Commit 51104813 authored by Adam Straw's avatar Adam Straw Committed by Scott Cyphers

add support for Quantize round mode (#1859)

* added half_toward_zero; all previous tests passing

* all rounding modes added with unit tests

* fix cpu emitter

* round mode doc

* round out round modes

* doc typo

* using  names for round modes

* use ceil/floor for rounding functions instead of round/nearbyint

* clean up doc

* equidistant
parent e765956a
...@@ -40,11 +40,49 @@ Attributes ...@@ -40,11 +40,49 @@ Attributes
+-------------------------------+----------------------------------------------------------------+ +-------------------------------+----------------------------------------------------------------+
| ``axes`` | Axis positions on which ``scale`` and ``offset`` are specified | | ``axes`` | Axis positions on which ``scale`` and ``offset`` are specified |
+-------------------------------+----------------------------------------------------------------+ +-------------------------------+----------------------------------------------------------------+
| ``round_mode`` | Refer to ``/src/ngraph/op/quantize.hpp`` | | ``round_mode`` | *ROUND_NEAREST_TOWARD_INFINITY:* |
| | round to nearest integer |
| | in case of two equidistant integers round away from zero e.g. |
| | 2.5 -> 3 |
| | -3.5 -> -4 |
| | |
| | *ROUND_NEAREST_TOWARD_ZERO:* |
| | round to nearest integer |
| | in case of two equidistant integers round toward zero e.g. |
| | 2.5 -> 2 |
| | -3.5 to -3 |
| | |
| | *ROUND_NEAREST_UPWARD:* |
| | round to nearest integer |
| | in case of two equidistant integers round up e.g. |
| | 2.5 to 3 |
| | -3.5 to -3 |
| | |
| | *ROUND_NEAREST_DOWNWARD:* |
| | round to nearest integer |
| | in case of two equidistant integers round down e.g. |
| | 2.5 to 2 |
| | -3.5 to -4 |
| | |
| | *ROUND_NEAREST_TOWARD_EVEN:* |
| | round to nearest integer |
| | in case of two equidistant integers round to even e.g. |
| | 2.5 to 2 |
| | -3.5 to -4 |
| | |
| | *ROUND_TOWARD_INFINITY:* |
| | round to nearest integer away from zero |
| | |
| | *ROUND_TOWARD_ZERO:* |
| | round to nearest integer toward zero |
| | |
| | *ROUND_UP:* |
| | round to nearest integer toward infinity (ceiling) |
| | |
| | *ROUND_DOWN:* |
| | round to nearest integer toward negative infinity (floor) |
+-------------------------------+----------------------------------------------------------------+ +-------------------------------+----------------------------------------------------------------+
Outputs Outputs
------- -------
......
...@@ -44,9 +44,6 @@ void op::Quantize::validate_and_infer_types() ...@@ -44,9 +44,6 @@ void op::Quantize::validate_and_infer_types()
OFFSET OFFSET
}; };
NODE_VALIDATION_ASSERT(this, m_round_mode == RoundMode::HALF_AWAY_FROM_ZERO)
<< "Only RoundMode = HALF_AWAY_FROM_ZERO is supported, for now";
NODE_VALIDATION_ASSERT(this, m_type.is_static()) << "Output element type must not be dynamic"; NODE_VALIDATION_ASSERT(this, m_type.is_static()) << "Output element type must not be dynamic";
NODE_VALIDATION_ASSERT(this, m_type.is_quantized()) << "Output element type (" << m_type NODE_VALIDATION_ASSERT(this, m_type.is_quantized()) << "Output element type (" << m_type
......
...@@ -32,12 +32,48 @@ namespace ngraph ...@@ -32,12 +32,48 @@ namespace ngraph
public: public:
enum class RoundMode enum class RoundMode
{ {
// -3.5 -> 4 // round to nearest integer
// in case of two equidistant integers round away from zero e.g.
// 2.5 -> 3 // 2.5 -> 3
HALF_AWAY_FROM_ZERO, // -3.5 -> -4
// -3.5 -> 4 ROUND_NEAREST_TOWARD_INFINITY,
// 2.5 -> 2 (nearest even) HALF_AWAY_FROM_ZERO, // TF mode for backward compatability
HALF_TO_EVEN
// round to nearest integer
// in case of two equidistant integers round toward zero e.g.
// 2.5 -> 2
// -3.5 -> -3
ROUND_NEAREST_TOWARD_ZERO,
// round to nearest integer
// in case of two equidistant integers round up e.g.
// 2.5 -> 3
// -3.5 -> -3
ROUND_NEAREST_UPWARD,
// round to nearest integer
// in case of two equidistant integers round down e.g.
// 2.5 -> 2
// -3.5 -> -4
ROUND_NEAREST_DOWNWARD,
// round to nearest integer
// in case of two equidistant integers round to even e.g.
// 2.5 -> 2
// -3.5 -> -4
ROUND_NEAREST_TOWARD_EVEN,
// round to nearest integer away from zero
ROUND_TOWARD_INFINITY,
// round to nearest integer toward zero
ROUND_TOWARD_ZERO,
// round to nearest integer toward infinity (ceiling)
ROUND_UP,
// round to nearest integer toward negative infinity (floor)
ROUND_DOWN,
}; };
/// \brief Constructs a Quantize operation /// \brief Constructs a Quantize operation
...@@ -46,7 +82,7 @@ namespace ngraph ...@@ -46,7 +82,7 @@ namespace ngraph
/// \param offset offset used for mapping /// \param offset offset used for mapping
/// \param type output element type /// \param type output element type
/// \param axes axis positions on which `scale` and `offset` are specified /// \param axes axis positions on which `scale` and `offset` are specified
/// \param round_mode describes how to perform ROUND function /// \param round_mode describes how to perform ROUND function (see above)
Quantize(std::shared_ptr<Node> input, Quantize(std::shared_ptr<Node> input,
std::shared_ptr<Node> scale, std::shared_ptr<Node> scale,
std::shared_ptr<Node> offset, std::shared_ptr<Node> offset,
......
...@@ -547,7 +547,8 @@ shared_ptr<op::Constant> make_constant_quantize(shared_ptr<op::Constant> constan ...@@ -547,7 +547,8 @@ shared_ptr<op::Constant> make_constant_quantize(shared_ptr<op::Constant> constan
out_vec.data(), out_vec.data(),
constant->get_shape(), constant->get_shape(),
scale->get_shape(), scale->get_shape(),
quant->get_axes()); quant->get_axes(),
quant->get_round_mode());
return make_shared<op::Constant>(quant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(quant->get_element_type(), out_shape, out_vec);
} }
......
...@@ -142,34 +142,39 @@ namespace ngraph ...@@ -142,34 +142,39 @@ namespace ngraph
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape(); auto arg1_shape = args[1].get_shape();
auto daxes = quantize->get_axes(); auto daxes = quantize->get_axes();
op::Quantize::RoundMode round_mode = quantize->get_round_mode();
if (args[0].get_element_type() == element::f32) if (args[0].get_element_type() == element::f32)
{ {
if (out[0].get_element_type() == element::i8) if (out[0].get_element_type() == element::i8)
{ {
functor = [&, arg0_shape, arg1_shape, daxes](CPURuntimeContext* ctx) { functor =
ngraph::runtime::reference::quantize<float>( [&, arg0_shape, arg1_shape, daxes, round_mode](CPURuntimeContext* ctx) {
static_cast<float*>(arg0_tensor), ngraph::runtime::reference::quantize<float>(
static_cast<float*>(arg1_tensor), static_cast<float*>(arg0_tensor),
static_cast<int8_t*>(arg2_tensor), static_cast<float*>(arg1_tensor),
static_cast<int8_t*>(out_tensor), static_cast<int8_t*>(arg2_tensor),
arg0_shape, static_cast<int8_t*>(out_tensor),
arg1_shape, arg0_shape,
daxes); arg1_shape,
}; daxes,
round_mode);
};
} }
else if (out[0].get_element_type() == element::u8) else if (out[0].get_element_type() == element::u8)
{ {
functor = [&, arg0_shape, arg1_shape, daxes](CPURuntimeContext* ctx) { functor =
ngraph::runtime::reference::quantize<float>( [&, arg0_shape, arg1_shape, daxes, round_mode](CPURuntimeContext* ctx) {
static_cast<float*>(arg0_tensor), ngraph::runtime::reference::quantize<float>(
static_cast<float*>(arg1_tensor), static_cast<float*>(arg0_tensor),
static_cast<uint8_t*>(arg2_tensor), static_cast<float*>(arg1_tensor),
static_cast<uint8_t*>(out_tensor), static_cast<uint8_t*>(arg2_tensor),
arg0_shape, static_cast<uint8_t*>(out_tensor),
arg1_shape, arg0_shape,
daxes); arg1_shape,
}; daxes,
round_mode);
};
} }
else else
{ {
...@@ -180,29 +185,33 @@ namespace ngraph ...@@ -180,29 +185,33 @@ namespace ngraph
{ {
if (out[0].get_element_type() == element::i8) if (out[0].get_element_type() == element::i8)
{ {
functor = [&, arg0_shape, arg1_shape, daxes](CPURuntimeContext* ctx) { functor =
ngraph::runtime::reference::quantize<double>( [&, arg0_shape, arg1_shape, daxes, round_mode](CPURuntimeContext* ctx) {
static_cast<double*>(arg0_tensor), ngraph::runtime::reference::quantize<double>(
static_cast<double*>(arg1_tensor), static_cast<double*>(arg0_tensor),
static_cast<int8_t*>(arg2_tensor), static_cast<double*>(arg1_tensor),
static_cast<int8_t*>(out_tensor), static_cast<int8_t*>(arg2_tensor),
arg0_shape, static_cast<int8_t*>(out_tensor),
arg1_shape, arg0_shape,
daxes); arg1_shape,
}; daxes,
round_mode);
};
} }
else if (out[0].get_element_type() == element::u8) else if (out[0].get_element_type() == element::u8)
{ {
functor = [&, arg0_shape, arg1_shape, daxes](CPURuntimeContext* ctx) { functor =
ngraph::runtime::reference::quantize<double>( [&, arg0_shape, arg1_shape, daxes, round_mode](CPURuntimeContext* ctx) {
static_cast<double*>(arg0_tensor), ngraph::runtime::reference::quantize<double>(
static_cast<double*>(arg1_tensor), static_cast<double*>(arg0_tensor),
static_cast<uint8_t*>(arg2_tensor), static_cast<double*>(arg1_tensor),
static_cast<uint8_t*>(out_tensor), static_cast<uint8_t*>(arg2_tensor),
arg0_shape, static_cast<uint8_t*>(out_tensor),
arg1_shape, arg0_shape,
daxes); arg1_shape,
}; daxes,
round_mode);
};
} }
else else
{ {
......
...@@ -4660,7 +4660,7 @@ namespace ngraph ...@@ -4660,7 +4660,7 @@ namespace ngraph
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Quantize) void CPU_Emitter::EMITTER_DECL(ngraph::op::Quantize)
{ {
auto quantize = static_cast<const ngraph::op::Dequantize*>(node); auto quantize = static_cast<const ngraph::op::Quantize*>(node);
writer << "reference::quantize("; writer << "reference::quantize(";
writer << " " << args[0].get_name() << ",\n"; writer << " " << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n"; writer << " " << args[1].get_name() << ",\n";
...@@ -4668,7 +4668,9 @@ namespace ngraph ...@@ -4668,7 +4668,9 @@ namespace ngraph
writer << " " << out[0].get_name() << ",\n"; writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n"; writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(args[1].get_shape()) << "},\n"; writer << " {" << join(args[1].get_shape()) << "},\n";
writer << " {" << join(quantize->get_axes()) << "});\n"; writer << " {" << join(quantize->get_axes()) << "},\n";
writer << " static_cast<op::Quantize::RoundMode>("
<< static_cast<int>(quantize->get_round_mode()) << "));\n";
} }
#undef TI #undef TI
......
...@@ -902,7 +902,8 @@ private: ...@@ -902,7 +902,8 @@ private:
out[0]->get_data_ptr<uint8_t>(), out[0]->get_data_ptr<uint8_t>(),
args[0]->get_shape(), args[0]->get_shape(),
args[1]->get_shape(), args[1]->get_shape(),
quantize->get_axes()); quantize->get_axes(),
quantize->get_round_mode());
} }
else if (type == element::i8) else if (type == element::i8)
{ {
...@@ -912,7 +913,8 @@ private: ...@@ -912,7 +913,8 @@ private:
out[0]->get_data_ptr<int8_t>(), out[0]->get_data_ptr<int8_t>(),
args[0]->get_shape(), args[0]->get_shape(),
args[1]->get_shape(), args[1]->get_shape(),
quantize->get_axes()); quantize->get_axes(),
quantize->get_round_mode());
} }
else else
{ {
......
...@@ -16,11 +16,8 @@ ...@@ -16,11 +16,8 @@
#pragma once #pragma once
#include <cmath>
#include "ngraph/axis_set.hpp"
#include "ngraph/coordinate_transform.hpp" #include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp" #include "ngraph/op/quantize.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +32,8 @@ namespace ngraph ...@@ -35,7 +32,8 @@ namespace ngraph
QUANT* output, QUANT* output,
const Shape& input_shape, const Shape& input_shape,
const Shape& scale_offset_shape, const Shape& scale_offset_shape,
const AxisSet& axes) const AxisSet& axes,
op::Quantize::RoundMode round_mode)
{ {
CoordinateTransform input_transform(input_shape); CoordinateTransform input_transform(input_shape);
CoordinateTransform scale_offset_transform(scale_offset_shape); CoordinateTransform scale_offset_transform(scale_offset_shape);
...@@ -44,11 +42,62 @@ namespace ngraph ...@@ -44,11 +42,62 @@ namespace ngraph
{ {
Coordinate scale_offset_coord = project(input_coord, axes); Coordinate scale_offset_coord = project(input_coord, axes);
// apply scale and offset // apply scale
REAL qvalue = REAL qvalue = input[input_transform.index(input_coord)] /
std::round(input[input_transform.index(input_coord)] / scale[scale_offset_transform.index(scale_offset_coord)];
scale[scale_offset_transform.index(scale_offset_coord)]) +
offset[scale_offset_transform.index(scale_offset_coord)]; // round
if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY ||
round_mode == op::Quantize::RoundMode::HALF_AWAY_FROM_ZERO)
{
auto abs_qvalue = std::fabs(qvalue);
auto abs_qvalue_toward_inf = std::floor(abs_qvalue + 0.5);
qvalue = (qvalue < 0.0) ? -abs_qvalue_toward_inf : abs_qvalue_toward_inf;
}
else if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_ZERO)
{
auto abs_qvalue = std::fabs(qvalue);
auto abs_qvalue_toward_zero = std::ceil(abs_qvalue - 0.5);
qvalue = (qvalue < 0.0) ? -abs_qvalue_toward_zero : abs_qvalue_toward_zero;
}
else if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_UPWARD)
{
qvalue = std::floor(qvalue + 0.5);
}
else if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_DOWNWARD)
{
qvalue = std::ceil(qvalue - 0.5);
}
else if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)
{
auto up_qvalue = std::floor(qvalue + 0.5);
auto dn_qvalue = std::ceil(qvalue - 0.5);
auto rem = std::fmod(up_qvalue, 2.0);
qvalue = (rem == 0.0) ? up_qvalue : dn_qvalue;
}
else if (round_mode == op::Quantize::RoundMode::ROUND_TOWARD_INFINITY)
{
auto abs_qvalue = std::fabs(qvalue);
auto abs_qvalue_toward_inf = std::ceil(abs_qvalue);
qvalue = (qvalue < 0.0) ? -abs_qvalue_toward_inf : abs_qvalue_toward_inf;
}
else if (round_mode == op::Quantize::RoundMode::ROUND_TOWARD_ZERO)
{
auto abs_qvalue = std::fabs(qvalue);
auto abs_qvalue_toward_zero = std::floor(abs_qvalue);
qvalue = (qvalue < 0.0) ? -abs_qvalue_toward_zero : abs_qvalue_toward_zero;
}
else if (round_mode == op::Quantize::RoundMode::ROUND_UP)
{
qvalue = std::ceil(qvalue);
}
else if (round_mode == op::Quantize::RoundMode::ROUND_DOWN)
{
qvalue = std::floor(qvalue);
}
// apply offset
qvalue += offset[scale_offset_transform.index(scale_offset_coord)];
// clamp // clamp
qvalue = std::max<REAL>(qvalue, qvalue = std::max<REAL>(qvalue,
......
This diff is collapsed.
...@@ -10088,40 +10088,6 @@ TEST(type_prop, quantize_offset_shape_mismatch_different_rank_fails) ...@@ -10088,40 +10088,6 @@ TEST(type_prop, quantize_offset_shape_mismatch_different_rank_fails)
} }
} }
TEST(type_prop, quantize_offset_unsupported_round_mode_fails)
{
Shape batch_shape{64, 3, 480, 640};
Shape scale_shape{64, 3};
Shape offset_shape{64, 3};
element::Type unquantized_type = element::f32;
element::Type quantized_type = element::i8;
element::Type batch_type = unquantized_type;
element::Type scale_type = unquantized_type;
element::Type offset_type = quantized_type;
AxisSet axes{0, 1};
auto round_mode = op::Quantize::RoundMode::HALF_TO_EVEN;
auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
auto offset = make_shared<op::Parameter>(offset_type, offset_shape);
try
{
auto quant =
make_shared<op::Quantize>(batch, scale, offset, quantized_type, axes, round_mode);
FAIL() << "Unsupported round mode not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Only RoundMode = HALF_AWAY_FROM_ZERO is supported, for now");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, quantize_partial_all_rank_dynamic_ok) TEST(type_prop, quantize_partial_all_rank_dynamic_ok)
{ {
PartialShape batch_shape{PartialShape::dynamic()}; PartialShape batch_shape{PartialShape::dynamic()};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment