Unverified Commit 34499001 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Quantization conversion from nodes to outputs (#3316)

parent 8eb63379
This diff is collapsed.
This diff is collapsed.
...@@ -36,25 +36,25 @@ namespace ngraph ...@@ -36,25 +36,25 @@ namespace ngraph
{ {
namespace quantization namespace quantization
{ {
shared_ptr<Node> QuantizedLinearConvolutionBias(const shared_ptr<Node>& input, shared_ptr<Node> QuantizedLinearConvolutionBias(const Output<Node>& input,
const shared_ptr<Node>& filter, const Output<Node>& filter,
const shared_ptr<Node>& bias, const Output<Node>& bias,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const shared_ptr<Node>& input_scale, const Output<Node>& input_scale,
const shared_ptr<Node>& filter_scale, const Output<Node>& filter_scale,
const shared_ptr<Node>& output_scale) const Output<Node>& output_scale)
{ {
// TODO: need to establish cross-nGraph view of scale (mult or div) // TODO: need to establish cross-nGraph view of scale (mult or div)
auto requantization_scale = (input_scale * filter_scale) / output_scale; auto requantization_scale = (input_scale * filter_scale) / output_scale;
auto mybias = bias; auto mybias = bias;
if (bias->get_element_type() != element::i32) if (bias.get_element_type() != element::i32)
{ {
const auto zero = make_constant(element::i32, input_scale->get_shape(), 0); const auto zero = make_constant(element::i32, input_scale.get_shape(), 0);
const AxisSet quantization_axes; const AxisSet quantization_axes;
const auto bias_scale = input_scale * filter_scale; const auto bias_scale = input_scale * filter_scale;
op::Quantize::RoundMode round_mode = op::Quantize::RoundMode round_mode =
......
...@@ -26,17 +26,17 @@ namespace ngraph ...@@ -26,17 +26,17 @@ namespace ngraph
namespace quantization namespace quantization
{ {
std::shared_ptr<Node> std::shared_ptr<Node>
QuantizedLinearConvolutionBias(const std::shared_ptr<Node>& input, QuantizedLinearConvolutionBias(const Output<Node>& input,
const std::shared_ptr<Node>& filter, const Output<Node>& filter,
const std::shared_ptr<Node>& bias, const Output<Node>& bias,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node>& input_scale, const Output<Node>& input_scale,
const std::shared_ptr<Node>& filter_scale, const Output<Node>& filter_scale,
const std::shared_ptr<Node>& output_scale); const Output<Node>& output_scale);
} }
} }
} }
...@@ -39,14 +39,14 @@ namespace ngraph ...@@ -39,14 +39,14 @@ namespace ngraph
{ {
// TODO: this code is falling back to fp32 dot // TODO: this code is falling back to fp32 dot
// 1) add support in reference kernel for zero point // 1) add support in reference kernel for zero point
shared_ptr<Node> QuantizedLinearMatmul(const shared_ptr<Node>& input0, shared_ptr<Node> QuantizedLinearMatmul(const Output<Node>& input0,
const shared_ptr<Node>& input1, const Output<Node>& input1,
const shared_ptr<Node>& input0_scale, const Output<Node>& input0_scale,
const shared_ptr<Node>& input0_zero_point, const Output<Node>& input0_zero_point,
const shared_ptr<Node>& input1_scale, const Output<Node>& input1_scale,
const shared_ptr<Node>& input1_zero_point, const Output<Node>& input1_zero_point,
const shared_ptr<Node>& output_scale, const Output<Node>& output_scale,
const shared_ptr<Node>& output_zero_point) const Output<Node>& output_zero_point)
{ {
// Check if zero point is constant and zero // Check if zero point is constant and zero
if (ngraph::is_zero(input0_zero_point) && ngraph::is_zero(input1_zero_point) && if (ngraph::is_zero(input0_zero_point) && ngraph::is_zero(input1_zero_point) &&
...@@ -62,13 +62,13 @@ namespace ngraph ...@@ -62,13 +62,13 @@ namespace ngraph
auto dq_input0 = make_shared<op::Dequantize>(input0, auto dq_input0 = make_shared<op::Dequantize>(input0,
input0_scale, input0_scale,
input0_zero_point, input0_zero_point,
input0_scale->get_element_type(), input0_scale.get_element_type(),
axes); axes);
auto dq_input1 = make_shared<op::Dequantize>(input1, auto dq_input1 = make_shared<op::Dequantize>(input1,
input1_scale, input1_scale,
input1_zero_point, input1_zero_point,
input1_scale->get_element_type(), input1_scale.get_element_type(),
axes); axes);
auto dot = make_shared<op::Dot>(dq_input0, dq_input1, 1); auto dot = make_shared<op::Dot>(dq_input0, dq_input1, 1);
...@@ -76,24 +76,23 @@ namespace ngraph ...@@ -76,24 +76,23 @@ namespace ngraph
dot, dot,
output_scale, output_scale,
output_zero_point, output_zero_point,
output_zero_point->get_element_type(), output_zero_point.get_element_type(),
axes, axes,
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN); op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
} }
} }
shared_ptr<Node> QuantizedLinearMatmulInteger(const shared_ptr<Node>& input0, shared_ptr<Node> QuantizedLinearMatmulInteger(const Output<Node>& input0,
const shared_ptr<Node>& input1) const Output<Node>& input1)
{ {
auto output_scale = make_constant(element::f32, Shape{}, 1); auto output_scale = make_constant(element::f32, Shape{}, 1);
return make_shared<op::QuantizedDot>(input0, input1, output_scale, false, false); return make_shared<op::QuantizedDot>(input0, input1, output_scale, false, false);
} }
shared_ptr<Node> shared_ptr<Node> QuantizedLinearMatmulInteger(const Output<Node>& input0,
QuantizedLinearMatmulInteger(const std::shared_ptr<Node>& input0, const Output<Node>& input1,
const std::shared_ptr<Node>& input1, const Output<Node>& input0_zero_point,
const std::shared_ptr<Node>& input0_zero_point, const Output<Node>& input1_zero_point)
const std::shared_ptr<Node>& input1_zero_point)
{ {
// Check if zero points are constant and zero // Check if zero points are constant and zero
if (ngraph::is_zero(input0_zero_point) && ngraph::is_zero(input1_zero_point)) if (ngraph::is_zero(input0_zero_point) && ngraph::is_zero(input1_zero_point))
......
...@@ -25,24 +25,23 @@ namespace ngraph ...@@ -25,24 +25,23 @@ namespace ngraph
{ {
namespace quantization namespace quantization
{ {
std::shared_ptr<Node> std::shared_ptr<Node> QuantizedLinearMatmul(const Output<Node>& input0,
QuantizedLinearMatmul(const std::shared_ptr<Node>& input0, const Output<Node>& input1,
const std::shared_ptr<Node>& input1, const Output<Node>& input0_scale,
const std::shared_ptr<Node>& input0_scale, const Output<Node>& input0_zero_point,
const std::shared_ptr<Node>& input0_zero_point, const Output<Node>& input1_scale,
const std::shared_ptr<Node>& input1_scale, const Output<Node>& input1_zero_point,
const std::shared_ptr<Node>& input1_zero_point, const Output<Node>& output_scale,
const std::shared_ptr<Node>& output_scale, const Output<Node>& output_zero_point);
const std::shared_ptr<Node>& output_zero_point);
std::shared_ptr<Node> QuantizedLinearMatmulInteger(const std::shared_ptr<Node>& input0, std::shared_ptr<Node> QuantizedLinearMatmulInteger(const Output<Node>& input0,
const std::shared_ptr<Node>& input1); const Output<Node>& input1);
std::shared_ptr<Node> std::shared_ptr<Node>
QuantizedLinearMatmulInteger(const std::shared_ptr<Node>& input0, QuantizedLinearMatmulInteger(const Output<Node>& input0,
const std::shared_ptr<Node>& input1, const Output<Node>& input1,
const std::shared_ptr<Node>& input0_zero_point, const Output<Node>& input0_zero_point,
const std::shared_ptr<Node>& input1_zero_point); const Output<Node>& input1_zero_point);
} }
} }
} }
This diff is collapsed.
...@@ -22,26 +22,26 @@ namespace ngraph ...@@ -22,26 +22,26 @@ namespace ngraph
{ {
namespace quantization_utils namespace quantization_utils
{ {
std::shared_ptr<Node> max_abs(std::shared_ptr<Node> a, std::shared_ptr<Node> b) std::shared_ptr<Node> max_abs(const Output<Node>& a, const Output<Node>& b)
{ {
auto abs_a = std::make_shared<op::Abs>(a); auto abs_a = std::make_shared<op::Abs>(a);
auto abs_b = std::make_shared<op::Abs>(b); auto abs_b = std::make_shared<op::Abs>(b);
return std::make_shared<op::Maximum>(abs_a, abs_b); return std::make_shared<op::Maximum>(abs_a, abs_b);
} }
std::shared_ptr<Node> get_scale(std::shared_ptr<Node> input_min_range, std::shared_ptr<Node> get_scale(const Output<Node>& input_min_range,
std::shared_ptr<Node> input_max_range, const Output<Node>& input_max_range,
const ngraph::element::Type& quant_type, const ngraph::element::Type& quant_type,
bool bump_by_eps) bool bump_by_eps)
{ {
auto type = input_min_range->get_element_type(); auto type = input_min_range.get_element_type();
if (type != input_max_range->get_element_type()) if (type != input_max_range.get_element_type())
{ {
throw ngraph_error("get_scale: min and max must have same type"); throw ngraph_error("get_scale: min and max must have same type");
} }
auto shape = input_min_range->get_shape(); auto shape = input_min_range.get_shape();
if (shape != input_max_range->get_shape()) if (shape != input_max_range.get_shape())
{ {
throw ngraph_error("get_scale: min and max must have same shape"); throw ngraph_error("get_scale: min and max must have same shape");
} }
......
...@@ -37,10 +37,10 @@ namespace ngraph ...@@ -37,10 +37,10 @@ namespace ngraph
{ {
namespace quantization_utils namespace quantization_utils
{ {
std::shared_ptr<Node> max_abs(std::shared_ptr<Node> a, std::shared_ptr<Node> b); std::shared_ptr<Node> max_abs(const Output<Node>& a, const Output<Node>& b);
std::shared_ptr<Node> get_scale(std::shared_ptr<Node> input_min_range, std::shared_ptr<Node> get_scale(const Output<Node>& input_min_range,
std::shared_ptr<Node> input_max_range, const Output<Node>& input_max_range,
const ngraph::element::Type& quant_type, const ngraph::element::Type& quant_type,
bool bump_by_eps = false); bool bump_by_eps = false);
} }
......
...@@ -26,35 +26,34 @@ namespace ngraph ...@@ -26,35 +26,34 @@ namespace ngraph
{ {
namespace builder namespace builder
{ {
shared_ptr<Node> QuantizedConvolutionBuilder(const shared_ptr<Node>& input, shared_ptr<Node> QuantizedConvolutionBuilder(const Output<Node>& input,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const shared_ptr<Node>& min_input, const Output<Node>& min_input,
const shared_ptr<Node>& max_input, const Output<Node>& max_input,
const shared_ptr<Node>& min_filter, const Output<Node>& min_filter,
const shared_ptr<Node>& max_filter, const Output<Node>& max_filter,
const shared_ptr<Node>& min_output, const Output<Node>& min_output,
const shared_ptr<Node>& max_output, const Output<Node>& max_output,
const ngraph::element::Type& output_type, const ngraph::element::Type& output_type,
const ngraph::AxisSet& input_axes, const ngraph::AxisSet& input_axes,
const ngraph::AxisSet& filter_axes, const ngraph::AxisSet& filter_axes,
const ngraph::AxisSet& output_axes) const ngraph::AxisSet& output_axes)
{ {
auto input_scale = auto input_scale =
quantization_utils::get_scale(min_input, max_input, input->get_element_type()); quantization_utils::get_scale(min_input, max_input, input.get_element_type());
auto filter_scale = auto filter_scale =
quantization_utils::get_scale(min_filter, max_filter, filters->get_element_type()); quantization_utils::get_scale(min_filter, max_filter, filters.get_element_type());
auto output_scale = quantization_utils::get_scale(min_output, max_output, output_type); auto output_scale = quantization_utils::get_scale(min_output, max_output, output_type);
// TODO: Check for this later // TODO: Check for this later
// For Builders the zero point is assumed to be zero (for now) // For Builders the zero point is assumed to be zero (for now)
auto input_zero_point = op::Constant::create(input->get_element_type(), Shape{}, {0}); auto input_zero_point = op::Constant::create(input.get_element_type(), Shape{}, {0});
auto filter_zero_point = auto filter_zero_point = op::Constant::create(filters.get_element_type(), Shape{}, {0});
op::Constant::create(filters->get_element_type(), Shape{}, {0});
return make_shared<op::QuantizedConvolution>( return make_shared<op::QuantizedConvolution>(
input, input,
......
...@@ -26,19 +26,19 @@ namespace ngraph ...@@ -26,19 +26,19 @@ namespace ngraph
namespace builder namespace builder
{ {
std::shared_ptr<Node> std::shared_ptr<Node>
QuantizedConvolutionBuilder(const std::shared_ptr<Node>& input, QuantizedConvolutionBuilder(const Output<Node>& input,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node>& min_input, const Output<Node>& min_input,
const std::shared_ptr<Node>& max_input, const Output<Node>& max_input,
const std::shared_ptr<Node>& min_filter, const Output<Node>& min_filter,
const std::shared_ptr<Node>& max_filter, const Output<Node>& max_filter,
const std::shared_ptr<Node>& min_output, const Output<Node>& min_output,
const std::shared_ptr<Node>& max_output, const Output<Node>& max_output,
const ngraph::element::Type& output_type, const ngraph::element::Type& output_type,
const ngraph::AxisSet& input_axes = ngraph::AxisSet{}, const ngraph::AxisSet& input_axes = ngraph::AxisSet{},
const ngraph::AxisSet& filter_axes = ngraph::AxisSet{}, const ngraph::AxisSet& filter_axes = ngraph::AxisSet{},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment