Commit 41f7cb5f authored by Adam Straw's avatar Adam Straw Committed by Scott Cyphers

Quantization API cleanup (non breaking) (#2873)

* quantization cleanup

* offset changed to zero point

* fix failing tests

* code style

* code style
parent eb83f267
......@@ -13,33 +13,33 @@ Description
Produces a tensor of element type ``type`` and the same shape as ``input``
where the value of each coordinate :math:`i` of ``output`` is the corresponding coordinate of
``input`` minus ``offset`` quantity multiplied by ``scale``.
The coordinate :math:`j` of ``scale`` and ``offset`` is the coordinate of ``output``
``input`` minus ``zero_point`` quantity multiplied by ``scale``.
The coordinate :math:`j` of ``scale`` and ``zero_point`` is the coordinate of ``output``
projected onto ``axes``.
Inputs
------
+-----------------+-------------------------+------------------------------------------+
| Name | Element Type | Shape |
+=================+=========================+==========================================+
| ``input`` | Any quantized type | Any |
+-----------------+-------------------------+------------------------------------------+
| ``scale`` | Same as ``output`` | ``input`` shape projected onto ``axes`` |
+-----------------+-------------------------+------------------------------------------+
| ``offset`` | Same as ``input`` | ``input`` shape projected onto ``axes`` |
+-----------------+-------------------------+------------------------------------------+
+-----------------+-------------------------+----------------------------------------------+
| Name | Element Type | Shape |
+=================+=========================+==============================================+
| ``input`` | Any quantized type | Any |
+-----------------+-------------------------+----------------------------------------------+
| ``scale`` | Same as ``output`` | ``input`` shape projected onto ``axes`` |
+-----------------+-------------------------+----------------------------------------------+
| ``zero_point`` | Same as ``input`` | ``input`` shape projected onto ``axes`` |
+-----------------+-------------------------+----------------------------------------------+
Attributes
----------
+-------------------------------+----------------------------------------------------------------+
| Name | Description |
+===============================+================================================================+
| ``type`` | ``output`` element type; any real type |
+-------------------------------+----------------------------------------------------------------+
| ``axes`` | Axis positions on which ``scale`` and ``offset`` are specified |
+-------------------------------+----------------------------------------------------------------+
+-------------------------------+--------------------------------------------------------------------+
| Name | Description |
+===============================+====================================================================+
| ``type`` | ``output`` element type; any real type |
+-------------------------------+--------------------------------------------------------------------+
| ``axes`` | Axis positions on which ``scale`` and ``zero_point`` are specified |
+-------------------------------+--------------------------------------------------------------------+
......@@ -59,7 +59,7 @@ Mathematical Definition
.. math::
\mathtt{output}_{i,j} = (\mathtt{input}_{i,j} - \mathtt{offset}_{j}) \mathtt{scale}_{j}
\mathtt{output}_{i,j} = (\mathtt{input}_{i,j} - \mathtt{zero_point}_{j}) \mathtt{scale}_{j}
C++ Interface
=============
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -37,16 +37,16 @@ namespace ngraph
{
namespace quantization
{
shared_ptr<Node> QuantizedLinearConvolution(shared_ptr<Node> input,
shared_ptr<Node> filter,
shared_ptr<Node> QuantizedLinearConvolution(const shared_ptr<Node>& input,
const shared_ptr<Node>& filter,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
shared_ptr<Node> input_scale,
shared_ptr<Node> filter_scale,
shared_ptr<Node> output_scale)
const shared_ptr<Node>& input_scale,
const shared_ptr<Node>& filter_scale,
const shared_ptr<Node>& output_scale)
{
// TODO: need to establish cross-nGraph view of scale (mult or div)
auto requantization_scale = (input_scale * filter_scale) / output_scale;
......@@ -65,19 +65,19 @@ namespace ngraph
// need to make this the primary builder which means
// 1) add support for zero point in QuantizeConvolution op API
// 2) add QuantizedConvolution reference kernel, including zero point
shared_ptr<Node> QuantizedLinearConvolution(shared_ptr<Node> input,
shared_ptr<Node> filter,
shared_ptr<Node> QuantizedLinearConvolution(const shared_ptr<Node>& input,
const shared_ptr<Node>& filter,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
shared_ptr<Node> input_scale,
shared_ptr<Node> input_zero_point,
shared_ptr<Node> filter_scale,
shared_ptr<Node> filter_zero_point,
shared_ptr<Node> output_scale,
shared_ptr<Node> output_zero_point)
const shared_ptr<Node>& input_scale,
const shared_ptr<Node>& input_zero_point,
const shared_ptr<Node>& filter_scale,
const shared_ptr<Node>& filter_zero_point,
const shared_ptr<Node>& output_scale,
const shared_ptr<Node>& output_zero_point)
{
AxisSet axes;
......@@ -107,21 +107,22 @@ namespace ngraph
return q_convolution;
}
shared_ptr<Node> QuantizedLinearConvolutionBias(shared_ptr<Node> input,
shared_ptr<Node> filter,
shared_ptr<Node> bias,
shared_ptr<Node> QuantizedLinearConvolutionBias(const shared_ptr<Node>& input,
const shared_ptr<Node>& filter,
const shared_ptr<Node>& bias,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
shared_ptr<Node> input_scale,
shared_ptr<Node> filter_scale,
shared_ptr<Node> output_scale)
const shared_ptr<Node>& input_scale,
const shared_ptr<Node>& filter_scale,
const shared_ptr<Node>& output_scale)
{
// TODO: need to establish cross-nGraph view of scale (mult or div)
auto requantization_scale = (input_scale * filter_scale) / output_scale;
auto mybias = bias;
if (bias->get_element_type() != element::i32)
{
auto zero = make_constant(element::i32, input_scale->get_shape(), 0);
......@@ -130,13 +131,13 @@ namespace ngraph
op::Quantize::RoundMode round_mode =
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
bias = make_shared<op::Quantize>(
mybias = make_shared<op::Quantize>(
bias, bias_scale, zero, element::i32, quantization_axes, round_mode);
}
return make_shared<op::QuantizedConvolutionBias>(input,
filter,
bias,
mybias,
window_movement_strides,
window_dilation_strides,
padding_below,
......@@ -146,8 +147,8 @@ namespace ngraph
false);
}
shared_ptr<Node> QuantizedConvInteger(shared_ptr<Node> input,
shared_ptr<Node> filter,
shared_ptr<Node> QuantizedConvInteger(const shared_ptr<Node>& input,
const shared_ptr<Node>& filter,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
......
......@@ -25,47 +25,48 @@ namespace ngraph
{
namespace quantization
{
std::shared_ptr<Node> QuantizedLinearConvolution(std::shared_ptr<Node> input,
std::shared_ptr<Node> filter,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
std::shared_ptr<Node> input_scale,
std::shared_ptr<Node> filter_scale,
std::shared_ptr<Node> output_scale);
std::shared_ptr<Node>
QuantizedLinearConvolution(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& filter,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node>& input_scale,
const std::shared_ptr<Node>& filter_scale,
const std::shared_ptr<Node>& output_scale);
std::shared_ptr<Node>
QuantizedLinearConvolution(std::shared_ptr<Node> input,
std::shared_ptr<Node> filter,
QuantizedLinearConvolution(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& filter,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
std::shared_ptr<Node> input_scale,
std::shared_ptr<Node> input_zero_point,
std::shared_ptr<Node> filter_scale,
std::shared_ptr<Node> filter_zero_point,
std::shared_ptr<Node> output_scale,
std::shared_ptr<Node> output_zero_point);
const std::shared_ptr<Node>& input_scale,
const std::shared_ptr<Node>& input_zero_point,
const std::shared_ptr<Node>& filter_scale,
const std::shared_ptr<Node>& filter_zero_point,
const std::shared_ptr<Node>& output_scale,
const std::shared_ptr<Node>& output_zero_point);
std::shared_ptr<Node>
QuantizedLinearConvolutionBias(std::shared_ptr<Node> input,
std::shared_ptr<Node> filter,
std::shared_ptr<Node> bias,
QuantizedLinearConvolutionBias(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& filter,
const std::shared_ptr<Node>& bias,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
std::shared_ptr<Node> input_scale,
std::shared_ptr<Node> filter_scale,
std::shared_ptr<Node> output_scale);
const std::shared_ptr<Node>& input_scale,
const std::shared_ptr<Node>& filter_scale,
const std::shared_ptr<Node>& output_scale);
std::shared_ptr<Node> QuantizedConvInteger(std::shared_ptr<Node> input,
std::shared_ptr<Node> filter,
std::shared_ptr<Node> QuantizedConvInteger(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& filter,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
......
......@@ -20,13 +20,13 @@
using namespace std;
using namespace ngraph;
op::Dequantize::Dequantize(shared_ptr<Node> input,
shared_ptr<Node> scale,
shared_ptr<Node> offset,
op::Dequantize::Dequantize(const shared_ptr<Node>& input,
const shared_ptr<Node>& scale,
const shared_ptr<Node>& zero_point,
const element::Type& type,
const AxisSet& axes)
: Op("Dequantize", check_single_output_args({input, scale, offset}))
: Op("Dequantize", check_single_output_args({input, scale, zero_point}))
, m_type(type)
, m_axes(axes)
{
......@@ -39,7 +39,7 @@ void op::Dequantize::validate_and_infer_types()
{
INPUT,
SCALE,
OFFSET
ZERO_POINT
};
NODE_VALIDATION_CHECK(this, m_type.is_static(), "Output element type must not be dynamic");
......@@ -52,16 +52,16 @@ void op::Dequantize::validate_and_infer_types()
NODE_VALIDATION_CHECK(this,
element::Type::merge(quantized_type,
get_input_element_type(INPUT),
get_input_element_type(OFFSET)),
"Offset element type (",
get_input_element_type(OFFSET),
get_input_element_type(ZERO_POINT)),
"Zero point element type (",
get_input_element_type(ZERO_POINT),
") must match input element type (",
get_input_element_type(INPUT),
")");
NODE_VALIDATION_CHECK(this,
quantized_type.is_dynamic() || quantized_type.is_quantized(),
"Offset/input element type (",
"Zero point / input element type (",
quantized_type,
") must be a quantized type");
......@@ -90,21 +90,21 @@ void op::Dequantize::validate_and_infer_types()
")");
}
PartialShape scale_offset_shape = get_input_partial_shape(SCALE);
PartialShape scale_zero_point_shape = get_input_partial_shape(SCALE);
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(scale_offset_shape, get_input_partial_shape(OFFSET)),
PartialShape::merge_into(scale_zero_point_shape, get_input_partial_shape(ZERO_POINT)),
"Scale shape (",
get_input_partial_shape(SCALE),
") and offset shape (",
get_input_partial_shape(OFFSET),
") and zero point shape (",
get_input_partial_shape(ZERO_POINT),
") must match");
NODE_VALIDATION_CHECK(this,
scale_offset_shape.rank().compatible(m_axes.size()),
"Scale/offset rank (",
scale_offset_shape.rank(),
scale_zero_point_shape.rank().compatible(m_axes.size()),
"Scale / zero point rank (",
scale_zero_point_shape.rank(),
") does not match the number of ",
"quantization axes (",
m_axes.size(),
......@@ -112,30 +112,30 @@ void op::Dequantize::validate_and_infer_types()
set_output_size(1);
if (input_shape.rank().is_static() && scale_offset_shape.rank().is_static())
if (input_shape.rank().is_static() && scale_zero_point_shape.rank().is_static())
{
size_t i = 0;
std::vector<Dimension> injected_scale_offset_dims;
vector<Dimension> injected_scale_zero_point_dims;
for (size_t j = 0; j < size_t(input_shape.rank()); j++)
{
if (m_axes.count(j) != 0)
{
injected_scale_offset_dims.push_back(scale_offset_shape[i++]);
injected_scale_zero_point_dims.push_back(scale_zero_point_shape[i++]);
}
else
{
injected_scale_offset_dims.push_back(Dimension::dynamic());
injected_scale_zero_point_dims.push_back(Dimension::dynamic());
}
}
PartialShape result_shape = input_shape;
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(result_shape, PartialShape{injected_scale_offset_dims}),
"Scale/offset shape (",
scale_offset_shape,
PartialShape::merge_into(result_shape, PartialShape{injected_scale_zero_point_dims}),
"Scale / zero point shape (",
scale_zero_point_shape,
") must match input shape (",
input_shape,
") at the quantization axes (",
......
......@@ -25,7 +25,7 @@ namespace ngraph
namespace op
{
/// \brief Dequantize operation
/// Maps quantized input (q) to real output (r) using scale (s) and offset (o):
/// Maps quantized input (q) to real output (r) using scale (s) and zero point (z):
/// r = (q - o) * s
class Dequantize : public ngraph::op::Op
{
......@@ -33,12 +33,12 @@ namespace ngraph
/// \brief Constructs a Dequantize operation
/// \param input quantized input
/// \param scale scale used for mapping
/// \param offset offset used for mapping
/// \param zero_point zero point used for mapping
/// \param type output element type
/// \param axes axis positions on which `scale` and `offset` are specified
Dequantize(std::shared_ptr<Node> input,
std::shared_ptr<Node> scale,
std::shared_ptr<Node> offset,
/// \param axes axis positions on which `scale` and `zero_point` are specified
Dequantize(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& scale,
const std::shared_ptr<Node>& zero_point,
const ngraph::element::Type& type,
const ngraph::AxisSet& axes);
......
......@@ -30,7 +30,7 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& data_batc
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale,
const shared_ptr<Node>& scale,
const bool requantize)
: Op("QuantizedConvolution", check_single_output_args({data_batch, filters, scale}))
, m_window_movement_strides(window_movement_strides)
......
......@@ -33,7 +33,7 @@ namespace ngraph
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale,
const std::shared_ptr<Node>& scale,
const bool requantize = true);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
......
......@@ -34,7 +34,7 @@ op::QuantizedConvolutionBias::QuantizedConvolutionBias(const shared_ptr<Node>& d
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale,
const shared_ptr<Node>& scale,
const bool with_relu)
: Op("QuantizedConvolutionBias", check_single_output_args({data_batch, filters, bias, scale}))
, m_window_movement_strides(window_movement_strides)
......@@ -100,8 +100,8 @@ op::QuantizedConvolutionBiasAdd::QuantizedConvolutionBiasAdd(const shared_ptr<No
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale,
const std::shared_ptr<Node> sum_scale,
const shared_ptr<Node>& scale,
const shared_ptr<Node>& sum_scale,
const bool with_relu)
: Op("QuantizedConvolutionBiasAdd",
check_single_output_args({data_batch, filters, bias, sum_input, scale, sum_scale}))
......@@ -172,8 +172,8 @@ op::QuantizedConvolutionBiasSignedAdd::QuantizedConvolutionBiasSignedAdd(
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale,
const std::shared_ptr<Node> sum_scale,
const shared_ptr<Node>& scale,
const shared_ptr<Node>& sum_scale,
const bool with_relu)
: Op("QuantizedConvolutionBiasSignedAdd",
check_single_output_args({data_batch, filters, bias, sum_input, scale, sum_scale}))
......
......@@ -39,7 +39,7 @@ namespace ngraph
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale,
const std::shared_ptr<Node>& scale,
const bool with_relu = false);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
......@@ -75,8 +75,8 @@ namespace ngraph
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale,
const std::shared_ptr<Node> sum_scale,
const std::shared_ptr<Node>& scale,
const std::shared_ptr<Node>& sum_scale,
const bool with_relu = false);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
......@@ -112,8 +112,8 @@ namespace ngraph
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale,
const std::shared_ptr<Node> sum_scale,
const std::shared_ptr<Node>& scale,
const std::shared_ptr<Node>& sum_scale,
const bool with_relu = false);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
......
......@@ -24,14 +24,14 @@
using namespace std;
using namespace ngraph;
op::QuantizedConvolutionRelu::QuantizedConvolutionRelu(const std::shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters,
op::QuantizedConvolutionRelu::QuantizedConvolutionRelu(const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale)
const shared_ptr<Node>& scale)
: Op("QuantizedConvolutionRelu", check_single_output_args({data_batch, filters, scale}))
, m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides)
......@@ -63,20 +63,19 @@ op::QuantizedConvolutionRelu::QuantizedConvolutionRelu(const std::shared_ptr<Nod
));
}
std::shared_ptr<Node>
op::QuantizedConvolutionRelu::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::QuantizedConvolutionRelu::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::shared_ptr<Node>(new QuantizedConvolutionRelu(new_args.at(0),
new_args.at(1),
get_window_movement_strides(),
get_window_dilation_strides(),
get_padding_below(),
get_padding_above(),
get_data_dilation_strides(),
new_args.at(2)));
return shared_ptr<Node>(new QuantizedConvolutionRelu(new_args.at(0),
new_args.at(1),
get_window_movement_strides(),
get_window_dilation_strides(),
get_padding_below(),
get_padding_above(),
get_data_dilation_strides(),
new_args.at(2)));
}
......@@ -36,7 +36,7 @@ namespace ngraph
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node> scale);
const std::shared_ptr<Node>& scale);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
......
......@@ -20,14 +20,14 @@
using namespace std;
using namespace ngraph;
op::Quantize::Quantize(shared_ptr<Node> input,
shared_ptr<Node> scale,
shared_ptr<Node> offset,
op::Quantize::Quantize(const shared_ptr<Node>& input,
const shared_ptr<Node>& scale,
const shared_ptr<Node>& zero_point,
const element::Type& type,
const AxisSet& axes,
RoundMode round_mode)
: Op("Quantize", check_single_output_args({input, scale, offset}))
: Op("Quantize", check_single_output_args({input, scale, zero_point}))
, m_type(type)
, m_axes(axes)
, m_round_mode(round_mode)
......@@ -41,7 +41,7 @@ void op::Quantize::validate_and_infer_types()
{
INPUT,
SCALE,
OFFSET
ZERO_POINT
};
NODE_VALIDATION_CHECK(this, m_type.is_static(), "Output element type must not be dynamic");
......@@ -63,7 +63,7 @@ void op::Quantize::validate_and_infer_types()
NODE_VALIDATION_CHECK(this,
unquantized_type.is_dynamic() || unquantized_type.is_real(),
"Scale/input element type (",
"Scale / input element type (",
unquantized_type,
") must be a floating point number");
......@@ -71,9 +71,9 @@ void op::Quantize::validate_and_infer_types()
NODE_VALIDATION_CHECK(
this,
element::Type::merge(quantized_type, get_input_element_type(OFFSET), m_type),
"Offset element type (",
get_input_element_type(OFFSET),
element::Type::merge(quantized_type, get_input_element_type(ZERO_POINT), m_type),
"Zero point element type (",
get_input_element_type(ZERO_POINT),
") must match output element type (",
m_type,
")");
......@@ -92,21 +92,21 @@ void op::Quantize::validate_and_infer_types()
")");
}
PartialShape scale_offset_shape = get_input_partial_shape(SCALE);
PartialShape scale_zero_point_shape = get_input_partial_shape(SCALE);
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(scale_offset_shape, get_input_partial_shape(OFFSET)),
PartialShape::merge_into(scale_zero_point_shape, get_input_partial_shape(ZERO_POINT)),
"Scale shape (",
get_input_partial_shape(SCALE),
") and offset shape (",
get_input_partial_shape(OFFSET),
") and zero point shape (",
get_input_partial_shape(ZERO_POINT),
") must match");
NODE_VALIDATION_CHECK(this,
scale_offset_shape.rank().compatible(m_axes.size()),
"Scale/offset rank (",
scale_offset_shape.rank(),
scale_zero_point_shape.rank().compatible(m_axes.size()),
"Scale / zero point rank (",
scale_zero_point_shape.rank(),
") does not match the number of ",
"quantization axes (",
m_axes.size(),
......@@ -114,30 +114,30 @@ void op::Quantize::validate_and_infer_types()
set_output_size(1);
if (input_shape.rank().is_static() && scale_offset_shape.rank().is_static())
if (input_shape.rank().is_static() && scale_zero_point_shape.rank().is_static())
{
size_t i = 0;
std::vector<Dimension> injected_scale_offset_dims;
vector<Dimension> injected_scale_zero_point_dims;
for (size_t j = 0; j < size_t(input_shape.rank()); j++)
{
if (m_axes.count(j) != 0)
{
injected_scale_offset_dims.push_back(scale_offset_shape[i++]);
injected_scale_zero_point_dims.push_back(scale_zero_point_shape[i++]);
}
else
{
injected_scale_offset_dims.push_back(Dimension::dynamic());
injected_scale_zero_point_dims.push_back(Dimension::dynamic());
}
}
PartialShape result_shape = input_shape;
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(result_shape, PartialShape{injected_scale_offset_dims}),
"Scale/offset shape (",
scale_offset_shape,
PartialShape::merge_into(result_shape, PartialShape{injected_scale_zero_point_dims}),
"Scale / zero point shape (",
scale_zero_point_shape,
") must match input shape (",
input_shape,
") at the quantization axes (",
......
......@@ -25,7 +25,7 @@ namespace ngraph
namespace op
{
/// \brief Quantize operation
/// Maps real input (r) to quantized output (q) using scale (s), offset (o) and round mode:
/// Maps real input (r) to quantized output (q) using scale (s), zero point (z) and round mode:
/// q = ROUND(r / s) + o
class Quantize : public ngraph::op::Op
{
......@@ -78,13 +78,13 @@ namespace ngraph
/// \brief Constructs a Quantize operation
/// \param input real input
/// \param scale scale used for mapping
/// \param offset offset used for mapping
/// \param zero_point zero point used for mapping
/// \param type output element type
/// \param axes axis positions on which `scale` and `offset` are specified
/// \param axes axis positions on which `scale` and `zero_point` are specified
/// \param round_mode describes how to perform ROUND function (see above)
Quantize(std::shared_ptr<Node> input,
std::shared_ptr<Node> scale,
std::shared_ptr<Node> offset,
Quantize(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& scale,
const std::shared_ptr<Node>& zero_point,
const ngraph::element::Type& type,
const ngraph::AxisSet& axes,
RoundMode round_mode);
......
......@@ -31,24 +31,24 @@ namespace ngraph
template <typename QUANT, typename REAL>
void dequantize(const QUANT* input,
const REAL* scale,
const QUANT* offset,
const QUANT* zero_point,
REAL* output,
const Shape& input_shape,
const Shape& scale_offset_shape,
const Shape& scale_zero_point_shape,
const AxisSet& axes)
{
CoordinateTransform input_transform(input_shape);
CoordinateTransform scale_offset_transform(scale_offset_shape);
CoordinateTransform scale_zero_point_transform(scale_zero_point_shape);
for (const Coordinate& input_coord : input_transform)
{
Coordinate scale_offset_coord = project(input_coord, axes);
Coordinate scale_zero_point_coord = project(input_coord, axes);
output[input_transform.index(input_coord)] =
static_cast<REAL>(
(input[input_transform.index(input_coord)] -
offset[scale_offset_transform.index(scale_offset_coord)])) *
scale[scale_offset_transform.index(scale_offset_coord)];
static_cast<REAL>((
input[input_transform.index(input_coord)] -
zero_point[scale_zero_point_transform.index(scale_zero_point_coord)])) *
scale[scale_zero_point_transform.index(scale_zero_point_coord)];
}
}
}
......
......@@ -28,23 +28,23 @@ namespace ngraph
template <typename REAL, typename QUANT>
void quantize(const REAL* input,
const REAL* scale,
const QUANT* offset,
const QUANT* zero_point,
QUANT* output,
const Shape& input_shape,
const Shape& scale_offset_shape,
const Shape& scale_zero_point_shape,
const AxisSet& axes,
op::Quantize::RoundMode round_mode)
{
CoordinateTransform input_transform(input_shape);
CoordinateTransform scale_offset_transform(scale_offset_shape);
CoordinateTransform scale_zero_point_transform(scale_zero_point_shape);
for (const Coordinate& input_coord : input_transform)
{
Coordinate scale_offset_coord = project(input_coord, axes);
Coordinate scale_zero_point_coord = project(input_coord, axes);
// apply scale
REAL qvalue = input[input_transform.index(input_coord)] /
scale[scale_offset_transform.index(scale_offset_coord)];
scale[scale_zero_point_transform.index(scale_zero_point_coord)];
// round
if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY)
......@@ -95,8 +95,8 @@ namespace ngraph
qvalue = std::floor(qvalue);
}
// apply offset
qvalue += offset[scale_offset_transform.index(scale_offset_coord)];
// apply zero_point
qvalue += zero_point[scale_zero_point_transform.index(scale_zero_point_coord)];
// clamp
qvalue = std::max<REAL>(qvalue,
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment