Commit 7c32240c authored by Ewa Tusień's avatar Ewa Tusień Committed by Michał Karzyński

[Py] Added fake_quantize operator to Python API. (#3364)

parent b2e436cd
...@@ -37,6 +37,7 @@ ngraph.ops ...@@ -37,6 +37,7 @@ ngraph.ops
elu elu
equal equal
exp exp
fake_quantize
floor floor
gelu gelu
gemm gemm
......
...@@ -50,6 +50,7 @@ from ngraph.ops import dot ...@@ -50,6 +50,7 @@ from ngraph.ops import dot
from ngraph.ops import elu from ngraph.ops import elu
from ngraph.ops import equal from ngraph.ops import equal
from ngraph.ops import exp from ngraph.ops import exp
from ngraph.ops import fake_quantize
from ngraph.ops import floor from ngraph.ops import floor
from ngraph.ops import gelu from ngraph.ops import gelu
from ngraph.ops import gemm from ngraph.ops import gemm
......
...@@ -74,6 +74,7 @@ from _pyngraph.op import Dot ...@@ -74,6 +74,7 @@ from _pyngraph.op import Dot
from _pyngraph.op import Elu from _pyngraph.op import Elu
from _pyngraph.op import Equal from _pyngraph.op import Equal
from _pyngraph.op import Exp from _pyngraph.op import Exp
from _pyngraph.op import FakeQuantize
from _pyngraph.op import Floor from _pyngraph.op import Floor
from _pyngraph.op import Gelu from _pyngraph.op import Gelu
from _pyngraph.op import Gemm from _pyngraph.op import Gemm
......
...@@ -22,9 +22,9 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio ...@@ -22,9 +22,9 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \ BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, DepthToSpace, Divide, Dot, Elu, Equal, Exp, \ Convolution, ConvolutionBackpropData, Cos, Cosh, DepthToSpace, Divide, Dot, Elu, FakeQuantize, \
Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, Less, LessEq, Log, LRN, Max, \ Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, Less, LessEq, Log, \
Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, \ LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, \
Parameter, Product, Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, \ Parameter, Product, Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, \
Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh, TopK Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh, TopK
...@@ -537,6 +537,38 @@ def broadcast_to(node, new_shape, axis=None, name=None): ...@@ -537,6 +537,38 @@ def broadcast_to(node, new_shape, axis=None, name=None):
@nameable_op @nameable_op
def fake_quantize(data, input_low, input_high, output_low, output_high, levels, name=None):
# type: (Node, Node, Node, Node, Node, int, str) -> Node
r"""Perform an element-wise linear quantization on input data.
Input floating point values are quantized into a discrete set of floating point values.
.. code-block:: python
if x <= input_low:
output = output_low
if x > input_high:
output = output_high
else:
output = fake_quantize(output)
Fake quantize uses the following logic:
.. math:: output =
\dfrac{round( \dfrac{data - input\_low}{(input\_high - input\_low)\cdot (levels-1)})}
{(levels-1)\cdot (output\_high - output\_low)} + output\_low
:param data: The node with data tensor.
:param input_low: The node with the minimum for input values.
:param input_high: The node with the maximum for input values.
:param output_low: The node with the minimum quantized value.
:param output_high: The node with the maximum quantized value.
:param levels: The number of quantization levels. Integer value.
:return: New node with quantized value.
"""
return FakeQuantize(data, input_low, input_high, output_low, output_high, levels)
def gemm(A, # type: Node def gemm(A, # type: Node
B, # type: Node B, # type: Node
C, # type: Node C, # type: Node
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/fake_quantize.hpp"
#include "pyngraph/ops/fused/fake_quantize.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_FakeQuantize(py::module m)
{
py::class_<ngraph::op::FakeQuantize, std::shared_ptr<ngraph::op::FakeQuantize>, ngraph::op::Op>
fakequantize(m, "FakeQuantize");
fakequantize.doc() = "ngraph.impl.op.FakeQuantize wraps ngraph::op::FakeQuantize";
fakequantize.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
int&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_FakeQuantize(py::module m);
...@@ -54,6 +54,7 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -54,6 +54,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Elu(m_op); regclass_pyngraph_op_Elu(m_op);
regclass_pyngraph_op_Equal(m_op); regclass_pyngraph_op_Equal(m_op);
regclass_pyngraph_op_Exp(m_op); regclass_pyngraph_op_Exp(m_op);
regclass_pyngraph_op_FakeQuantize(m_op);
regclass_pyngraph_op_Floor(m_op); regclass_pyngraph_op_Floor(m_op);
regclass_pyngraph_op_Gelu(m_op); regclass_pyngraph_op_Gelu(m_op);
regclass_pyngraph_op_Gemm(m_op); regclass_pyngraph_op_Gemm(m_op);
......
...@@ -45,6 +45,7 @@ ...@@ -45,6 +45,7 @@
#include "pyngraph/ops/fused/clamp.hpp" #include "pyngraph/ops/fused/clamp.hpp"
#include "pyngraph/ops/fused/depth_to_space.hpp" #include "pyngraph/ops/fused/depth_to_space.hpp"
#include "pyngraph/ops/fused/elu.hpp" #include "pyngraph/ops/fused/elu.hpp"
#include "pyngraph/ops/fused/fake_quantize.hpp"
#include "pyngraph/ops/fused/gelu.hpp" #include "pyngraph/ops/fused/gelu.hpp"
#include "pyngraph/ops/fused/gemm.hpp" #include "pyngraph/ops/fused/gemm.hpp"
#include "pyngraph/ops/fused/grn.hpp" #include "pyngraph/ops/fused/grn.hpp"
......
...@@ -184,6 +184,7 @@ sources = [ ...@@ -184,6 +184,7 @@ sources = [
'pyngraph/ops/fused/elu.cpp', 'pyngraph/ops/fused/elu.cpp',
'pyngraph/ops/equal.cpp', 'pyngraph/ops/equal.cpp',
'pyngraph/ops/exp.cpp', 'pyngraph/ops/exp.cpp',
'pyngraph/ops/fused/fake_quantize.cpp',
'pyngraph/ops/floor.cpp', 'pyngraph/ops/floor.cpp',
'pyngraph/ops/fused/gelu.cpp', 'pyngraph/ops/fused/gelu.cpp',
'pyngraph/ops/fused/gemm.cpp', 'pyngraph/ops/fused/gemm.cpp',
......
...@@ -69,6 +69,52 @@ def test_elu_operator_with_scalar(): ...@@ -69,6 +69,52 @@ def test_elu_operator_with_scalar():
assert np.allclose(result, expected) assert np.allclose(result, expected)
def test_fake_quantize():
runtime = get_runtime()
data_value = np.arange(24.0, dtype=np.float32).reshape(1, 2, 3, 4)
input_low_value = np.float32(0)
input_high_value = np.float32(23)
output_low_value = np.float32(2)
output_high_value = np.float32(16)
levels = np.float32(4)
data_shape = [1, 2, 3, 4]
bound_shape = []
parameter_data = ng.parameter(data_shape, name='data', dtype=np.float32)
parameter_input_low = ng.parameter(bound_shape, name='input_low', dtype=np.float32)
parameter_input_high = ng.parameter(bound_shape, name='input_high', dtype=np.float32)
parameter_output_low = ng.parameter(bound_shape, name='output_low', dtype=np.float32)
parameter_output_high = ng.parameter(bound_shape, name='output_high', dtype=np.float32)
model = ng.fake_quantize(parameter_data,
parameter_input_low,
parameter_input_high,
parameter_output_low,
parameter_output_high,
levels)
computation = runtime.computation(model,
parameter_data,
parameter_input_low,
parameter_input_high,
parameter_output_low,
parameter_output_high)
result = computation(data_value,
input_low_value,
input_high_value,
output_low_value,
output_high_value)
expected = np.array([[[[[2., 2., 2., 2.],
[6.6666669, 6.6666669, 6.6666669, 6.6666669],
[6.6666669, 6.6666669, 6.6666669, 6.6666669]],
[[11.33333301, 11.33333301, 11.33333301, 11.33333301],
[11.33333301, 11.33333301, 11.33333301, 11.33333301],
[16., 16., 16., 16.]]]]], dtype=np.float32)
assert np.allclose(result, expected)
def test_depth_to_space(): def test_depth_to_space():
runtime = get_runtime() runtime = get_runtime()
...@@ -219,4 +265,5 @@ def test_grn_operator(): ...@@ -219,4 +265,5 @@ def test_grn_operator():
[[0.9970545, 0.98994946, 0.9805807, 0.97014254], [[0.9970545, 0.98994946, 0.9805807, 0.97014254],
[0.9593655, 0.9486833, 0.9383431, 0.9284767], [0.9593655, 0.9486833, 0.9383431, 0.9284767],
[0.91914505, 0.9103665, 0.9021342, 0.8944272]]]], dtype=np.float32) [0.91914505, 0.9103665, 0.9021342, 0.8944272]]]], dtype=np.float32)
assert np.allclose(result, expected) assert np.allclose(result, expected)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment