Commit 7c32240c authored by Ewa Tusień's avatar Ewa Tusień Committed by Michał Karzyński

[Py] Added fake_quantize operator to Python API. (#3364)

parent b2e436cd
......@@ -37,6 +37,7 @@ ngraph.ops
elu
equal
exp
fake_quantize
floor
gelu
gemm
......
......@@ -50,6 +50,7 @@ from ngraph.ops import dot
from ngraph.ops import elu
from ngraph.ops import equal
from ngraph.ops import exp
from ngraph.ops import fake_quantize
from ngraph.ops import floor
from ngraph.ops import gelu
from ngraph.ops import gemm
......
......@@ -74,6 +74,7 @@ from _pyngraph.op import Dot
from _pyngraph.op import Elu
from _pyngraph.op import Equal
from _pyngraph.op import Exp
from _pyngraph.op import FakeQuantize
from _pyngraph.op import Floor
from _pyngraph.op import Gelu
from _pyngraph.op import Gemm
......
......@@ -22,9 +22,9 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, DepthToSpace, Divide, Dot, Elu, Equal, Exp, \
Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, Less, LessEq, Log, LRN, Max, \
Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, \
Convolution, ConvolutionBackpropData, Cos, Cosh, DepthToSpace, Divide, Dot, Elu, FakeQuantize, \
Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, Less, LessEq, Log, \
LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, \
Parameter, Product, Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, \
Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh, TopK
......@@ -537,6 +537,38 @@ def broadcast_to(node, new_shape, axis=None, name=None):
@nameable_op
def fake_quantize(data, input_low, input_high, output_low, output_high, levels, name=None):
# type: (Node, Node, Node, Node, Node, int, str) -> Node
r"""Perform an element-wise linear quantization on input data.
Input floating point values are quantized into a discrete set of floating point values.
.. code-block:: python
if x <= input_low:
output = output_low
if x > input_high:
output = output_high
else:
output = fake_quantize(output)
Fake quantize uses the following logic:
.. math:: output =
\dfrac{round( \dfrac{data - input\_low}{(input\_high - input\_low)\cdot (levels-1)})}
{(levels-1)\cdot (output\_high - output\_low)} + output\_low
:param data: The node with data tensor.
:param input_low: The node with the minimum for input values.
:param input_high: The node with the maximum for input values.
:param output_low: The node with the minimum quantized value.
:param output_high: The node with the maximum quantized value.
:param levels: The number of quantization levels. Integer value.
:return: New node with quantized value.
"""
return FakeQuantize(data, input_low, input_high, output_low, output_high, levels)
def gemm(A, # type: Node
B, # type: Node
C, # type: Node
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/fake_quantize.hpp"
#include "pyngraph/ops/fused/fake_quantize.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_FakeQuantize(py::module m)
{
py::class_<ngraph::op::FakeQuantize, std::shared_ptr<ngraph::op::FakeQuantize>, ngraph::op::Op>
fakequantize(m, "FakeQuantize");
fakequantize.doc() = "ngraph.impl.op.FakeQuantize wraps ngraph::op::FakeQuantize";
fakequantize.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
int&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_FakeQuantize(py::module m);
......@@ -54,6 +54,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Elu(m_op);
regclass_pyngraph_op_Equal(m_op);
regclass_pyngraph_op_Exp(m_op);
regclass_pyngraph_op_FakeQuantize(m_op);
regclass_pyngraph_op_Floor(m_op);
regclass_pyngraph_op_Gelu(m_op);
regclass_pyngraph_op_Gemm(m_op);
......
......@@ -45,6 +45,7 @@
#include "pyngraph/ops/fused/clamp.hpp"
#include "pyngraph/ops/fused/depth_to_space.hpp"
#include "pyngraph/ops/fused/elu.hpp"
#include "pyngraph/ops/fused/fake_quantize.hpp"
#include "pyngraph/ops/fused/gelu.hpp"
#include "pyngraph/ops/fused/gemm.hpp"
#include "pyngraph/ops/fused/grn.hpp"
......
......@@ -184,6 +184,7 @@ sources = [
'pyngraph/ops/fused/elu.cpp',
'pyngraph/ops/equal.cpp',
'pyngraph/ops/exp.cpp',
'pyngraph/ops/fused/fake_quantize.cpp',
'pyngraph/ops/floor.cpp',
'pyngraph/ops/fused/gelu.cpp',
'pyngraph/ops/fused/gemm.cpp',
......
......@@ -69,6 +69,52 @@ def test_elu_operator_with_scalar():
assert np.allclose(result, expected)
def test_fake_quantize():
runtime = get_runtime()
data_value = np.arange(24.0, dtype=np.float32).reshape(1, 2, 3, 4)
input_low_value = np.float32(0)
input_high_value = np.float32(23)
output_low_value = np.float32(2)
output_high_value = np.float32(16)
levels = np.float32(4)
data_shape = [1, 2, 3, 4]
bound_shape = []
parameter_data = ng.parameter(data_shape, name='data', dtype=np.float32)
parameter_input_low = ng.parameter(bound_shape, name='input_low', dtype=np.float32)
parameter_input_high = ng.parameter(bound_shape, name='input_high', dtype=np.float32)
parameter_output_low = ng.parameter(bound_shape, name='output_low', dtype=np.float32)
parameter_output_high = ng.parameter(bound_shape, name='output_high', dtype=np.float32)
model = ng.fake_quantize(parameter_data,
parameter_input_low,
parameter_input_high,
parameter_output_low,
parameter_output_high,
levels)
computation = runtime.computation(model,
parameter_data,
parameter_input_low,
parameter_input_high,
parameter_output_low,
parameter_output_high)
result = computation(data_value,
input_low_value,
input_high_value,
output_low_value,
output_high_value)
expected = np.array([[[[[2., 2., 2., 2.],
[6.6666669, 6.6666669, 6.6666669, 6.6666669],
[6.6666669, 6.6666669, 6.6666669, 6.6666669]],
[[11.33333301, 11.33333301, 11.33333301, 11.33333301],
[11.33333301, 11.33333301, 11.33333301, 11.33333301],
[16., 16., 16., 16.]]]]], dtype=np.float32)
assert np.allclose(result, expected)
def test_depth_to_space():
runtime = get_runtime()
......@@ -219,4 +265,5 @@ def test_grn_operator():
[[0.9970545, 0.98994946, 0.9805807, 0.97014254],
[0.9593655, 0.9486833, 0.9383431, 0.9284767],
[0.91914505, 0.9103665, 0.9021342, 0.8944272]]]], dtype=np.float32)
assert np.allclose(result, expected)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment