Commit 7809effd authored by Ewa Tusień's avatar Ewa Tusień Committed by Scott Cyphers

[Py] Added Dequantize, Quantize, Quantized Convolution, Quantized Dot… (#3527)

* [Py] Added Dequantize, Quantize, Quantized Convolution, Quantized Dot operators to Python API.

* [Py] Removed unnecess import.

* [Py] Changed docstring.
Co-Authored-By: 's avatarTomasz Socha <tomasz.socha@intel.com>

* [Py] Changed docstring.

* [Py] Changed docstring.

* [Py] Added missed imports.
parent 4954b8d5
......@@ -32,6 +32,7 @@ ngraph.ops
cos
cosh
depth_to_space
dequantize
divide
dot
elu
......@@ -68,6 +69,9 @@ ngraph.ops
power
prelu
prod
quantize
quantized_convolution
quantized_dot
relu
replace_slice
reshape
......
......@@ -2,9 +2,9 @@ ngraph package
==============
.. automodule:: ngraph
:members:
:undoc-members:
:show-inheritance:
:members:
:undoc-members:
:show-inheritance:
Submodules
----------
......@@ -13,24 +13,23 @@ ngraph.exceptions module
------------------------
.. automodule:: ngraph.exceptions
:members:
:undoc-members:
:show-inheritance:
:members:
:undoc-members:
:show-inheritance:
ngraph.ops module
-----------------
.. automodule:: ngraph.ops
:members:
:undoc-members:
:show-inheritance:
:members:
:undoc-members:
:show-inheritance:
ngraph.runtime module
---------------------
.. automodule:: ngraph.runtime
:members:
:undoc-members:
:show-inheritance:
:members:
:undoc-members:
:show-inheritance:
......@@ -45,6 +45,7 @@ from ngraph.ops import convolution_backprop_data
from ngraph.ops import cos
from ngraph.ops import cosh
from ngraph.ops import depth_to_space
from ngraph.ops import dequantize
from ngraph.ops import divide
from ngraph.ops import dot
from ngraph.ops import elu
......@@ -81,6 +82,9 @@ from ngraph.ops import parameter
from ngraph.ops import power
from ngraph.ops import prod
from ngraph.ops import prelu
from ngraph.ops import quantize
from ngraph.ops import quantized_convolution
from ngraph.ops import quantized_dot
from ngraph.ops import relu
from ngraph.ops import replace_slice
from ngraph.ops import reshape
......
......@@ -69,6 +69,7 @@ from _pyngraph.op import ConvolutionBackpropFilters
from _pyngraph.op import Cos
from _pyngraph.op import Cosh
from _pyngraph.op import DepthToSpace
from _pyngraph.op import Dequantize
from _pyngraph.op import Divide
from _pyngraph.op import Dot
from _pyngraph.op import Elu
......@@ -106,6 +107,9 @@ from _pyngraph.op import Parameter
from _pyngraph.op import Power
from _pyngraph.op import PRelu
from _pyngraph.op import Product
from _pyngraph.op import Quantize
from _pyngraph.op import QuantizedConvolution
from _pyngraph.op import QuantizedDot
from _pyngraph.op import Relu
from _pyngraph.op import ReluBackprop
from _pyngraph.op import ReplaceSlice
......
......@@ -22,10 +22,11 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, DepthToSpace, Divide, Dot, Elu, FakeQuantize, \
Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, HardSigmoid, Less, \
LessEq, Log, LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, MVN, Negative, Not, NotEqual, \
OneHot, Or, Pad, Parameter, Product, Power, PRelu, Relu, ReplaceSlice, Reshape, Reverse, \
Convolution, ConvolutionBackpropData, Cos, Cosh, DepthToSpace, Dequantize, Divide, Dot, Elu, \
FakeQuantize, Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, \
HardSigmoid, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, MVN, \
Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, Power, Quantize, \
QuantizedConvolution, QuantizedDot, PRelu, Relu, ReplaceSlice, Reshape, Reverse, \
ScaleShift, Select, ShuffleChannels, Sign, Sin, Sinh, Slice, Softmax, SpaceToDepth, Sqrt, \
SquaredDifference, Squeeze, Subtract, Sum, Tan, Tanh, TopK, Unsqueeze
......@@ -243,6 +244,190 @@ def mvn(data, axes, normalize_variance, eps, name=None):
return MVN(data, AxisSet(axes), normalize_variance, eps)
@nameable_op
def quantize(data, scale, zero_point, new_type, axes, round_mode, name=None):
# type: (Node, Node, Node, NumericType, Set[int], Quantize.RoundMode, str) -> Node
r"""Perform quantize operation on data from input node.
Computes quantize on the input tensor:
.. math:: output = ROUND((input / scale) + zero\_point)
:param data: The node with data tensor.
:param scale: Scale used for mapping.
:param zero_point: Zero point used for mapping.
:param new_type: Output element type.
:param round_mode: Number describes how to perform ROUND function.
ROUND_NEAREST_TOWARD_INFINITY: Round to nearest integer. In case of two
equidistant integers round away from zero e.g. 2.5 -> 3, -3.5 -> -4
ROUND_NEAREST_TOWARD_ZERO: Round to nearest integer. In case of two equidistant
integers round toward zero e.g. 2.5 -> 2, -3.5 -> -3
ROUND_NEAREST_UPWARD: Round to nearest integer. In case of two equidistant
integers round up e.g. 2.5 -> 2, -3.5 -> -3
ROUND_NEAREST_DOWNWARD: Round to nearest integer. In case of two equidistant
integers round down e.g. 2.5 -> 2, -3.5 -> -4
ROUND_NEAREST_TOWARD_EVEN: Round to nearest integer. In case of two equidistant
integers round down e.g. 2.5 -> 2, -3.5 -> -4
ROUND_TOWARD_INFINITY: Round to nearest integer away from zero.
ROUND_TOWARD_ZERO: Round to nearest integer toward zero.
ROUND_UP: Round to nearest integer toward infinity (ceiling).
ROUND_DOWN: Round to nearest integer toward negative infinity (floor).
:param name: Optional output node name.
:return: The new node performing a quantize operation on input tensor.
"""
new_element_type = get_element_type(new_type)
return Quantize(data,
scale,
zero_point,
new_element_type,
AxisSet(axes),
round_mode)
@nameable_op
def dequantize(data, scale, zero_point, element_type, axes, name=None):
# type: (Node, Node, Node, NumericType, Set[int], str) -> Node
r"""Perform dequantize operation on data from input node.
Computes dequantize on the input tensor:
.. math:: output = (input - zero\_point) * scale
:param data: The node with data tensor.
:param scale: Scale used for mapping.
:param zero_point: Zero point used for mapping.
:param element_type: Output element type.
:param name: Optional output node name.
:return: The new node performing a dequantize operation on input tensor.
"""
new_element_type = get_element_type(element_type)
return Dequantize(data, scale, zero_point, new_element_type, AxisSet(axes))
@nameable_op
def quantized_convolution(data, # type: Node
filters, # type: Node
window_movement_strides, # type: List[int]
window_dilation_strides, # type: List[int]
padding_below, # type: List[int]
padding_above, # type: List[int]
data_dilation_strides, # type: List[int]
input_scale, # type: Node
input_zero_point, # type: Node
filter_scale, # type: Node
filter_zero_point, # type: Node
output_scale, # type: Node
output_zero_point, # type: Node
output_type, # type: NumericType
input_axes, # type: Set[int]
filter_axes, # type: Set[int]
output_axes, # type: Set[int]
name=None, # type: str
):
# type: (...) -> Node
r"""Perform quantized convolution operation on data from input node.
:param data: The node producing the input data batch tensor.
:param filters: The node producing the filters tensor.
:param window_movement_strides: The window movement strides.
:param window_dilation_strides: he window dilation strides.
:param padding_below: The padding-below sizes.
:param padding_above: The padding-above sizes.
:param data_dilation_strides: The data dilation strides.
:param input_scale: Scale to transform the input.
:param input_zero_point: Zero point used for mapping.
:param filter_scale: Scale to transform the filters.
:param filter_zero_point: Zero point used for mapping.
:param output_scale: Scale to transform the output.
:param output_zero_point: Zero point used for mapping.
:param output_type: Output element type.
:param input_axes: Input axes set for channel wise quantization.
:param filter_axes: Filter axes set for channel wise quantization.
:param output_type: Output axes set for channel wise quantization.
:param name: Optional output node name.
:return: The new node performing a quantized convolution operation on input tensor.
"""
new_output_type = get_element_type(output_type)
return QuantizedConvolution(data,
filters,
Strides(window_movement_strides),
Strides(window_dilation_strides),
CoordinateDiff(padding_below),
CoordinateDiff(padding_above),
Strides(data_dilation_strides),
input_scale,
input_zero_point,
filter_scale,
filter_zero_point,
output_scale,
output_zero_point,
new_output_type,
AxisSet(input_axes),
AxisSet(filter_axes),
AxisSet(output_axes))
@nameable_op
def quantized_dot(input0, # type: Node
input1, # type: Node
reduction_axes_count, # type: int
input0_scale, # type: Node
input0_zero_point, # type: Node
input1_scale, # type: Node
input1_zero_point, # type: Node
output_scale, # type: Node
output_zero_point, # type: Node
output_type, # type: NumericType
input0_axes, # type: Set[int]
input1_axes, # type: Set[int]
output_axes, # type: Set[int]
name=None, # type: str
):
# type: (...) -> Node
r"""Perform quantized dot operation on data from input node.
:param input0: The node producing the input data batch tensor.
:param input1: The node producing the filters tensor.
:param reduction_axes_count: Number of reduction axes.
:param input0_scale: Scale to transform the input.
:param input0_zero_point: Zero point used for mapping.
:param input1_scale: Scale to transform the filters.
:param input1_zero_point: Zero point used for mapping.
:param output_scale: Scale to transform the output.
:param output_zero_point: Zero point used for mapping.
:param output_type: Output element type.
:param input0_axes: Input0 axes set for channel wise quantization
:param input1_axes: Input1 axes set for channel wise quantization
:param output_axes: Output axes set for channel wise quantization
:param name: Optional output node name.
:return: The new node performing a quantized dot operation on input tensor.
"""
new_output_type = get_element_type(output_type)
return QuantizedDot(input0,
input1,
reduction_axes_count,
input0_scale,
input0_zero_point,
input1_scale,
input1_zero_point,
output_scale,
output_zero_point,
new_output_type,
AxisSet(input0_axes),
AxisSet(input1_axes),
AxisSet(output_axes))
# Unary ops
@unary_op
def absolute(node, name=None): # type: (NodeInput, str) -> Node
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/dequantize.hpp"
#include "pyngraph/ops/dequantize.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Dequantize(py::module m)
{
py::class_<ngraph::op::Dequantize, std::shared_ptr<ngraph::op::Dequantize>, ngraph::op::Op>
dequantize(m, "Dequantize");
dequantize.doc() = "ngraph.impl.op.Dequantize wraps ngraph::op::Dequantize";
dequantize.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const ngraph::element::Type&,
const ngraph::AxisSet&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Dequantize(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/quantize.hpp"
#include "pyngraph/ops/quantize.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Quantize(py::module m)
{
py::class_<ngraph::op::Quantize, std::shared_ptr<ngraph::op::Quantize>, ngraph::op::Op>
quantize(m, "Quantize");
quantize.doc() = "ngraph.impl.op.Quantize wraps ngraph::op::Quantize";
quantize.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const ngraph::element::Type&,
const ngraph::AxisSet&,
ngraph::op::Quantize::RoundMode>());
py::enum_<ngraph::op::Quantize::RoundMode>(quantize, "RoundMode", py::arithmetic())
.value("ROUND_NEAREST_TOWARD_INFINITY",
ngraph::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY)
.value("ROUND_NEAREST_TOWARD_ZERO",
ngraph::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_ZERO)
.value("ROUND_NEAREST_UPWARD", ngraph::op::Quantize::RoundMode::ROUND_NEAREST_UPWARD)
.value("ROUND_NEAREST_DOWNWARD", ngraph::op::Quantize::RoundMode::ROUND_NEAREST_DOWNWARD)
.value("ROUND_NEAREST_TOWARD_EVEN",
ngraph::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)
.value("ROUND_TOWARD_INFINITY", ngraph::op::Quantize::RoundMode::ROUND_TOWARD_INFINITY)
.value("ROUND_TOWARD_ZERO", ngraph::op::Quantize::RoundMode::ROUND_TOWARD_ZERO)
.value("ROUND_UP", ngraph::op::Quantize::RoundMode::ROUND_UP)
.value("ROUND_DOWN", ngraph::op::Quantize::RoundMode::ROUND_DOWN)
.export_values();
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Quantize(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/quantized_convolution.hpp"
#include "pyngraph/ops/quantized_convolution.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_QuantizedConvolution(py::module m)
{
py::class_<ngraph::op::QuantizedConvolution,
std::shared_ptr<ngraph::op::QuantizedConvolution>,
ngraph::op::Op>
quantizedconvolution(m, "QuantizedConvolution");
quantizedconvolution.doc() =
"ngraph.impl.op.QuantizedConvolution wraps ngraph::op::QuantizedConvolution";
quantizedconvolution.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const ngraph::Strides&,
const ngraph::Strides&,
const ngraph::CoordinateDiff&,
const ngraph::CoordinateDiff&,
const ngraph::Strides&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const ngraph::element::Type&,
const ngraph::AxisSet&,
const ngraph::AxisSet&,
const ngraph::AxisSet&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_QuantizedConvolution(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/quantized_dot.hpp"
#include "pyngraph/ops/quantized_dot.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_QuantizedDot(py::module m)
{
py::class_<ngraph::op::QuantizedDot, std::shared_ptr<ngraph::op::QuantizedDot>, ngraph::op::Op>
quantizeddot(m, "QuantizedDot");
quantizeddot.doc() = "ngraph.impl.op.QuantizedDot wraps ngraph::op::QuantizedDot";
quantizeddot.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const int,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const ngraph::element::Type&,
const ngraph::AxisSet&,
const ngraph::AxisSet&,
const ngraph::AxisSet&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_QuantizedDot(py::module m);
......@@ -49,6 +49,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Cos(m_op);
regclass_pyngraph_op_Cosh(m_op);
regclass_pyngraph_op_DepthToSpace(m_op);
regclass_pyngraph_op_Dequantize(m_op);
regclass_pyngraph_op_Divide(m_op);
regclass_pyngraph_op_Dot(m_op);
regclass_pyngraph_op_Elu(m_op);
......@@ -86,6 +87,9 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Power(m_op);
regclass_pyngraph_op_PRelu(m_op);
regclass_pyngraph_op_Product(m_op);
regclass_pyngraph_op_Quantize(m_op);
regclass_pyngraph_op_QuantizedConvolution(m_op);
regclass_pyngraph_op_QuantizedDot(m_op);
regclass_pyngraph_op_Relu(m_op);
regclass_pyngraph_op_ReluBackprop(m_op);
regclass_pyngraph_op_ReplaceSlice(m_op);
......
......@@ -37,6 +37,7 @@
#include "pyngraph/ops/convolution.hpp"
#include "pyngraph/ops/cos.hpp"
#include "pyngraph/ops/cosh.hpp"
#include "pyngraph/ops/dequantize.hpp"
#include "pyngraph/ops/divide.hpp"
#include "pyngraph/ops/dot.hpp"
#include "pyngraph/ops/equal.hpp"
......@@ -81,6 +82,9 @@
#include "pyngraph/ops/passthrough.hpp"
#include "pyngraph/ops/power.hpp"
#include "pyngraph/ops/product.hpp"
#include "pyngraph/ops/quantize.hpp"
#include "pyngraph/ops/quantized_convolution.hpp"
#include "pyngraph/ops/quantized_dot.hpp"
#include "pyngraph/ops/relu.hpp"
#include "pyngraph/ops/replace_slice.hpp"
#include "pyngraph/ops/reshape.hpp"
......
......@@ -179,6 +179,7 @@ sources = [
'pyngraph/ops/cosh.cpp',
'pyngraph/ops/ceiling.cpp',
'pyngraph/ops/fused/depth_to_space.cpp',
'pyngraph/ops/dequantize.cpp',
'pyngraph/ops/divide.cpp',
'pyngraph/ops/dot.cpp',
'pyngraph/ops/fused/elu.cpp',
......@@ -214,6 +215,9 @@ sources = [
'pyngraph/ops/passthrough.cpp',
'pyngraph/ops/power.cpp',
'pyngraph/ops/fused/prelu.cpp',
'pyngraph/ops/quantize.cpp',
'pyngraph/ops/quantized_convolution.cpp',
'pyngraph/ops/quantized_dot.cpp',
'pyngraph/ops/regmodule_pyngraph_op.cpp',
'pyngraph/ops/relu.cpp',
'pyngraph/ops/replace_slice.cpp',
......
# ******************************************************************************
# Copyright 2017-2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import numpy as np
import ngraph as ng
from test.ngraph.util import get_runtime
from ngraph.impl.op import Quantize
def test_quantize_operator():
runtime = get_runtime()
data_shape = [6]
scale_shape = []
zero_point_shape = []
data_value = np.array([0, 2, 3, 1000, -254, -1000]).astype(np.float32)
scale_value = np.float32(2)
zero_point_value = np.uint8(128)
new_type = np.uint8
axis_set = []
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
parameter_scale = ng.parameter(scale_shape, name='Scale', dtype=np.float32)
parameter_zero_point = ng.parameter(zero_point_shape, name='Zero_Point', dtype=np.uint8)
model = ng.quantize(parameter_data,
parameter_scale,
parameter_zero_point,
new_type,
axis_set,
Quantize.RoundMode.ROUND_NEAREST_TOWARD_INFINITY)
computation = runtime.computation(model,
parameter_data,
parameter_scale,
parameter_zero_point)
result = computation(data_value, scale_value, zero_point_value)
expected = np.array([128, 129, 130, 255, 1, 0]).astype(np.uint8)
assert np.allclose(result, expected)
def test_quantized_convoluction_operator():
runtime = get_runtime()
data_shape = [1, 1, 3, 4]
filters_shape = [1, 1, 3, 3]
result_shape = [1, 1, 3, 4]
shape = []
data_value = np.array([1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4]).astype(np.uint8).reshape(data_shape)
filters_value = np.array([1, 2, 3, 4, 5, 0, 0, 1, 2]).astype(np.uint8).reshape(filters_shape)
window_movement_strides = [1, 1]
window_dilation_strides = [1, 1]
padding_below = [1, 1]
padding_above = [1, 1]
data_dilation_strides = [1, 1]
input_scale_value = 1
input_zero_point_value = 0
filter_scale_value = 1
filter_zero_point_value = 0
output_scale_value = 1
output_zero_point_value = 0
output_type = np.int32
input_axes = []
filter_axes = []
output_axes = []
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.uint8)
parameter_filters = ng.parameter(filters_shape, name='Filters', dtype=np.uint8)
parameter_input_scale = ng.parameter(shape, name='Input_scale', dtype=np.float32)
parameter_input_zero_point = ng.parameter(shape, name='Input_zero_point', dtype=np.uint8)
parameter_filter_scale = ng.parameter(shape, name='Filter_scale', dtype=np.float32)
parameter_filter_zero_point = ng.parameter(shape, name='Filter_zero_point', dtype=np.uint8)
parameter_output_scale = ng.parameter(shape, name='Output_scale', dtype=np.float32)
parameter_output_zero_point = ng.parameter(shape, name='Output_zero_point', dtype=np.int32)
model = ng.quantized_convolution(parameter_data,
parameter_filters,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
parameter_input_scale,
parameter_input_zero_point,
parameter_filter_scale,
parameter_filter_zero_point,
parameter_output_scale,
parameter_output_zero_point,
output_type,
input_axes,
filter_axes,
output_axes)
computation = runtime.computation(model,
parameter_data,
parameter_filters,
parameter_input_scale,
parameter_input_zero_point,
parameter_filter_scale,
parameter_filter_zero_point,
parameter_output_scale,
parameter_output_zero_point)
result = computation(data_value,
filters_value,
input_scale_value,
input_zero_point_value,
filter_scale_value,
filter_zero_point_value,
output_scale_value,
output_zero_point_value)
expected = np.array([22, 34, 30, 32, 38, 72,
90, 43, 33, 52, 43, 39]).astype(np.int8).reshape(result_shape)
assert np.allclose(result, expected)
def test_quantized_dot_operator():
runtime = get_runtime()
input0_shape = [1, 2]
input1_shape = [2, 3]
result_shape = [1, 3]
shape = []
input0_value = np.array([2, 3]).astype(np.uint8).reshape(input0_shape)
input1_value = np.array([0, 2, 4, 1, 3, 5]).astype(np.uint8).reshape(input1_shape)
reduction_axes_count = 1
input0_scale_value = 2
input0_zero_point_value = 0
input1_scale_value = 1
input1_zero_point_value = 0
output_scale_value = 2
output_zero_point_value = 0
output_type = np.uint8
input0_axes = []
input1_axes = []
output_axes = []
parameter_input0 = ng.parameter(input0_shape, name='Input0', dtype=np.uint8)
parameter_input1 = ng.parameter(input1_shape, name='Input1', dtype=np.uint8)
parameter_input0_scale = ng.parameter(shape, name='Input0_scale', dtype=np.float32)
parameter_input0_zero_point = ng.parameter(shape, name='Input0_zero_point', dtype=np.uint8)
parameter_input1_scale = ng.parameter(shape, name='Input1_scale', dtype=np.float32)
parameter_input1_zero_point = ng.parameter(shape, name='Input1_zero_point', dtype=np.uint8)
parameter_output_scale = ng.parameter(shape, name='Output_scale', dtype=np.float32)
parameter_output_zero_point = ng.parameter(shape, name='Output_zero_point', dtype=np.uint8)
model = ng.quantized_dot(parameter_input0,
parameter_input1,
reduction_axes_count,
parameter_input0_scale,
parameter_input0_zero_point,
parameter_input1_scale,
parameter_input1_zero_point,
parameter_output_scale,
parameter_output_zero_point,
output_type,
input0_axes,
input1_axes,
output_axes)
computation = runtime.computation(model,
parameter_input0,
parameter_input1,
parameter_input0_scale,
parameter_input0_zero_point,
parameter_input1_scale,
parameter_input1_zero_point,
parameter_output_scale,
parameter_output_zero_point)
result = computation(input0_value,
input1_value,
input0_scale_value,
input0_zero_point_value,
input1_scale_value,
input1_zero_point_value,
output_scale_value,
output_zero_point_value)
expected = np.array([3, 13, 23]).astype(np.int8).reshape(result_shape)
assert np.allclose(result, expected)
def test_dequantize_operator():
runtime = get_runtime()
data_shape = [4, 3]
scale_shape = []
zero_point_shape = []
result_shape = [4, 3]
data_value = np.array([1, 1, 2, -1, 3, -1,
4, -3, 5, -3, 6, -5]).astype(np.int8).reshape(data_shape)
scale_value = np.float32(2)
zero_point_value = np.int8(1)
element_type = np.float32
axis_set = []
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.int8)
parameter_scale = ng.parameter(scale_shape, name='Scale', dtype=np.float32)
parameter_zero_point = ng.parameter(zero_point_shape, name='Zero_Point', dtype=np.int8)
model = ng.dequantize(parameter_data,
parameter_scale,
parameter_zero_point,
element_type,
axis_set)
computation = runtime.computation(model,
parameter_data,
parameter_scale,
parameter_zero_point)
result = computation(data_value, scale_value, zero_point_value)
expected = np.array([0, 0, 2, -4, 4, -4,
6, -8, 8, -8, 10, -12]).astype(np.float32).reshape(result_shape)
assert np.allclose(result, expected)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment