Unverified Commit 34499001 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Quantization conversion from nodes to outputs (#3316)

parent 8eb63379
...@@ -33,27 +33,27 @@ namespace ngraph ...@@ -33,27 +33,27 @@ namespace ngraph
{ {
namespace builder namespace builder
{ {
shared_ptr<Node> ScaledQuantize(const shared_ptr<Node>& input, shared_ptr<Node> ScaledQuantize(const Output<Node>& input,
const shared_ptr<Node>& min, const Output<Node>& min,
const shared_ptr<Node>& max, const Output<Node>& max,
const ngraph::element::Type& quant_type, const ngraph::element::Type& quant_type,
const ngraph::AxisSet& axes, const ngraph::AxisSet& axes,
op::Quantize::RoundMode round_mode) op::Quantize::RoundMode round_mode)
{ {
auto real_type = input->get_element_type(); auto real_type = input.get_element_type();
if (min->get_element_type() != real_type) if (min.get_element_type() != real_type)
{ {
throw ngraph_error("ScaledQuantize: min must match input type"); throw ngraph_error("ScaledQuantize: min must match input type");
} }
if (max->get_element_type() != real_type) if (max.get_element_type() != real_type)
{ {
throw ngraph_error("ScaledQuantize: max must match input type"); throw ngraph_error("ScaledQuantize: max must match input type");
} }
auto shape = min->get_shape(); auto shape = min.get_shape();
if (shape != max->get_shape()) if (shape != max.get_shape())
{ {
throw ngraph_error("ScaledQuantize: min and max must have same shape"); throw ngraph_error("ScaledQuantize: min and max must have same shape");
} }
...@@ -63,26 +63,26 @@ namespace ngraph ...@@ -63,26 +63,26 @@ namespace ngraph
return make_shared<op::Quantize>(input, scale, zero, quant_type, axes, round_mode); return make_shared<op::Quantize>(input, scale, zero, quant_type, axes, round_mode);
} }
shared_ptr<Node> ScaledDequantize(const shared_ptr<Node>& input, shared_ptr<Node> ScaledDequantize(const Output<Node>& input,
const shared_ptr<Node>& min, const Output<Node>& min,
const shared_ptr<Node>& max, const Output<Node>& max,
const ngraph::element::Type& real_type, const ngraph::element::Type& real_type,
const ngraph::AxisSet& axes) const ngraph::AxisSet& axes)
{ {
auto quant_type = input->get_element_type(); auto quant_type = input.get_element_type();
if (min->get_element_type() != real_type) if (min.get_element_type() != real_type)
{ {
throw ngraph_error("ScaledDequantize: min must match output type"); throw ngraph_error("ScaledDequantize: min must match output type");
} }
if (max->get_element_type() != real_type) if (max.get_element_type() != real_type)
{ {
throw ngraph_error("ScaledDequantize: max must match output type"); throw ngraph_error("ScaledDequantize: max must match output type");
} }
auto shape = min->get_shape(); auto shape = min.get_shape();
if (shape != max->get_shape()) if (shape != max.get_shape())
{ {
throw ngraph_error("ScaledDequantize: min and max must have same shape"); throw ngraph_error("ScaledDequantize: min and max must have same shape");
} }
...@@ -127,14 +127,14 @@ namespace ngraph ...@@ -127,14 +127,14 @@ namespace ngraph
return make_shared<op::QuantizedConcat>(rescaled_args, concatenation_axis); return make_shared<op::QuantizedConcat>(rescaled_args, concatenation_axis);
} }
shared_ptr<Node> ScaledQuantizedAvgPool(const shared_ptr<Node>& input, shared_ptr<Node> ScaledQuantizedAvgPool(const Output<Node>& input,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
bool include_padding_in_avg_computation, bool include_padding_in_avg_computation,
const shared_ptr<Node>& min, const Output<Node>& min,
const shared_ptr<Node>& max) const Output<Node>& max)
{ {
return make_shared<op::QuantizedAvgPool>(input, return make_shared<op::QuantizedAvgPool>(input,
window_shape, window_shape,
...@@ -144,20 +144,20 @@ namespace ngraph ...@@ -144,20 +144,20 @@ namespace ngraph
include_padding_in_avg_computation); include_padding_in_avg_computation);
} }
shared_ptr<Node> ScaledQuantizedConvolutionBias(const shared_ptr<Node>& input, shared_ptr<Node> ScaledQuantizedConvolutionBias(const Output<Node>& input,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const shared_ptr<Node>& bias, const Output<Node>& bias,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const shared_ptr<Node>& min_input, const Output<Node>& min_input,
const shared_ptr<Node>& max_input, const Output<Node>& max_input,
const shared_ptr<Node>& min_filter, const Output<Node>& min_filter,
const shared_ptr<Node>& max_filter, const Output<Node>& max_filter,
const shared_ptr<Node>& min_output, const Output<Node>& min_output,
const shared_ptr<Node>& max_output, const Output<Node>& max_output,
const bool with_relu) const bool with_relu)
{ {
auto output_et = with_relu ? element::u8 : element::i8; auto output_et = with_relu ? element::u8 : element::i8;
...@@ -165,9 +165,9 @@ namespace ngraph ...@@ -165,9 +165,9 @@ namespace ngraph
min_input, max_input, min_filter, max_filter, min_output, max_output, output_et); min_input, max_input, min_filter, max_filter, min_output, max_output, output_et);
auto mybias = bias; auto mybias = bias;
if (bias->get_element_type() != element::i32) if (bias.get_element_type() != element::i32)
{ {
auto zero = make_constant(element::i32, min_input->get_shape(), 0); auto zero = make_constant(element::i32, min_input.get_shape(), 0);
AxisSet quantization_axes; AxisSet quantization_axes;
auto bias_scale = auto bias_scale =
quantization_util::get_bias_scale(min_input, max_input, min_filter, max_filter); quantization_util::get_bias_scale(min_input, max_input, min_filter, max_filter);
...@@ -190,19 +190,19 @@ namespace ngraph ...@@ -190,19 +190,19 @@ namespace ngraph
with_relu); with_relu);
} }
shared_ptr<Node> ScaledQuantizedConvolutionRelu(const shared_ptr<Node>& input, shared_ptr<Node> ScaledQuantizedConvolutionRelu(const Output<Node>& input,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const shared_ptr<Node>& min_input, const Output<Node>& min_input,
const shared_ptr<Node>& max_input, const Output<Node>& max_input,
const shared_ptr<Node>& min_filter, const Output<Node>& min_filter,
const shared_ptr<Node>& max_filter, const Output<Node>& max_filter,
const shared_ptr<Node>& min_output, const Output<Node>& min_output,
const shared_ptr<Node>& max_output) const Output<Node>& max_output)
{ {
auto requantization_scale = quantization_util::get_scale( auto requantization_scale = quantization_util::get_scale(
min_input, max_input, min_filter, max_filter, min_output, max_output, element::u8); min_input, max_input, min_filter, max_filter, min_output, max_output, element::u8);
...@@ -217,35 +217,35 @@ namespace ngraph ...@@ -217,35 +217,35 @@ namespace ngraph
requantization_scale); requantization_scale);
} }
shared_ptr<Node> ScaledQuantizedMaxPool(const shared_ptr<Node>& input, shared_ptr<Node> ScaledQuantizedMaxPool(const Output<Node>& input,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
const shared_ptr<Node>& min, const Output<Node>& min,
const shared_ptr<Node>& max) const Output<Node>& max)
{ {
return make_shared<op::QuantizedMaxPool>( return make_shared<op::QuantizedMaxPool>(
input, window_shape, window_movement_strides, padding_below, padding_above); input, window_shape, window_movement_strides, padding_below, padding_above);
} }
shared_ptr<Node> ScaledQuantizedConvolutionBiasAdd(const shared_ptr<Node>& input, shared_ptr<Node> ScaledQuantizedConvolutionBiasAdd(const Output<Node>& input,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const shared_ptr<Node>& bias, const Output<Node>& bias,
const shared_ptr<Node>& sum_input, const Output<Node>& sum_input,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const shared_ptr<Node>& min_input, const Output<Node>& min_input,
const shared_ptr<Node>& max_input, const Output<Node>& max_input,
const shared_ptr<Node>& min_filter, const Output<Node>& min_filter,
const shared_ptr<Node>& max_filter, const Output<Node>& max_filter,
const shared_ptr<Node>& min_output, const Output<Node>& min_output,
const shared_ptr<Node>& max_output, const Output<Node>& max_output,
const shared_ptr<Node>& min_sum_input, const Output<Node>& min_sum_input,
const shared_ptr<Node>& max_sum_input, const Output<Node>& max_sum_input,
const bool with_relu) const bool with_relu)
{ {
auto output_et = with_relu ? element::u8 : element::i8; auto output_et = with_relu ? element::u8 : element::i8;
...@@ -256,9 +256,9 @@ namespace ngraph ...@@ -256,9 +256,9 @@ namespace ngraph
min_output, max_output, min_sum_input, max_sum_input); min_output, max_output, min_sum_input, max_sum_input);
auto mybias = bias; auto mybias = bias;
if (bias->get_element_type() != element::i32) if (bias.get_element_type() != element::i32)
{ {
auto zero = make_constant(element::i32, min_input->get_shape(), 0); auto zero = make_constant(element::i32, min_input.get_shape(), 0);
AxisSet quantization_axes; AxisSet quantization_axes;
auto bias_scale = auto bias_scale =
quantization_util::get_bias_scale(min_input, max_input, min_filter, max_filter); quantization_util::get_bias_scale(min_input, max_input, min_filter, max_filter);
...@@ -284,23 +284,23 @@ namespace ngraph ...@@ -284,23 +284,23 @@ namespace ngraph
} }
shared_ptr<Node> shared_ptr<Node>
ScaledQuantizedConvolutionBiasSignedAdd(const shared_ptr<Node>& input, ScaledQuantizedConvolutionBiasSignedAdd(const Output<Node>& input,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const shared_ptr<Node>& bias, const Output<Node>& bias,
const shared_ptr<Node>& sum_input, const Output<Node>& sum_input,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const shared_ptr<Node>& min_input, const Output<Node>& min_input,
const shared_ptr<Node>& max_input, const Output<Node>& max_input,
const shared_ptr<Node>& min_filter, const Output<Node>& min_filter,
const shared_ptr<Node>& max_filter, const Output<Node>& max_filter,
const shared_ptr<Node>& min_output, const Output<Node>& min_output,
const shared_ptr<Node>& max_output, const Output<Node>& max_output,
const shared_ptr<Node>& min_sum_input, const Output<Node>& min_sum_input,
const shared_ptr<Node>& max_sum_input, const Output<Node>& max_sum_input,
const bool with_relu) const bool with_relu)
{ {
auto output_et = with_relu ? element::u8 : element::i8; auto output_et = with_relu ? element::u8 : element::i8;
...@@ -317,9 +317,9 @@ namespace ngraph ...@@ -317,9 +317,9 @@ namespace ngraph
} }
auto mybias = bias; auto mybias = bias;
if (bias->get_element_type() != element::i32) if (bias.get_element_type() != element::i32)
{ {
auto zero = make_constant(element::i32, min_input->get_shape(), 0); auto zero = make_constant(element::i32, min_input.get_shape(), 0);
AxisSet quantization_axes; AxisSet quantization_axes;
auto bias_scale = auto bias_scale =
quantization_util::get_bias_scale(min_input, max_input, min_filter, max_filter); quantization_util::get_bias_scale(min_input, max_input, min_filter, max_filter);
...@@ -344,15 +344,15 @@ namespace ngraph ...@@ -344,15 +344,15 @@ namespace ngraph
return make_shared<op::Convert>(qconv, element::u8); return make_shared<op::Convert>(qconv, element::u8);
} }
shared_ptr<Node> ScaledQuantizedDotBias(const shared_ptr<Node>& input, shared_ptr<Node> ScaledQuantizedDotBias(const Output<Node>& input,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const shared_ptr<Node>& bias, const Output<Node>& bias,
const shared_ptr<Node>& min_input, const Output<Node>& min_input,
const shared_ptr<Node>& max_input, const Output<Node>& max_input,
const shared_ptr<Node>& min_filter, const Output<Node>& min_filter,
const shared_ptr<Node>& max_filter, const Output<Node>& max_filter,
const shared_ptr<Node>& min_output, const Output<Node>& min_output,
const shared_ptr<Node>& max_output, const Output<Node>& max_output,
const bool requantize, const bool requantize,
const bool with_relu) const bool with_relu)
{ {
...@@ -363,14 +363,14 @@ namespace ngraph ...@@ -363,14 +363,14 @@ namespace ngraph
max_filter, max_filter,
min_output, min_output,
max_output, max_output,
input->get_element_type(), input.get_element_type(),
with_relu ? element::u8 : element::i8, with_relu ? element::u8 : element::i8,
requantize); requantize);
auto mybias = bias; auto mybias = bias;
if (bias->get_element_type() != element::i32) if (bias.get_element_type() != element::i32)
{ {
auto zero = make_constant(element::i32, min_input->get_shape(), 0); auto zero = make_constant(element::i32, min_input.get_shape(), 0);
AxisSet quantization_axes; AxisSet quantization_axes;
auto bias_scale = auto bias_scale =
quantization_util::get_bias_scale(min_input, max_input, min_filter, max_filter); quantization_util::get_bias_scale(min_input, max_input, min_filter, max_filter);
...@@ -384,14 +384,14 @@ namespace ngraph ...@@ -384,14 +384,14 @@ namespace ngraph
input, filters, mybias, requantization_scale, requantize, with_relu); input, filters, mybias, requantization_scale, requantize, with_relu);
} }
shared_ptr<Node> ScaledQuantizedDot(const shared_ptr<Node>& input, shared_ptr<Node> ScaledQuantizedDot(const Output<Node>& input,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const shared_ptr<Node>& min_input, const Output<Node>& min_input,
const shared_ptr<Node>& max_input, const Output<Node>& max_input,
const shared_ptr<Node>& min_filter, const Output<Node>& min_filter,
const shared_ptr<Node>& max_filter, const Output<Node>& max_filter,
const shared_ptr<Node>& min_output, const Output<Node>& min_output,
const shared_ptr<Node>& max_output, const Output<Node>& max_output,
const bool requantize, const bool requantize,
const bool with_relu) const bool with_relu)
{ {
...@@ -402,7 +402,7 @@ namespace ngraph ...@@ -402,7 +402,7 @@ namespace ngraph
max_filter, max_filter,
min_output, min_output,
max_output, max_output,
input->get_element_type(), input.get_element_type(),
with_relu ? element::u8 : element::i8, with_relu ? element::u8 : element::i8,
requantize); requantize);
return make_shared<op::QuantizedDot>( return make_shared<op::QuantizedDot>(
......
...@@ -32,16 +32,16 @@ namespace ngraph ...@@ -32,16 +32,16 @@ namespace ngraph
{ {
namespace builder namespace builder
{ {
std::shared_ptr<Node> ScaledQuantize(const std::shared_ptr<Node>& input, std::shared_ptr<Node> ScaledQuantize(const Output<Node>& input,
const std::shared_ptr<Node>& min, const Output<Node>& min,
const std::shared_ptr<Node>& max, const Output<Node>& max,
const ngraph::element::Type& type, const ngraph::element::Type& type,
const ngraph::AxisSet& axes, const ngraph::AxisSet& axes,
op::Quantize::RoundMode round_mode); op::Quantize::RoundMode round_mode);
std::shared_ptr<Node> ScaledDequantize(const std::shared_ptr<Node>& input, std::shared_ptr<Node> ScaledDequantize(const Output<Node>& input,
const std::shared_ptr<Node>& min, const Output<Node>& min,
const std::shared_ptr<Node>& max, const Output<Node>& max,
const ngraph::element::Type& type, const ngraph::element::Type& type,
const ngraph::AxisSet& axes); const ngraph::AxisSet& axes);
...@@ -50,115 +50,113 @@ namespace ngraph ...@@ -50,115 +50,113 @@ namespace ngraph
const NodeVector& mins, const NodeVector& mins,
const NodeVector& maxes); const NodeVector& maxes);
std::shared_ptr<Node> ScaledQuantizedAvgPool(const std::shared_ptr<Node>& input, std::shared_ptr<Node> ScaledQuantizedAvgPool(const Output<Node>& input,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
bool include_padding_in_avg_computation, bool include_padding_in_avg_computation,
const std::shared_ptr<Node>& min, const Output<Node>& min,
const std::shared_ptr<Node>& max); const Output<Node>& max);
std::shared_ptr<Node> std::shared_ptr<Node> ScaledQuantizedConvolutionBias(const Output<Node>& input,
ScaledQuantizedConvolutionBias(const std::shared_ptr<Node>& input, const Output<Node>& filters,
const std::shared_ptr<Node>& filters, const Output<Node>& bias,
const std::shared_ptr<Node>& bias, const Strides& window_movement_strides,
const Strides& window_movement_strides, const Strides& window_dilation_strides,
const Strides& window_dilation_strides, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_above,
const CoordinateDiff& padding_above, const Strides& data_dilation_strides,
const Strides& data_dilation_strides, const Output<Node>& min_input,
const std::shared_ptr<Node>& min_input, const Output<Node>& max_input,
const std::shared_ptr<Node>& max_input, const Output<Node>& min_filter,
const std::shared_ptr<Node>& min_filter, const Output<Node>& max_filter,
const std::shared_ptr<Node>& max_filter, const Output<Node>& min_output,
const std::shared_ptr<Node>& min_output, const Output<Node>& max_output,
const std::shared_ptr<Node>& max_output, const bool with_relu = false);
const bool with_relu = false);
std::shared_ptr<Node> std::shared_ptr<Node> ScaledQuantizedConvolutionRelu(const Output<Node>& input,
ScaledQuantizedConvolutionRelu(const std::shared_ptr<Node>& input, const Output<Node>& filters,
const std::shared_ptr<Node>& filters, const Strides& window_movement_strides,
const Strides& window_movement_strides, const Strides& window_dilation_strides,
const Strides& window_dilation_strides, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_above,
const CoordinateDiff& padding_above, const Strides& data_dilation_strides,
const Strides& data_dilation_strides, const Output<Node>& min_input,
const std::shared_ptr<Node>& min_input, const Output<Node>& max_input,
const std::shared_ptr<Node>& max_input, const Output<Node>& min_filter,
const std::shared_ptr<Node>& min_filter, const Output<Node>& max_filter,
const std::shared_ptr<Node>& max_filter, const Output<Node>& min_output,
const std::shared_ptr<Node>& min_output, const Output<Node>& max_output);
const std::shared_ptr<Node>& max_output);
std::shared_ptr<Node> ScaledQuantizedMaxPool(const std::shared_ptr<Node>& input, std::shared_ptr<Node> ScaledQuantizedMaxPool(const Output<Node>& input,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
const std::shared_ptr<Node>& min, const Output<Node>& min,
const std::shared_ptr<Node>& max); const Output<Node>& max);
std::shared_ptr<Node> std::shared_ptr<Node>
ScaledQuantizedConvolutionBiasAdd(const std::shared_ptr<Node>& input, ScaledQuantizedConvolutionBiasAdd(const Output<Node>& input,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const std::shared_ptr<Node>& bias, const Output<Node>& bias,
const std::shared_ptr<Node>& sum_input, const Output<Node>& sum_input,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node>& min_input, const Output<Node>& min_input,
const std::shared_ptr<Node>& max_input, const Output<Node>& max_input,
const std::shared_ptr<Node>& min_filter, const Output<Node>& min_filter,
const std::shared_ptr<Node>& max_filter, const Output<Node>& max_filter,
const std::shared_ptr<Node>& min_output, const Output<Node>& min_output,
const std::shared_ptr<Node>& max_output, const Output<Node>& max_output,
const std::shared_ptr<Node>& min_sum_input, const Output<Node>& min_sum_input,
const std::shared_ptr<Node>& max_sum_input, const Output<Node>& max_sum_input,
const bool with_relu = false); const bool with_relu = false);
std::shared_ptr<Node> std::shared_ptr<Node>
ScaledQuantizedConvolutionBiasSignedAdd(const std::shared_ptr<Node>& input, ScaledQuantizedConvolutionBiasSignedAdd(const Output<Node>& input,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const std::shared_ptr<Node>& bias, const Output<Node>& bias,
const std::shared_ptr<Node>& sum_input, const Output<Node>& sum_input,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node>& min_input, const Output<Node>& min_input,
const std::shared_ptr<Node>& max_input, const Output<Node>& max_input,
const std::shared_ptr<Node>& min_filter, const Output<Node>& min_filter,
const std::shared_ptr<Node>& max_filter, const Output<Node>& max_filter,
const std::shared_ptr<Node>& min_output, const Output<Node>& min_output,
const std::shared_ptr<Node>& max_output, const Output<Node>& max_output,
const std::shared_ptr<Node>& min_sum_input, const Output<Node>& min_sum_input,
const std::shared_ptr<Node>& max_sum_input, const Output<Node>& max_sum_input,
const bool with_relu = false); const bool with_relu = false);
std::shared_ptr<Node> ScaledQuantizedDotBias(const std::shared_ptr<Node>& input, std::shared_ptr<Node> ScaledQuantizedDotBias(const Output<Node>& input,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const std::shared_ptr<Node>& bias, const Output<Node>& bias,
const std::shared_ptr<Node>& min_input, const Output<Node>& min_input,
const std::shared_ptr<Node>& max_input, const Output<Node>& max_input,
const std::shared_ptr<Node>& min_filter, const Output<Node>& min_filter,
const std::shared_ptr<Node>& max_filter, const Output<Node>& max_filter,
const std::shared_ptr<Node>& min_output, const Output<Node>& min_output,
const std::shared_ptr<Node>& max_output, const Output<Node>& max_output,
const bool requantize = true, const bool requantize = true,
const bool with_relu = false); const bool with_relu = false);
std::shared_ptr<Node> ScaledQuantizedDot(const std::shared_ptr<Node>& input, std::shared_ptr<Node> ScaledQuantizedDot(const Output<Node>& input,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const std::shared_ptr<Node>& min_input, const Output<Node>& min_input,
const std::shared_ptr<Node>& max_input, const Output<Node>& max_input,
const std::shared_ptr<Node>& min_filter, const Output<Node>& min_filter,
const std::shared_ptr<Node>& max_filter, const Output<Node>& max_filter,
const std::shared_ptr<Node>& min_output, const Output<Node>& min_output,
const std::shared_ptr<Node>& max_output, const Output<Node>& max_output,
const bool requantize = true, const bool requantize = true,
const bool with_relu = false); const bool with_relu = false);
......
...@@ -36,25 +36,25 @@ namespace ngraph ...@@ -36,25 +36,25 @@ namespace ngraph
{ {
namespace quantization namespace quantization
{ {
shared_ptr<Node> QuantizedLinearConvolutionBias(const shared_ptr<Node>& input, shared_ptr<Node> QuantizedLinearConvolutionBias(const Output<Node>& input,
const shared_ptr<Node>& filter, const Output<Node>& filter,
const shared_ptr<Node>& bias, const Output<Node>& bias,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const shared_ptr<Node>& input_scale, const Output<Node>& input_scale,
const shared_ptr<Node>& filter_scale, const Output<Node>& filter_scale,
const shared_ptr<Node>& output_scale) const Output<Node>& output_scale)
{ {
// TODO: need to establish cross-nGraph view of scale (mult or div) // TODO: need to establish cross-nGraph view of scale (mult or div)
auto requantization_scale = (input_scale * filter_scale) / output_scale; auto requantization_scale = (input_scale * filter_scale) / output_scale;
auto mybias = bias; auto mybias = bias;
if (bias->get_element_type() != element::i32) if (bias.get_element_type() != element::i32)
{ {
const auto zero = make_constant(element::i32, input_scale->get_shape(), 0); const auto zero = make_constant(element::i32, input_scale.get_shape(), 0);
const AxisSet quantization_axes; const AxisSet quantization_axes;
const auto bias_scale = input_scale * filter_scale; const auto bias_scale = input_scale * filter_scale;
op::Quantize::RoundMode round_mode = op::Quantize::RoundMode round_mode =
......
...@@ -26,17 +26,17 @@ namespace ngraph ...@@ -26,17 +26,17 @@ namespace ngraph
namespace quantization namespace quantization
{ {
std::shared_ptr<Node> std::shared_ptr<Node>
QuantizedLinearConvolutionBias(const std::shared_ptr<Node>& input, QuantizedLinearConvolutionBias(const Output<Node>& input,
const std::shared_ptr<Node>& filter, const Output<Node>& filter,
const std::shared_ptr<Node>& bias, const Output<Node>& bias,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node>& input_scale, const Output<Node>& input_scale,
const std::shared_ptr<Node>& filter_scale, const Output<Node>& filter_scale,
const std::shared_ptr<Node>& output_scale); const Output<Node>& output_scale);
} }
} }
} }
...@@ -39,14 +39,14 @@ namespace ngraph ...@@ -39,14 +39,14 @@ namespace ngraph
{ {
// TODO: this code is falling back to fp32 dot // TODO: this code is falling back to fp32 dot
// 1) add support in reference kernel for zero point // 1) add support in reference kernel for zero point
shared_ptr<Node> QuantizedLinearMatmul(const shared_ptr<Node>& input0, shared_ptr<Node> QuantizedLinearMatmul(const Output<Node>& input0,
const shared_ptr<Node>& input1, const Output<Node>& input1,
const shared_ptr<Node>& input0_scale, const Output<Node>& input0_scale,
const shared_ptr<Node>& input0_zero_point, const Output<Node>& input0_zero_point,
const shared_ptr<Node>& input1_scale, const Output<Node>& input1_scale,
const shared_ptr<Node>& input1_zero_point, const Output<Node>& input1_zero_point,
const shared_ptr<Node>& output_scale, const Output<Node>& output_scale,
const shared_ptr<Node>& output_zero_point) const Output<Node>& output_zero_point)
{ {
// Check if zero point is constant and zero // Check if zero point is constant and zero
if (ngraph::is_zero(input0_zero_point) && ngraph::is_zero(input1_zero_point) && if (ngraph::is_zero(input0_zero_point) && ngraph::is_zero(input1_zero_point) &&
...@@ -62,13 +62,13 @@ namespace ngraph ...@@ -62,13 +62,13 @@ namespace ngraph
auto dq_input0 = make_shared<op::Dequantize>(input0, auto dq_input0 = make_shared<op::Dequantize>(input0,
input0_scale, input0_scale,
input0_zero_point, input0_zero_point,
input0_scale->get_element_type(), input0_scale.get_element_type(),
axes); axes);
auto dq_input1 = make_shared<op::Dequantize>(input1, auto dq_input1 = make_shared<op::Dequantize>(input1,
input1_scale, input1_scale,
input1_zero_point, input1_zero_point,
input1_scale->get_element_type(), input1_scale.get_element_type(),
axes); axes);
auto dot = make_shared<op::Dot>(dq_input0, dq_input1, 1); auto dot = make_shared<op::Dot>(dq_input0, dq_input1, 1);
...@@ -76,24 +76,23 @@ namespace ngraph ...@@ -76,24 +76,23 @@ namespace ngraph
dot, dot,
output_scale, output_scale,
output_zero_point, output_zero_point,
output_zero_point->get_element_type(), output_zero_point.get_element_type(),
axes, axes,
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN); op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
} }
} }
shared_ptr<Node> QuantizedLinearMatmulInteger(const shared_ptr<Node>& input0, shared_ptr<Node> QuantizedLinearMatmulInteger(const Output<Node>& input0,
const shared_ptr<Node>& input1) const Output<Node>& input1)
{ {
auto output_scale = make_constant(element::f32, Shape{}, 1); auto output_scale = make_constant(element::f32, Shape{}, 1);
return make_shared<op::QuantizedDot>(input0, input1, output_scale, false, false); return make_shared<op::QuantizedDot>(input0, input1, output_scale, false, false);
} }
shared_ptr<Node> shared_ptr<Node> QuantizedLinearMatmulInteger(const Output<Node>& input0,
QuantizedLinearMatmulInteger(const std::shared_ptr<Node>& input0, const Output<Node>& input1,
const std::shared_ptr<Node>& input1, const Output<Node>& input0_zero_point,
const std::shared_ptr<Node>& input0_zero_point, const Output<Node>& input1_zero_point)
const std::shared_ptr<Node>& input1_zero_point)
{ {
// Check if zero points are constant and zero // Check if zero points are constant and zero
if (ngraph::is_zero(input0_zero_point) && ngraph::is_zero(input1_zero_point)) if (ngraph::is_zero(input0_zero_point) && ngraph::is_zero(input1_zero_point))
......
...@@ -25,24 +25,23 @@ namespace ngraph ...@@ -25,24 +25,23 @@ namespace ngraph
{ {
namespace quantization namespace quantization
{ {
std::shared_ptr<Node> std::shared_ptr<Node> QuantizedLinearMatmul(const Output<Node>& input0,
QuantizedLinearMatmul(const std::shared_ptr<Node>& input0, const Output<Node>& input1,
const std::shared_ptr<Node>& input1, const Output<Node>& input0_scale,
const std::shared_ptr<Node>& input0_scale, const Output<Node>& input0_zero_point,
const std::shared_ptr<Node>& input0_zero_point, const Output<Node>& input1_scale,
const std::shared_ptr<Node>& input1_scale, const Output<Node>& input1_zero_point,
const std::shared_ptr<Node>& input1_zero_point, const Output<Node>& output_scale,
const std::shared_ptr<Node>& output_scale, const Output<Node>& output_zero_point);
const std::shared_ptr<Node>& output_zero_point);
std::shared_ptr<Node> QuantizedLinearMatmulInteger(const std::shared_ptr<Node>& input0, std::shared_ptr<Node> QuantizedLinearMatmulInteger(const Output<Node>& input0,
const std::shared_ptr<Node>& input1); const Output<Node>& input1);
std::shared_ptr<Node> std::shared_ptr<Node>
QuantizedLinearMatmulInteger(const std::shared_ptr<Node>& input0, QuantizedLinearMatmulInteger(const Output<Node>& input0,
const std::shared_ptr<Node>& input1, const Output<Node>& input1,
const std::shared_ptr<Node>& input0_zero_point, const Output<Node>& input0_zero_point,
const std::shared_ptr<Node>& input1_zero_point); const Output<Node>& input1_zero_point);
} }
} }
} }
...@@ -37,7 +37,7 @@ namespace ngraph ...@@ -37,7 +37,7 @@ namespace ngraph
{ {
namespace quantization_util namespace quantization_util
{ {
std::shared_ptr<Node> max_abs(std::shared_ptr<Node> a, std::shared_ptr<Node> b) std::shared_ptr<Node> max_abs(Output<Node> a, Output<Node> b)
{ {
auto abs_a = std::make_shared<op::Abs>(a); auto abs_a = std::make_shared<op::Abs>(a);
auto abs_b = std::make_shared<op::Abs>(b); auto abs_b = std::make_shared<op::Abs>(b);
...@@ -45,22 +45,22 @@ namespace ngraph ...@@ -45,22 +45,22 @@ namespace ngraph
} }
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>
quantization_range_for_multiplication(std::shared_ptr<Node> min_a, quantization_range_for_multiplication(Output<Node> min_a,
std::shared_ptr<Node> max_a, Output<Node> max_a,
std::shared_ptr<Node> min_b, Output<Node> min_b,
std::shared_ptr<Node> max_b) Output<Node> max_b)
{ {
auto type = min_a->get_element_type(); auto type = min_a.get_element_type();
if (type != max_a->get_element_type() || type != min_b->get_element_type() || if (type != max_a.get_element_type() || type != min_b.get_element_type() ||
type != max_b->get_element_type()) type != max_b.get_element_type())
{ {
throw ngraph_error( throw ngraph_error(
"quantization_range_for_multiplication: min and max must have same type"); "quantization_range_for_multiplication: min and max must have same type");
} }
auto shape = min_a->get_shape(); auto shape = min_a.get_shape();
if (shape != max_a->get_shape() || shape != min_b->get_shape() || if (shape != max_a.get_shape() || shape != min_b.get_shape() ||
shape != max_b->get_shape()) shape != max_b.get_shape())
{ {
throw ngraph_error( throw ngraph_error(
"quantization_range_for_multiplication: min and max must have same shape"); "quantization_range_for_multiplication: min and max must have same shape");
...@@ -87,28 +87,27 @@ namespace ngraph ...@@ -87,28 +87,27 @@ namespace ngraph
return std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>(min_c, max_c); return std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>(min_c, max_c);
} }
std::shared_ptr<Node> get_scale(std::shared_ptr<Node> min_input, std::shared_ptr<Node> get_scale(Output<Node> min_input,
std::shared_ptr<Node> max_input, Output<Node> max_input,
std::shared_ptr<Node> min_filter, Output<Node> min_filter,
std::shared_ptr<Node> max_filter, Output<Node> max_filter,
std::shared_ptr<Node> min_freezed_output, Output<Node> min_freezed_output,
std::shared_ptr<Node> max_freezed_output, Output<Node> max_freezed_output,
const ngraph::element::Type& output_type) const ngraph::element::Type& output_type)
{ {
auto type = min_input->get_element_type(); auto type = min_input.get_element_type();
if (type != max_input->get_element_type() || if (type != max_input.get_element_type() || type != min_filter.get_element_type() ||
type != min_filter->get_element_type() || type != max_filter.get_element_type() ||
type != max_filter->get_element_type() || type != min_freezed_output.get_element_type() ||
type != min_freezed_output->get_element_type() || type != max_freezed_output.get_element_type())
type != max_freezed_output->get_element_type())
{ {
throw ngraph_error("get_scale: min and max must have same type"); throw ngraph_error("get_scale: min and max must have same type");
} }
auto shape = min_input->get_shape(); auto shape = min_input.get_shape();
if (shape != max_input->get_shape() || shape != min_filter->get_shape() || if (shape != max_input.get_shape() || shape != min_filter.get_shape() ||
shape != max_filter->get_shape() || shape != min_freezed_output->get_shape() || shape != max_filter.get_shape() || shape != min_freezed_output.get_shape() ||
shape != max_freezed_output->get_shape()) shape != max_freezed_output.get_shape())
{ {
throw ngraph_error("get_scale: min and max must have same shape"); throw ngraph_error("get_scale: min and max must have same shape");
} }
...@@ -147,22 +146,21 @@ namespace ngraph ...@@ -147,22 +146,21 @@ namespace ngraph
(max_abs32 / max_abs8); (max_abs32 / max_abs8);
} }
std::shared_ptr<Node> get_bias_scale(std::shared_ptr<Node> min_input, std::shared_ptr<Node> get_bias_scale(Output<Node> min_input,
std::shared_ptr<Node> max_input, Output<Node> max_input,
std::shared_ptr<Node> min_filter, Output<Node> min_filter,
std::shared_ptr<Node> max_filter) Output<Node> max_filter)
{ {
auto type = min_input->get_element_type(); auto type = min_input.get_element_type();
if (type != max_input->get_element_type() || if (type != max_input.get_element_type() || type != min_filter.get_element_type() ||
type != min_filter->get_element_type() || type != max_filter.get_element_type())
type != max_filter->get_element_type())
{ {
throw ngraph_error("get_bias_scale: min and max must have same type"); throw ngraph_error("get_bias_scale: min and max must have same type");
} }
auto shape = min_input->get_shape(); auto shape = min_input.get_shape();
if (shape != max_input->get_shape() || shape != min_filter->get_shape() || if (shape != max_input.get_shape() || shape != min_filter.get_shape() ||
shape != max_filter->get_shape()) shape != max_filter.get_shape())
{ {
throw ngraph_error("get_bias_scale: min and max must have same shape"); throw ngraph_error("get_bias_scale: min and max must have same shape");
} }
...@@ -178,23 +176,23 @@ namespace ngraph ...@@ -178,23 +176,23 @@ namespace ngraph
return (max_abs_input_range * max_abs_filter_range) / range; return (max_abs_input_range * max_abs_filter_range) / range;
} }
std::shared_ptr<Node> get_sum_scale(std::shared_ptr<Node> min_freezed_output_conv_1, std::shared_ptr<Node> get_sum_scale(Output<Node> min_freezed_output_conv_1,
std::shared_ptr<Node> max_freezed_output_conv_1, Output<Node> max_freezed_output_conv_1,
std::shared_ptr<Node> min_freezed_output_conv_2, Output<Node> min_freezed_output_conv_2,
std::shared_ptr<Node> max_freezed_output_conv_2) Output<Node> max_freezed_output_conv_2)
{ {
auto type = min_freezed_output_conv_1->get_element_type(); auto type = min_freezed_output_conv_1.get_element_type();
if (type != max_freezed_output_conv_1->get_element_type() || if (type != max_freezed_output_conv_1.get_element_type() ||
type != min_freezed_output_conv_2->get_element_type() || type != min_freezed_output_conv_2.get_element_type() ||
type != max_freezed_output_conv_2->get_element_type()) type != max_freezed_output_conv_2.get_element_type())
{ {
throw ngraph_error("get_sum_scale: min and max must have same type"); throw ngraph_error("get_sum_scale: min and max must have same type");
} }
auto shape = min_freezed_output_conv_1->get_shape(); auto shape = min_freezed_output_conv_1.get_shape();
if (shape != max_freezed_output_conv_1->get_shape() || if (shape != max_freezed_output_conv_1.get_shape() ||
shape != min_freezed_output_conv_2->get_shape() || shape != min_freezed_output_conv_2.get_shape() ||
shape != max_freezed_output_conv_2->get_shape()) shape != max_freezed_output_conv_2.get_shape())
{ {
throw ngraph_error("get_sum_scale: min and max must have same shape"); throw ngraph_error("get_sum_scale: min and max must have same shape");
} }
...@@ -204,19 +202,19 @@ namespace ngraph ...@@ -204,19 +202,19 @@ namespace ngraph
return max_abs_conv_2 / max_abs_conv_1; return max_abs_conv_2 / max_abs_conv_1;
} }
std::shared_ptr<Node> get_scale(std::shared_ptr<Node> input_min_range, std::shared_ptr<Node> get_scale(Output<Node> input_min_range,
std::shared_ptr<Node> input_max_range, Output<Node> input_max_range,
const ngraph::element::Type& quant_type, const ngraph::element::Type& quant_type,
bool bump_by_eps = false) bool bump_by_eps = false)
{ {
auto type = input_min_range->get_element_type(); auto type = input_min_range.get_element_type();
if (type != input_max_range->get_element_type()) if (type != input_max_range.get_element_type())
{ {
throw ngraph_error("get_scale: min and max must have same type"); throw ngraph_error("get_scale: min and max must have same type");
} }
auto shape = input_min_range->get_shape(); auto shape = input_min_range.get_shape();
if (shape != input_max_range->get_shape()) if (shape != input_max_range.get_shape())
{ {
throw ngraph_error("get_scale: min and max must have same shape"); throw ngraph_error("get_scale: min and max must have same shape");
} }
...@@ -277,30 +275,29 @@ namespace ngraph ...@@ -277,30 +275,29 @@ namespace ngraph
} }
} }
std::shared_ptr<Node> get_dot_scale(std::shared_ptr<Node> min_input, std::shared_ptr<Node> get_dot_scale(Output<Node> min_input,
std::shared_ptr<Node> max_input, Output<Node> max_input,
std::shared_ptr<Node> min_filter, Output<Node> min_filter,
std::shared_ptr<Node> max_filter, Output<Node> max_filter,
std::shared_ptr<Node> min_freezed_output, Output<Node> min_freezed_output,
std::shared_ptr<Node> max_freezed_output, Output<Node> max_freezed_output,
const ngraph::element::Type& input_type, const ngraph::element::Type& input_type,
const ngraph::element::Type& output_type, const ngraph::element::Type& output_type,
const bool requantize = true) const bool requantize = true)
{ {
auto type = min_input->get_element_type(); auto type = min_input.get_element_type();
if (type != max_input->get_element_type() || if (type != max_input.get_element_type() || type != min_filter.get_element_type() ||
type != min_filter->get_element_type() || type != max_filter.get_element_type() ||
type != max_filter->get_element_type() || type != min_freezed_output.get_element_type() ||
type != min_freezed_output->get_element_type() || type != max_freezed_output.get_element_type())
type != max_freezed_output->get_element_type())
{ {
throw ngraph_error("get_dot_scale: min and max must have same type"); throw ngraph_error("get_dot_scale: min and max must have same type");
} }
auto shape = min_input->get_shape(); auto shape = min_input.get_shape();
if (shape != max_input->get_shape() || shape != min_filter->get_shape() || if (shape != max_input.get_shape() || shape != min_filter.get_shape() ||
shape != max_filter->get_shape() || shape != min_freezed_output->get_shape() || shape != max_filter.get_shape() || shape != min_freezed_output.get_shape() ||
shape != max_freezed_output->get_shape()) shape != max_freezed_output.get_shape())
{ {
throw ngraph_error("get_dot_scale: min and max must have same shape"); throw ngraph_error("get_dot_scale: min and max must have same shape");
} }
......
...@@ -22,26 +22,26 @@ namespace ngraph ...@@ -22,26 +22,26 @@ namespace ngraph
{ {
namespace quantization_utils namespace quantization_utils
{ {
std::shared_ptr<Node> max_abs(std::shared_ptr<Node> a, std::shared_ptr<Node> b) std::shared_ptr<Node> max_abs(const Output<Node>& a, const Output<Node>& b)
{ {
auto abs_a = std::make_shared<op::Abs>(a); auto abs_a = std::make_shared<op::Abs>(a);
auto abs_b = std::make_shared<op::Abs>(b); auto abs_b = std::make_shared<op::Abs>(b);
return std::make_shared<op::Maximum>(abs_a, abs_b); return std::make_shared<op::Maximum>(abs_a, abs_b);
} }
std::shared_ptr<Node> get_scale(std::shared_ptr<Node> input_min_range, std::shared_ptr<Node> get_scale(const Output<Node>& input_min_range,
std::shared_ptr<Node> input_max_range, const Output<Node>& input_max_range,
const ngraph::element::Type& quant_type, const ngraph::element::Type& quant_type,
bool bump_by_eps) bool bump_by_eps)
{ {
auto type = input_min_range->get_element_type(); auto type = input_min_range.get_element_type();
if (type != input_max_range->get_element_type()) if (type != input_max_range.get_element_type())
{ {
throw ngraph_error("get_scale: min and max must have same type"); throw ngraph_error("get_scale: min and max must have same type");
} }
auto shape = input_min_range->get_shape(); auto shape = input_min_range.get_shape();
if (shape != input_max_range->get_shape()) if (shape != input_max_range.get_shape())
{ {
throw ngraph_error("get_scale: min and max must have same shape"); throw ngraph_error("get_scale: min and max must have same shape");
} }
......
...@@ -37,10 +37,10 @@ namespace ngraph ...@@ -37,10 +37,10 @@ namespace ngraph
{ {
namespace quantization_utils namespace quantization_utils
{ {
std::shared_ptr<Node> max_abs(std::shared_ptr<Node> a, std::shared_ptr<Node> b); std::shared_ptr<Node> max_abs(const Output<Node>& a, const Output<Node>& b);
std::shared_ptr<Node> get_scale(std::shared_ptr<Node> input_min_range, std::shared_ptr<Node> get_scale(const Output<Node>& input_min_range,
std::shared_ptr<Node> input_max_range, const Output<Node>& input_max_range,
const ngraph::element::Type& quant_type, const ngraph::element::Type& quant_type,
bool bump_by_eps = false); bool bump_by_eps = false);
} }
......
...@@ -26,35 +26,34 @@ namespace ngraph ...@@ -26,35 +26,34 @@ namespace ngraph
{ {
namespace builder namespace builder
{ {
shared_ptr<Node> QuantizedConvolutionBuilder(const shared_ptr<Node>& input, shared_ptr<Node> QuantizedConvolutionBuilder(const Output<Node>& input,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const shared_ptr<Node>& min_input, const Output<Node>& min_input,
const shared_ptr<Node>& max_input, const Output<Node>& max_input,
const shared_ptr<Node>& min_filter, const Output<Node>& min_filter,
const shared_ptr<Node>& max_filter, const Output<Node>& max_filter,
const shared_ptr<Node>& min_output, const Output<Node>& min_output,
const shared_ptr<Node>& max_output, const Output<Node>& max_output,
const ngraph::element::Type& output_type, const ngraph::element::Type& output_type,
const ngraph::AxisSet& input_axes, const ngraph::AxisSet& input_axes,
const ngraph::AxisSet& filter_axes, const ngraph::AxisSet& filter_axes,
const ngraph::AxisSet& output_axes) const ngraph::AxisSet& output_axes)
{ {
auto input_scale = auto input_scale =
quantization_utils::get_scale(min_input, max_input, input->get_element_type()); quantization_utils::get_scale(min_input, max_input, input.get_element_type());
auto filter_scale = auto filter_scale =
quantization_utils::get_scale(min_filter, max_filter, filters->get_element_type()); quantization_utils::get_scale(min_filter, max_filter, filters.get_element_type());
auto output_scale = quantization_utils::get_scale(min_output, max_output, output_type); auto output_scale = quantization_utils::get_scale(min_output, max_output, output_type);
// TODO: Check for this later // TODO: Check for this later
// For Builders the zero point is assumed to be zero (for now) // For Builders the zero point is assumed to be zero (for now)
auto input_zero_point = op::Constant::create(input->get_element_type(), Shape{}, {0}); auto input_zero_point = op::Constant::create(input.get_element_type(), Shape{}, {0});
auto filter_zero_point = auto filter_zero_point = op::Constant::create(filters.get_element_type(), Shape{}, {0});
op::Constant::create(filters->get_element_type(), Shape{}, {0});
return make_shared<op::QuantizedConvolution>( return make_shared<op::QuantizedConvolution>(
input, input,
......
...@@ -26,19 +26,19 @@ namespace ngraph ...@@ -26,19 +26,19 @@ namespace ngraph
namespace builder namespace builder
{ {
std::shared_ptr<Node> std::shared_ptr<Node>
QuantizedConvolutionBuilder(const std::shared_ptr<Node>& input, QuantizedConvolutionBuilder(const Output<Node>& input,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node>& min_input, const Output<Node>& min_input,
const std::shared_ptr<Node>& max_input, const Output<Node>& max_input,
const std::shared_ptr<Node>& min_filter, const Output<Node>& min_filter,
const std::shared_ptr<Node>& max_filter, const Output<Node>& max_filter,
const std::shared_ptr<Node>& min_output, const Output<Node>& min_output,
const std::shared_ptr<Node>& max_output, const Output<Node>& max_output,
const ngraph::element::Type& output_type, const ngraph::element::Type& output_type,
const ngraph::AxisSet& input_axes = ngraph::AxisSet{}, const ngraph::AxisSet& input_axes = ngraph::AxisSet{},
const ngraph::AxisSet& filter_axes = ngraph::AxisSet{}, const ngraph::AxisSet& filter_axes = ngraph::AxisSet{},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment