Commit 51104813 authored by Adam Straw's avatar Adam Straw Committed by Scott Cyphers

add support for Quantize round mode (#1859)

* added half_toward_zero; all previous tests passing

* all rounding modes added with unit tests

* fix cpu emitter

* round mode doc

* round out round modes

* doc typo

* using  names for round modes

* use ceil/floor for rounding functions instead of round/nearbyint

* clean up doc

* equidistant
parent e765956a
......@@ -40,11 +40,49 @@ Attributes
+-------------------------------+----------------------------------------------------------------+
| ``axes`` | Axis positions on which ``scale`` and ``offset`` are specified |
+-------------------------------+----------------------------------------------------------------+
| ``round_mode`` | Refer to ``/src/ngraph/op/quantize.hpp`` |
| ``round_mode`` | *ROUND_NEAREST_TOWARD_INFINITY:* |
| | round to nearest integer |
| | in case of two equidistant integers round away from zero e.g. |
| | 2.5 -> 3 |
| | -3.5 -> -4 |
| | |
| | *ROUND_NEAREST_TOWARD_ZERO:* |
| | round to nearest integer |
| | in case of two equidistant integers round toward zero e.g. |
| | 2.5 -> 2 |
| | -3.5 to -3 |
| | |
| | *ROUND_NEAREST_UPWARD:* |
| | round to nearest integer |
| | in case of two equidistant integers round up e.g. |
| | 2.5 to 3 |
| | -3.5 to -3 |
| | |
| | *ROUND_NEAREST_DOWNWARD:* |
| | round to nearest integer |
| | in case of two equidistant integers round down e.g. |
| | 2.5 to 2 |
| | -3.5 to -4 |
| | |
| | *ROUND_NEAREST_TOWARD_EVEN:* |
| | round to nearest integer |
| | in case of two equidistant integers round to even e.g. |
| | 2.5 to 2 |
| | -3.5 to -4 |
| | |
| | *ROUND_TOWARD_INFINITY:* |
| | round to nearest integer away from zero |
| | |
| | *ROUND_TOWARD_ZERO:* |
| | round to nearest integer toward zero |
| | |
| | *ROUND_UP:* |
| | round to nearest integer toward infinity (ceiling) |
| | |
| | *ROUND_DOWN:* |
| | round to nearest integer toward negative infinity (floor) |
+-------------------------------+----------------------------------------------------------------+
Outputs
-------
......
......@@ -44,9 +44,6 @@ void op::Quantize::validate_and_infer_types()
OFFSET
};
NODE_VALIDATION_ASSERT(this, m_round_mode == RoundMode::HALF_AWAY_FROM_ZERO)
<< "Only RoundMode = HALF_AWAY_FROM_ZERO is supported, for now";
NODE_VALIDATION_ASSERT(this, m_type.is_static()) << "Output element type must not be dynamic";
NODE_VALIDATION_ASSERT(this, m_type.is_quantized()) << "Output element type (" << m_type
......
......@@ -32,12 +32,48 @@ namespace ngraph
public:
enum class RoundMode
{
// -3.5 -> 4
// round to nearest integer
// in case of two equidistant integers round away from zero e.g.
// 2.5 -> 3
HALF_AWAY_FROM_ZERO,
// -3.5 -> 4
// 2.5 -> 2 (nearest even)
HALF_TO_EVEN
// -3.5 -> -4
ROUND_NEAREST_TOWARD_INFINITY,
HALF_AWAY_FROM_ZERO, // TF mode for backward compatability
// round to nearest integer
// in case of two equidistant integers round toward zero e.g.
// 2.5 -> 2
// -3.5 -> -3
ROUND_NEAREST_TOWARD_ZERO,
// round to nearest integer
// in case of two equidistant integers round up e.g.
// 2.5 -> 3
// -3.5 -> -3
ROUND_NEAREST_UPWARD,
// round to nearest integer
// in case of two equidistant integers round down e.g.
// 2.5 -> 2
// -3.5 -> -4
ROUND_NEAREST_DOWNWARD,
// round to nearest integer
// in case of two equidistant integers round to even e.g.
// 2.5 -> 2
// -3.5 -> -4
ROUND_NEAREST_TOWARD_EVEN,
// round to nearest integer away from zero
ROUND_TOWARD_INFINITY,
// round to nearest integer toward zero
ROUND_TOWARD_ZERO,
// round to nearest integer toward infinity (ceiling)
ROUND_UP,
// round to nearest integer toward negative infinity (floor)
ROUND_DOWN,
};
/// \brief Constructs a Quantize operation
......@@ -46,7 +82,7 @@ namespace ngraph
/// \param offset offset used for mapping
/// \param type output element type
/// \param axes axis positions on which `scale` and `offset` are specified
/// \param round_mode describes how to perform ROUND function
/// \param round_mode describes how to perform ROUND function (see above)
Quantize(std::shared_ptr<Node> input,
std::shared_ptr<Node> scale,
std::shared_ptr<Node> offset,
......
......@@ -547,7 +547,8 @@ shared_ptr<op::Constant> make_constant_quantize(shared_ptr<op::Constant> constan
out_vec.data(),
constant->get_shape(),
scale->get_shape(),
quant->get_axes());
quant->get_axes(),
quant->get_round_mode());
return make_shared<op::Constant>(quant->get_element_type(), out_shape, out_vec);
}
......
......@@ -142,34 +142,39 @@ namespace ngraph
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto daxes = quantize->get_axes();
op::Quantize::RoundMode round_mode = quantize->get_round_mode();
if (args[0].get_element_type() == element::f32)
{
if (out[0].get_element_type() == element::i8)
{
functor = [&, arg0_shape, arg1_shape, daxes](CPURuntimeContext* ctx) {
ngraph::runtime::reference::quantize<float>(
static_cast<float*>(arg0_tensor),
static_cast<float*>(arg1_tensor),
static_cast<int8_t*>(arg2_tensor),
static_cast<int8_t*>(out_tensor),
arg0_shape,
arg1_shape,
daxes);
};
functor =
[&, arg0_shape, arg1_shape, daxes, round_mode](CPURuntimeContext* ctx) {
ngraph::runtime::reference::quantize<float>(
static_cast<float*>(arg0_tensor),
static_cast<float*>(arg1_tensor),
static_cast<int8_t*>(arg2_tensor),
static_cast<int8_t*>(out_tensor),
arg0_shape,
arg1_shape,
daxes,
round_mode);
};
}
else if (out[0].get_element_type() == element::u8)
{
functor = [&, arg0_shape, arg1_shape, daxes](CPURuntimeContext* ctx) {
ngraph::runtime::reference::quantize<float>(
static_cast<float*>(arg0_tensor),
static_cast<float*>(arg1_tensor),
static_cast<uint8_t*>(arg2_tensor),
static_cast<uint8_t*>(out_tensor),
arg0_shape,
arg1_shape,
daxes);
};
functor =
[&, arg0_shape, arg1_shape, daxes, round_mode](CPURuntimeContext* ctx) {
ngraph::runtime::reference::quantize<float>(
static_cast<float*>(arg0_tensor),
static_cast<float*>(arg1_tensor),
static_cast<uint8_t*>(arg2_tensor),
static_cast<uint8_t*>(out_tensor),
arg0_shape,
arg1_shape,
daxes,
round_mode);
};
}
else
{
......@@ -180,29 +185,33 @@ namespace ngraph
{
if (out[0].get_element_type() == element::i8)
{
functor = [&, arg0_shape, arg1_shape, daxes](CPURuntimeContext* ctx) {
ngraph::runtime::reference::quantize<double>(
static_cast<double*>(arg0_tensor),
static_cast<double*>(arg1_tensor),
static_cast<int8_t*>(arg2_tensor),
static_cast<int8_t*>(out_tensor),
arg0_shape,
arg1_shape,
daxes);
};
functor =
[&, arg0_shape, arg1_shape, daxes, round_mode](CPURuntimeContext* ctx) {
ngraph::runtime::reference::quantize<double>(
static_cast<double*>(arg0_tensor),
static_cast<double*>(arg1_tensor),
static_cast<int8_t*>(arg2_tensor),
static_cast<int8_t*>(out_tensor),
arg0_shape,
arg1_shape,
daxes,
round_mode);
};
}
else if (out[0].get_element_type() == element::u8)
{
functor = [&, arg0_shape, arg1_shape, daxes](CPURuntimeContext* ctx) {
ngraph::runtime::reference::quantize<double>(
static_cast<double*>(arg0_tensor),
static_cast<double*>(arg1_tensor),
static_cast<uint8_t*>(arg2_tensor),
static_cast<uint8_t*>(out_tensor),
arg0_shape,
arg1_shape,
daxes);
};
functor =
[&, arg0_shape, arg1_shape, daxes, round_mode](CPURuntimeContext* ctx) {
ngraph::runtime::reference::quantize<double>(
static_cast<double*>(arg0_tensor),
static_cast<double*>(arg1_tensor),
static_cast<uint8_t*>(arg2_tensor),
static_cast<uint8_t*>(out_tensor),
arg0_shape,
arg1_shape,
daxes,
round_mode);
};
}
else
{
......
......@@ -4660,7 +4660,7 @@ namespace ngraph
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Quantize)
{
auto quantize = static_cast<const ngraph::op::Dequantize*>(node);
auto quantize = static_cast<const ngraph::op::Quantize*>(node);
writer << "reference::quantize(";
writer << " " << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
......@@ -4668,7 +4668,9 @@ namespace ngraph
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(args[1].get_shape()) << "},\n";
writer << " {" << join(quantize->get_axes()) << "});\n";
writer << " {" << join(quantize->get_axes()) << "},\n";
writer << " static_cast<op::Quantize::RoundMode>("
<< static_cast<int>(quantize->get_round_mode()) << "));\n";
}
#undef TI
......
......@@ -902,7 +902,8 @@ private:
out[0]->get_data_ptr<uint8_t>(),
args[0]->get_shape(),
args[1]->get_shape(),
quantize->get_axes());
quantize->get_axes(),
quantize->get_round_mode());
}
else if (type == element::i8)
{
......@@ -912,7 +913,8 @@ private:
out[0]->get_data_ptr<int8_t>(),
args[0]->get_shape(),
args[1]->get_shape(),
quantize->get_axes());
quantize->get_axes(),
quantize->get_round_mode());
}
else
{
......
......@@ -16,11 +16,8 @@
#pragma once
#include <cmath>
#include "ngraph/axis_set.hpp"
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
#include "ngraph/op/quantize.hpp"
namespace ngraph
{
......@@ -35,7 +32,8 @@ namespace ngraph
QUANT* output,
const Shape& input_shape,
const Shape& scale_offset_shape,
const AxisSet& axes)
const AxisSet& axes,
op::Quantize::RoundMode round_mode)
{
CoordinateTransform input_transform(input_shape);
CoordinateTransform scale_offset_transform(scale_offset_shape);
......@@ -44,11 +42,62 @@ namespace ngraph
{
Coordinate scale_offset_coord = project(input_coord, axes);
// apply scale and offset
REAL qvalue =
std::round(input[input_transform.index(input_coord)] /
scale[scale_offset_transform.index(scale_offset_coord)]) +
offset[scale_offset_transform.index(scale_offset_coord)];
// apply scale
REAL qvalue = input[input_transform.index(input_coord)] /
scale[scale_offset_transform.index(scale_offset_coord)];
// round
if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY ||
round_mode == op::Quantize::RoundMode::HALF_AWAY_FROM_ZERO)
{
auto abs_qvalue = std::fabs(qvalue);
auto abs_qvalue_toward_inf = std::floor(abs_qvalue + 0.5);
qvalue = (qvalue < 0.0) ? -abs_qvalue_toward_inf : abs_qvalue_toward_inf;
}
else if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_ZERO)
{
auto abs_qvalue = std::fabs(qvalue);
auto abs_qvalue_toward_zero = std::ceil(abs_qvalue - 0.5);
qvalue = (qvalue < 0.0) ? -abs_qvalue_toward_zero : abs_qvalue_toward_zero;
}
else if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_UPWARD)
{
qvalue = std::floor(qvalue + 0.5);
}
else if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_DOWNWARD)
{
qvalue = std::ceil(qvalue - 0.5);
}
else if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)
{
auto up_qvalue = std::floor(qvalue + 0.5);
auto dn_qvalue = std::ceil(qvalue - 0.5);
auto rem = std::fmod(up_qvalue, 2.0);
qvalue = (rem == 0.0) ? up_qvalue : dn_qvalue;
}
else if (round_mode == op::Quantize::RoundMode::ROUND_TOWARD_INFINITY)
{
auto abs_qvalue = std::fabs(qvalue);
auto abs_qvalue_toward_inf = std::ceil(abs_qvalue);
qvalue = (qvalue < 0.0) ? -abs_qvalue_toward_inf : abs_qvalue_toward_inf;
}
else if (round_mode == op::Quantize::RoundMode::ROUND_TOWARD_ZERO)
{
auto abs_qvalue = std::fabs(qvalue);
auto abs_qvalue_toward_zero = std::floor(abs_qvalue);
qvalue = (qvalue < 0.0) ? -abs_qvalue_toward_zero : abs_qvalue_toward_zero;
}
else if (round_mode == op::Quantize::RoundMode::ROUND_UP)
{
qvalue = std::ceil(qvalue);
}
else if (round_mode == op::Quantize::RoundMode::ROUND_DOWN)
{
qvalue = std::floor(qvalue);
}
// apply offset
qvalue += offset[scale_offset_transform.index(scale_offset_coord)];
// clamp
qvalue = std::max<REAL>(qvalue,
......
This diff is collapsed.
......@@ -10088,40 +10088,6 @@ TEST(type_prop, quantize_offset_shape_mismatch_different_rank_fails)
}
}
TEST(type_prop, quantize_offset_unsupported_round_mode_fails)
{
Shape batch_shape{64, 3, 480, 640};
Shape scale_shape{64, 3};
Shape offset_shape{64, 3};
element::Type unquantized_type = element::f32;
element::Type quantized_type = element::i8;
element::Type batch_type = unquantized_type;
element::Type scale_type = unquantized_type;
element::Type offset_type = quantized_type;
AxisSet axes{0, 1};
auto round_mode = op::Quantize::RoundMode::HALF_TO_EVEN;
auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
auto offset = make_shared<op::Parameter>(offset_type, offset_shape);
try
{
auto quant =
make_shared<op::Quantize>(batch, scale, offset, quantized_type, axes, round_mode);
FAIL() << "Unsupported round mode not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Only RoundMode = HALF_AWAY_FROM_ZERO is supported, for now");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, quantize_partial_all_rank_dynamic_ok)
{
PartialShape batch_shape{PartialShape::dynamic()};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment