Commit 41f7cb5f authored by Adam Straw's avatar Adam Straw Committed by Scott Cyphers

Quantization API cleanup (non breaking) (#2873)

* quantization cleanup

* offset changed to zero point

* fix failing tests

* code style

* code style
parent eb83f267
...@@ -13,33 +13,33 @@ Description ...@@ -13,33 +13,33 @@ Description
Produces a tensor of element type ``type`` and the same shape as ``input`` Produces a tensor of element type ``type`` and the same shape as ``input``
where the value of each coordinate :math:`i` of ``output`` is the corresponding coordinate of where the value of each coordinate :math:`i` of ``output`` is the corresponding coordinate of
``input`` minus ``offset`` quantity multiplied by ``scale``. ``input`` minus ``zero_point`` quantity multiplied by ``scale``.
The coordinate :math:`j` of ``scale`` and ``offset`` is the coordinate of ``output`` The coordinate :math:`j` of ``scale`` and ``zero_point`` is the coordinate of ``output``
projected onto ``axes``. projected onto ``axes``.
Inputs Inputs
------ ------
+-----------------+-------------------------+------------------------------------------+ +-----------------+-------------------------+----------------------------------------------+
| Name | Element Type | Shape | | Name | Element Type | Shape |
+=================+=========================+==========================================+ +=================+=========================+==============================================+
| ``input`` | Any quantized type | Any | | ``input`` | Any quantized type | Any |
+-----------------+-------------------------+------------------------------------------+ +-----------------+-------------------------+----------------------------------------------+
| ``scale`` | Same as ``output`` | ``input`` shape projected onto ``axes`` | | ``scale`` | Same as ``output`` | ``input`` shape projected onto ``axes`` |
+-----------------+-------------------------+------------------------------------------+ +-----------------+-------------------------+----------------------------------------------+
| ``offset`` | Same as ``input`` | ``input`` shape projected onto ``axes`` | | ``zero_point`` | Same as ``input`` | ``input`` shape projected onto ``axes`` |
+-----------------+-------------------------+------------------------------------------+ +-----------------+-------------------------+----------------------------------------------+
Attributes Attributes
---------- ----------
+-------------------------------+----------------------------------------------------------------+ +-------------------------------+--------------------------------------------------------------------+
| Name | Description | | Name | Description |
+===============================+================================================================+ +===============================+====================================================================+
| ``type`` | ``output`` element type; any real type | | ``type`` | ``output`` element type; any real type |
+-------------------------------+----------------------------------------------------------------+ +-------------------------------+--------------------------------------------------------------------+
| ``axes`` | Axis positions on which ``scale`` and ``offset`` are specified | | ``axes`` | Axis positions on which ``scale`` and ``zero_point`` are specified |
+-------------------------------+----------------------------------------------------------------+ +-------------------------------+--------------------------------------------------------------------+
...@@ -59,7 +59,7 @@ Mathematical Definition ...@@ -59,7 +59,7 @@ Mathematical Definition
.. math:: .. math::
\mathtt{output}_{i,j} = (\mathtt{input}_{i,j} - \mathtt{offset}_{j}) \mathtt{scale}_{j} \mathtt{output}_{i,j} = (\mathtt{input}_{i,j} - \mathtt{zero_point}_{j}) \mathtt{scale}_{j}
C++ Interface C++ Interface
============= =============
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -37,16 +37,16 @@ namespace ngraph ...@@ -37,16 +37,16 @@ namespace ngraph
{ {
namespace quantization namespace quantization
{ {
shared_ptr<Node> QuantizedLinearConvolution(shared_ptr<Node> input, shared_ptr<Node> QuantizedLinearConvolution(const shared_ptr<Node>& input,
shared_ptr<Node> filter, const shared_ptr<Node>& filter,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
shared_ptr<Node> input_scale, const shared_ptr<Node>& input_scale,
shared_ptr<Node> filter_scale, const shared_ptr<Node>& filter_scale,
shared_ptr<Node> output_scale) const shared_ptr<Node>& output_scale)
{ {
// TODO: need to establish cross-nGraph view of scale (mult or div) // TODO: need to establish cross-nGraph view of scale (mult or div)
auto requantization_scale = (input_scale * filter_scale) / output_scale; auto requantization_scale = (input_scale * filter_scale) / output_scale;
...@@ -65,19 +65,19 @@ namespace ngraph ...@@ -65,19 +65,19 @@ namespace ngraph
// need to make this the primary builder which means // need to make this the primary builder which means
// 1) add support for zero point in QuantizeConvolution op API // 1) add support for zero point in QuantizeConvolution op API
// 2) add QuantizedConvolution reference kernel, including zero point // 2) add QuantizedConvolution reference kernel, including zero point
shared_ptr<Node> QuantizedLinearConvolution(shared_ptr<Node> input, shared_ptr<Node> QuantizedLinearConvolution(const shared_ptr<Node>& input,
shared_ptr<Node> filter, const shared_ptr<Node>& filter,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
shared_ptr<Node> input_scale, const shared_ptr<Node>& input_scale,
shared_ptr<Node> input_zero_point, const shared_ptr<Node>& input_zero_point,
shared_ptr<Node> filter_scale, const shared_ptr<Node>& filter_scale,
shared_ptr<Node> filter_zero_point, const shared_ptr<Node>& filter_zero_point,
shared_ptr<Node> output_scale, const shared_ptr<Node>& output_scale,
shared_ptr<Node> output_zero_point) const shared_ptr<Node>& output_zero_point)
{ {
AxisSet axes; AxisSet axes;
...@@ -107,21 +107,22 @@ namespace ngraph ...@@ -107,21 +107,22 @@ namespace ngraph
return q_convolution; return q_convolution;
} }
shared_ptr<Node> QuantizedLinearConvolutionBias(shared_ptr<Node> input, shared_ptr<Node> QuantizedLinearConvolutionBias(const shared_ptr<Node>& input,
shared_ptr<Node> filter, const shared_ptr<Node>& filter,
shared_ptr<Node> bias, const shared_ptr<Node>& bias,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
shared_ptr<Node> input_scale, const shared_ptr<Node>& input_scale,
shared_ptr<Node> filter_scale, const shared_ptr<Node>& filter_scale,
shared_ptr<Node> output_scale) const shared_ptr<Node>& output_scale)
{ {
// TODO: need to establish cross-nGraph view of scale (mult or div) // TODO: need to establish cross-nGraph view of scale (mult or div)
auto requantization_scale = (input_scale * filter_scale) / output_scale; auto requantization_scale = (input_scale * filter_scale) / output_scale;
auto mybias = bias;
if (bias->get_element_type() != element::i32) if (bias->get_element_type() != element::i32)
{ {
auto zero = make_constant(element::i32, input_scale->get_shape(), 0); auto zero = make_constant(element::i32, input_scale->get_shape(), 0);
...@@ -130,13 +131,13 @@ namespace ngraph ...@@ -130,13 +131,13 @@ namespace ngraph
op::Quantize::RoundMode round_mode = op::Quantize::RoundMode round_mode =
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN; op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
bias = make_shared<op::Quantize>( mybias = make_shared<op::Quantize>(
bias, bias_scale, zero, element::i32, quantization_axes, round_mode); bias, bias_scale, zero, element::i32, quantization_axes, round_mode);
} }
return make_shared<op::QuantizedConvolutionBias>(input, return make_shared<op::QuantizedConvolutionBias>(input,
filter, filter,
bias, mybias,
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
padding_below, padding_below,
...@@ -146,8 +147,8 @@ namespace ngraph ...@@ -146,8 +147,8 @@ namespace ngraph
false); false);
} }
shared_ptr<Node> QuantizedConvInteger(shared_ptr<Node> input, shared_ptr<Node> QuantizedConvInteger(const shared_ptr<Node>& input,
shared_ptr<Node> filter, const shared_ptr<Node>& filter,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
......
...@@ -25,47 +25,48 @@ namespace ngraph ...@@ -25,47 +25,48 @@ namespace ngraph
{ {
namespace quantization namespace quantization
{ {
std::shared_ptr<Node> QuantizedLinearConvolution(std::shared_ptr<Node> input, std::shared_ptr<Node>
std::shared_ptr<Node> filter, QuantizedLinearConvolution(const std::shared_ptr<Node>& input,
const Strides& window_movement_strides, const std::shared_ptr<Node>& filter,
const Strides& window_dilation_strides, const Strides& window_movement_strides,
const CoordinateDiff& padding_below, const Strides& window_dilation_strides,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_below,
const Strides& data_dilation_strides, const CoordinateDiff& padding_above,
std::shared_ptr<Node> input_scale, const Strides& data_dilation_strides,
std::shared_ptr<Node> filter_scale, const std::shared_ptr<Node>& input_scale,
std::shared_ptr<Node> output_scale); const std::shared_ptr<Node>& filter_scale,
const std::shared_ptr<Node>& output_scale);
std::shared_ptr<Node> std::shared_ptr<Node>
QuantizedLinearConvolution(std::shared_ptr<Node> input, QuantizedLinearConvolution(const std::shared_ptr<Node>& input,
std::shared_ptr<Node> filter, const std::shared_ptr<Node>& filter,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
std::shared_ptr<Node> input_scale, const std::shared_ptr<Node>& input_scale,
std::shared_ptr<Node> input_zero_point, const std::shared_ptr<Node>& input_zero_point,
std::shared_ptr<Node> filter_scale, const std::shared_ptr<Node>& filter_scale,
std::shared_ptr<Node> filter_zero_point, const std::shared_ptr<Node>& filter_zero_point,
std::shared_ptr<Node> output_scale, const std::shared_ptr<Node>& output_scale,
std::shared_ptr<Node> output_zero_point); const std::shared_ptr<Node>& output_zero_point);
std::shared_ptr<Node> std::shared_ptr<Node>
QuantizedLinearConvolutionBias(std::shared_ptr<Node> input, QuantizedLinearConvolutionBias(const std::shared_ptr<Node>& input,
std::shared_ptr<Node> filter, const std::shared_ptr<Node>& filter,
std::shared_ptr<Node> bias, const std::shared_ptr<Node>& bias,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
std::shared_ptr<Node> input_scale, const std::shared_ptr<Node>& input_scale,
std::shared_ptr<Node> filter_scale, const std::shared_ptr<Node>& filter_scale,
std::shared_ptr<Node> output_scale); const std::shared_ptr<Node>& output_scale);
std::shared_ptr<Node> QuantizedConvInteger(std::shared_ptr<Node> input, std::shared_ptr<Node> QuantizedConvInteger(const std::shared_ptr<Node>& input,
std::shared_ptr<Node> filter, const std::shared_ptr<Node>& filter,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
......
...@@ -20,13 +20,13 @@ ...@@ -20,13 +20,13 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Dequantize::Dequantize(shared_ptr<Node> input, op::Dequantize::Dequantize(const shared_ptr<Node>& input,
shared_ptr<Node> scale, const shared_ptr<Node>& scale,
shared_ptr<Node> offset, const shared_ptr<Node>& zero_point,
const element::Type& type, const element::Type& type,
const AxisSet& axes) const AxisSet& axes)
: Op("Dequantize", check_single_output_args({input, scale, offset})) : Op("Dequantize", check_single_output_args({input, scale, zero_point}))
, m_type(type) , m_type(type)
, m_axes(axes) , m_axes(axes)
{ {
...@@ -39,7 +39,7 @@ void op::Dequantize::validate_and_infer_types() ...@@ -39,7 +39,7 @@ void op::Dequantize::validate_and_infer_types()
{ {
INPUT, INPUT,
SCALE, SCALE,
OFFSET ZERO_POINT
}; };
NODE_VALIDATION_CHECK(this, m_type.is_static(), "Output element type must not be dynamic"); NODE_VALIDATION_CHECK(this, m_type.is_static(), "Output element type must not be dynamic");
...@@ -52,16 +52,16 @@ void op::Dequantize::validate_and_infer_types() ...@@ -52,16 +52,16 @@ void op::Dequantize::validate_and_infer_types()
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
element::Type::merge(quantized_type, element::Type::merge(quantized_type,
get_input_element_type(INPUT), get_input_element_type(INPUT),
get_input_element_type(OFFSET)), get_input_element_type(ZERO_POINT)),
"Offset element type (", "Zero point element type (",
get_input_element_type(OFFSET), get_input_element_type(ZERO_POINT),
") must match input element type (", ") must match input element type (",
get_input_element_type(INPUT), get_input_element_type(INPUT),
")"); ")");
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
quantized_type.is_dynamic() || quantized_type.is_quantized(), quantized_type.is_dynamic() || quantized_type.is_quantized(),
"Offset/input element type (", "Zero point / input element type (",
quantized_type, quantized_type,
") must be a quantized type"); ") must be a quantized type");
...@@ -90,21 +90,21 @@ void op::Dequantize::validate_and_infer_types() ...@@ -90,21 +90,21 @@ void op::Dequantize::validate_and_infer_types()
")"); ")");
} }
PartialShape scale_offset_shape = get_input_partial_shape(SCALE); PartialShape scale_zero_point_shape = get_input_partial_shape(SCALE);
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, this,
PartialShape::merge_into(scale_offset_shape, get_input_partial_shape(OFFSET)), PartialShape::merge_into(scale_zero_point_shape, get_input_partial_shape(ZERO_POINT)),
"Scale shape (", "Scale shape (",
get_input_partial_shape(SCALE), get_input_partial_shape(SCALE),
") and offset shape (", ") and zero point shape (",
get_input_partial_shape(OFFSET), get_input_partial_shape(ZERO_POINT),
") must match"); ") must match");
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
scale_offset_shape.rank().compatible(m_axes.size()), scale_zero_point_shape.rank().compatible(m_axes.size()),
"Scale/offset rank (", "Scale / zero point rank (",
scale_offset_shape.rank(), scale_zero_point_shape.rank(),
") does not match the number of ", ") does not match the number of ",
"quantization axes (", "quantization axes (",
m_axes.size(), m_axes.size(),
...@@ -112,30 +112,30 @@ void op::Dequantize::validate_and_infer_types() ...@@ -112,30 +112,30 @@ void op::Dequantize::validate_and_infer_types()
set_output_size(1); set_output_size(1);
if (input_shape.rank().is_static() && scale_offset_shape.rank().is_static()) if (input_shape.rank().is_static() && scale_zero_point_shape.rank().is_static())
{ {
size_t i = 0; size_t i = 0;
std::vector<Dimension> injected_scale_offset_dims; vector<Dimension> injected_scale_zero_point_dims;
for (size_t j = 0; j < size_t(input_shape.rank()); j++) for (size_t j = 0; j < size_t(input_shape.rank()); j++)
{ {
if (m_axes.count(j) != 0) if (m_axes.count(j) != 0)
{ {
injected_scale_offset_dims.push_back(scale_offset_shape[i++]); injected_scale_zero_point_dims.push_back(scale_zero_point_shape[i++]);
} }
else else
{ {
injected_scale_offset_dims.push_back(Dimension::dynamic()); injected_scale_zero_point_dims.push_back(Dimension::dynamic());
} }
} }
PartialShape result_shape = input_shape; PartialShape result_shape = input_shape;
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, this,
PartialShape::merge_into(result_shape, PartialShape{injected_scale_offset_dims}), PartialShape::merge_into(result_shape, PartialShape{injected_scale_zero_point_dims}),
"Scale/offset shape (", "Scale / zero point shape (",
scale_offset_shape, scale_zero_point_shape,
") must match input shape (", ") must match input shape (",
input_shape, input_shape,
") at the quantization axes (", ") at the quantization axes (",
......
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
namespace op namespace op
{ {
/// \brief Dequantize operation /// \brief Dequantize operation
/// Maps quantized input (q) to real output (r) using scale (s) and offset (o): /// Maps quantized input (q) to real output (r) using scale (s) and zero point (z):
/// r = (q - o) * s /// r = (q - o) * s
class Dequantize : public ngraph::op::Op class Dequantize : public ngraph::op::Op
{ {
...@@ -33,12 +33,12 @@ namespace ngraph ...@@ -33,12 +33,12 @@ namespace ngraph
/// \brief Constructs a Dequantize operation /// \brief Constructs a Dequantize operation
/// \param input quantized input /// \param input quantized input
/// \param scale scale used for mapping /// \param scale scale used for mapping
/// \param offset offset used for mapping /// \param zero_point zero point used for mapping
/// \param type output element type /// \param type output element type
/// \param axes axis positions on which `scale` and `offset` are specified /// \param axes axis positions on which `scale` and `zero_point` are specified
Dequantize(std::shared_ptr<Node> input, Dequantize(const std::shared_ptr<Node>& input,
std::shared_ptr<Node> scale, const std::shared_ptr<Node>& scale,
std::shared_ptr<Node> offset, const std::shared_ptr<Node>& zero_point,
const ngraph::element::Type& type, const ngraph::element::Type& type,
const ngraph::AxisSet& axes); const ngraph::AxisSet& axes);
......
...@@ -30,7 +30,7 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& data_batc ...@@ -30,7 +30,7 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& data_batc
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale, const shared_ptr<Node>& scale,
const bool requantize) const bool requantize)
: Op("QuantizedConvolution", check_single_output_args({data_batch, filters, scale})) : Op("QuantizedConvolution", check_single_output_args({data_batch, filters, scale}))
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale, const std::shared_ptr<Node>& scale,
const bool requantize = true); const bool requantize = true);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; } const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
......
...@@ -34,7 +34,7 @@ op::QuantizedConvolutionBias::QuantizedConvolutionBias(const shared_ptr<Node>& d ...@@ -34,7 +34,7 @@ op::QuantizedConvolutionBias::QuantizedConvolutionBias(const shared_ptr<Node>& d
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale, const shared_ptr<Node>& scale,
const bool with_relu) const bool with_relu)
: Op("QuantizedConvolutionBias", check_single_output_args({data_batch, filters, bias, scale})) : Op("QuantizedConvolutionBias", check_single_output_args({data_batch, filters, bias, scale}))
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
...@@ -100,8 +100,8 @@ op::QuantizedConvolutionBiasAdd::QuantizedConvolutionBiasAdd(const shared_ptr<No ...@@ -100,8 +100,8 @@ op::QuantizedConvolutionBiasAdd::QuantizedConvolutionBiasAdd(const shared_ptr<No
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale, const shared_ptr<Node>& scale,
const std::shared_ptr<Node> sum_scale, const shared_ptr<Node>& sum_scale,
const bool with_relu) const bool with_relu)
: Op("QuantizedConvolutionBiasAdd", : Op("QuantizedConvolutionBiasAdd",
check_single_output_args({data_batch, filters, bias, sum_input, scale, sum_scale})) check_single_output_args({data_batch, filters, bias, sum_input, scale, sum_scale}))
...@@ -172,8 +172,8 @@ op::QuantizedConvolutionBiasSignedAdd::QuantizedConvolutionBiasSignedAdd( ...@@ -172,8 +172,8 @@ op::QuantizedConvolutionBiasSignedAdd::QuantizedConvolutionBiasSignedAdd(
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale, const shared_ptr<Node>& scale,
const std::shared_ptr<Node> sum_scale, const shared_ptr<Node>& sum_scale,
const bool with_relu) const bool with_relu)
: Op("QuantizedConvolutionBiasSignedAdd", : Op("QuantizedConvolutionBiasSignedAdd",
check_single_output_args({data_batch, filters, bias, sum_input, scale, sum_scale})) check_single_output_args({data_batch, filters, bias, sum_input, scale, sum_scale}))
......
...@@ -39,7 +39,7 @@ namespace ngraph ...@@ -39,7 +39,7 @@ namespace ngraph
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale, const std::shared_ptr<Node>& scale,
const bool with_relu = false); const bool with_relu = false);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
...@@ -75,8 +75,8 @@ namespace ngraph ...@@ -75,8 +75,8 @@ namespace ngraph
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale, const std::shared_ptr<Node>& scale,
const std::shared_ptr<Node> sum_scale, const std::shared_ptr<Node>& sum_scale,
const bool with_relu = false); const bool with_relu = false);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
...@@ -112,8 +112,8 @@ namespace ngraph ...@@ -112,8 +112,8 @@ namespace ngraph
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale, const std::shared_ptr<Node>& scale,
const std::shared_ptr<Node> sum_scale, const std::shared_ptr<Node>& sum_scale,
const bool with_relu = false); const bool with_relu = false);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
......
...@@ -24,14 +24,14 @@ ...@@ -24,14 +24,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::QuantizedConvolutionRelu::QuantizedConvolutionRelu(const std::shared_ptr<Node>& data_batch, op::QuantizedConvolutionRelu::QuantizedConvolutionRelu(const shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters, const shared_ptr<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale) const shared_ptr<Node>& scale)
: Op("QuantizedConvolutionRelu", check_single_output_args({data_batch, filters, scale})) : Op("QuantizedConvolutionRelu", check_single_output_args({data_batch, filters, scale}))
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides) , m_window_dilation_strides(window_dilation_strides)
...@@ -63,20 +63,19 @@ op::QuantizedConvolutionRelu::QuantizedConvolutionRelu(const std::shared_ptr<Nod ...@@ -63,20 +63,19 @@ op::QuantizedConvolutionRelu::QuantizedConvolutionRelu(const std::shared_ptr<Nod
)); ));
} }
std::shared_ptr<Node> shared_ptr<Node> op::QuantizedConvolutionRelu::copy_with_new_args(const NodeVector& new_args) const
op::QuantizedConvolutionRelu::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 3) if (new_args.size() != 3)
{ {
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
} }
return std::shared_ptr<Node>(new QuantizedConvolutionRelu(new_args.at(0), return shared_ptr<Node>(new QuantizedConvolutionRelu(new_args.at(0),
new_args.at(1), new_args.at(1),
get_window_movement_strides(), get_window_movement_strides(),
get_window_dilation_strides(), get_window_dilation_strides(),
get_padding_below(), get_padding_below(),
get_padding_above(), get_padding_above(),
get_data_dilation_strides(), get_data_dilation_strides(),
new_args.at(2))); new_args.at(2)));
} }
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale); const std::shared_ptr<Node>& scale);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; } const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
......
...@@ -20,14 +20,14 @@ ...@@ -20,14 +20,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Quantize::Quantize(shared_ptr<Node> input, op::Quantize::Quantize(const shared_ptr<Node>& input,
shared_ptr<Node> scale, const shared_ptr<Node>& scale,
shared_ptr<Node> offset, const shared_ptr<Node>& zero_point,
const element::Type& type, const element::Type& type,
const AxisSet& axes, const AxisSet& axes,
RoundMode round_mode) RoundMode round_mode)
: Op("Quantize", check_single_output_args({input, scale, offset})) : Op("Quantize", check_single_output_args({input, scale, zero_point}))
, m_type(type) , m_type(type)
, m_axes(axes) , m_axes(axes)
, m_round_mode(round_mode) , m_round_mode(round_mode)
...@@ -41,7 +41,7 @@ void op::Quantize::validate_and_infer_types() ...@@ -41,7 +41,7 @@ void op::Quantize::validate_and_infer_types()
{ {
INPUT, INPUT,
SCALE, SCALE,
OFFSET ZERO_POINT
}; };
NODE_VALIDATION_CHECK(this, m_type.is_static(), "Output element type must not be dynamic"); NODE_VALIDATION_CHECK(this, m_type.is_static(), "Output element type must not be dynamic");
...@@ -63,7 +63,7 @@ void op::Quantize::validate_and_infer_types() ...@@ -63,7 +63,7 @@ void op::Quantize::validate_and_infer_types()
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
unquantized_type.is_dynamic() || unquantized_type.is_real(), unquantized_type.is_dynamic() || unquantized_type.is_real(),
"Scale/input element type (", "Scale / input element type (",
unquantized_type, unquantized_type,
") must be a floating point number"); ") must be a floating point number");
...@@ -71,9 +71,9 @@ void op::Quantize::validate_and_infer_types() ...@@ -71,9 +71,9 @@ void op::Quantize::validate_and_infer_types()
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, this,
element::Type::merge(quantized_type, get_input_element_type(OFFSET), m_type), element::Type::merge(quantized_type, get_input_element_type(ZERO_POINT), m_type),
"Offset element type (", "Zero point element type (",
get_input_element_type(OFFSET), get_input_element_type(ZERO_POINT),
") must match output element type (", ") must match output element type (",
m_type, m_type,
")"); ")");
...@@ -92,21 +92,21 @@ void op::Quantize::validate_and_infer_types() ...@@ -92,21 +92,21 @@ void op::Quantize::validate_and_infer_types()
")"); ")");
} }
PartialShape scale_offset_shape = get_input_partial_shape(SCALE); PartialShape scale_zero_point_shape = get_input_partial_shape(SCALE);
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, this,
PartialShape::merge_into(scale_offset_shape, get_input_partial_shape(OFFSET)), PartialShape::merge_into(scale_zero_point_shape, get_input_partial_shape(ZERO_POINT)),
"Scale shape (", "Scale shape (",
get_input_partial_shape(SCALE), get_input_partial_shape(SCALE),
") and offset shape (", ") and zero point shape (",
get_input_partial_shape(OFFSET), get_input_partial_shape(ZERO_POINT),
") must match"); ") must match");
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
scale_offset_shape.rank().compatible(m_axes.size()), scale_zero_point_shape.rank().compatible(m_axes.size()),
"Scale/offset rank (", "Scale / zero point rank (",
scale_offset_shape.rank(), scale_zero_point_shape.rank(),
") does not match the number of ", ") does not match the number of ",
"quantization axes (", "quantization axes (",
m_axes.size(), m_axes.size(),
...@@ -114,30 +114,30 @@ void op::Quantize::validate_and_infer_types() ...@@ -114,30 +114,30 @@ void op::Quantize::validate_and_infer_types()
set_output_size(1); set_output_size(1);
if (input_shape.rank().is_static() && scale_offset_shape.rank().is_static()) if (input_shape.rank().is_static() && scale_zero_point_shape.rank().is_static())
{ {
size_t i = 0; size_t i = 0;
std::vector<Dimension> injected_scale_offset_dims; vector<Dimension> injected_scale_zero_point_dims;
for (size_t j = 0; j < size_t(input_shape.rank()); j++) for (size_t j = 0; j < size_t(input_shape.rank()); j++)
{ {
if (m_axes.count(j) != 0) if (m_axes.count(j) != 0)
{ {
injected_scale_offset_dims.push_back(scale_offset_shape[i++]); injected_scale_zero_point_dims.push_back(scale_zero_point_shape[i++]);
} }
else else
{ {
injected_scale_offset_dims.push_back(Dimension::dynamic()); injected_scale_zero_point_dims.push_back(Dimension::dynamic());
} }
} }
PartialShape result_shape = input_shape; PartialShape result_shape = input_shape;
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, this,
PartialShape::merge_into(result_shape, PartialShape{injected_scale_offset_dims}), PartialShape::merge_into(result_shape, PartialShape{injected_scale_zero_point_dims}),
"Scale/offset shape (", "Scale / zero point shape (",
scale_offset_shape, scale_zero_point_shape,
") must match input shape (", ") must match input shape (",
input_shape, input_shape,
") at the quantization axes (", ") at the quantization axes (",
......
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
namespace op namespace op
{ {
/// \brief Quantize operation /// \brief Quantize operation
/// Maps real input (r) to quantized output (q) using scale (s), offset (o) and round mode: /// Maps real input (r) to quantized output (q) using scale (s), zero point (z) and round mode:
/// q = ROUND(r / s) + o /// q = ROUND(r / s) + o
class Quantize : public ngraph::op::Op class Quantize : public ngraph::op::Op
{ {
...@@ -78,13 +78,13 @@ namespace ngraph ...@@ -78,13 +78,13 @@ namespace ngraph
/// \brief Constructs a Quantize operation /// \brief Constructs a Quantize operation
/// \param input real input /// \param input real input
/// \param scale scale used for mapping /// \param scale scale used for mapping
/// \param offset offset used for mapping /// \param zero_point zero point used for mapping
/// \param type output element type /// \param type output element type
/// \param axes axis positions on which `scale` and `offset` are specified /// \param axes axis positions on which `scale` and `zero_point` are specified
/// \param round_mode describes how to perform ROUND function (see above) /// \param round_mode describes how to perform ROUND function (see above)
Quantize(std::shared_ptr<Node> input, Quantize(const std::shared_ptr<Node>& input,
std::shared_ptr<Node> scale, const std::shared_ptr<Node>& scale,
std::shared_ptr<Node> offset, const std::shared_ptr<Node>& zero_point,
const ngraph::element::Type& type, const ngraph::element::Type& type,
const ngraph::AxisSet& axes, const ngraph::AxisSet& axes,
RoundMode round_mode); RoundMode round_mode);
......
...@@ -31,24 +31,24 @@ namespace ngraph ...@@ -31,24 +31,24 @@ namespace ngraph
template <typename QUANT, typename REAL> template <typename QUANT, typename REAL>
void dequantize(const QUANT* input, void dequantize(const QUANT* input,
const REAL* scale, const REAL* scale,
const QUANT* offset, const QUANT* zero_point,
REAL* output, REAL* output,
const Shape& input_shape, const Shape& input_shape,
const Shape& scale_offset_shape, const Shape& scale_zero_point_shape,
const AxisSet& axes) const AxisSet& axes)
{ {
CoordinateTransform input_transform(input_shape); CoordinateTransform input_transform(input_shape);
CoordinateTransform scale_offset_transform(scale_offset_shape); CoordinateTransform scale_zero_point_transform(scale_zero_point_shape);
for (const Coordinate& input_coord : input_transform) for (const Coordinate& input_coord : input_transform)
{ {
Coordinate scale_offset_coord = project(input_coord, axes); Coordinate scale_zero_point_coord = project(input_coord, axes);
output[input_transform.index(input_coord)] = output[input_transform.index(input_coord)] =
static_cast<REAL>( static_cast<REAL>((
(input[input_transform.index(input_coord)] - input[input_transform.index(input_coord)] -
offset[scale_offset_transform.index(scale_offset_coord)])) * zero_point[scale_zero_point_transform.index(scale_zero_point_coord)])) *
scale[scale_offset_transform.index(scale_offset_coord)]; scale[scale_zero_point_transform.index(scale_zero_point_coord)];
} }
} }
} }
......
...@@ -28,23 +28,23 @@ namespace ngraph ...@@ -28,23 +28,23 @@ namespace ngraph
template <typename REAL, typename QUANT> template <typename REAL, typename QUANT>
void quantize(const REAL* input, void quantize(const REAL* input,
const REAL* scale, const REAL* scale,
const QUANT* offset, const QUANT* zero_point,
QUANT* output, QUANT* output,
const Shape& input_shape, const Shape& input_shape,
const Shape& scale_offset_shape, const Shape& scale_zero_point_shape,
const AxisSet& axes, const AxisSet& axes,
op::Quantize::RoundMode round_mode) op::Quantize::RoundMode round_mode)
{ {
CoordinateTransform input_transform(input_shape); CoordinateTransform input_transform(input_shape);
CoordinateTransform scale_offset_transform(scale_offset_shape); CoordinateTransform scale_zero_point_transform(scale_zero_point_shape);
for (const Coordinate& input_coord : input_transform) for (const Coordinate& input_coord : input_transform)
{ {
Coordinate scale_offset_coord = project(input_coord, axes); Coordinate scale_zero_point_coord = project(input_coord, axes);
// apply scale // apply scale
REAL qvalue = input[input_transform.index(input_coord)] / REAL qvalue = input[input_transform.index(input_coord)] /
scale[scale_offset_transform.index(scale_offset_coord)]; scale[scale_zero_point_transform.index(scale_zero_point_coord)];
// round // round
if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY) if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY)
...@@ -95,8 +95,8 @@ namespace ngraph ...@@ -95,8 +95,8 @@ namespace ngraph
qvalue = std::floor(qvalue); qvalue = std::floor(qvalue);
} }
// apply offset // apply zero_point
qvalue += offset[scale_offset_transform.index(scale_offset_coord)]; qvalue += zero_point[scale_zero_point_transform.index(scale_zero_point_coord)];
// clamp // clamp
qvalue = std::max<REAL>(qvalue, qvalue = std::max<REAL>(qvalue,
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment