Unverified Commit 34499001 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Quantization conversion from nodes to outputs (#3316)

parent 8eb63379
This diff is collapsed.
This diff is collapsed.
......@@ -36,25 +36,25 @@ namespace ngraph
{
namespace quantization
{
shared_ptr<Node> QuantizedLinearConvolutionBias(const shared_ptr<Node>& input,
const shared_ptr<Node>& filter,
const shared_ptr<Node>& bias,
shared_ptr<Node> QuantizedLinearConvolutionBias(const Output<Node>& input,
const Output<Node>& filter,
const Output<Node>& bias,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const shared_ptr<Node>& input_scale,
const shared_ptr<Node>& filter_scale,
const shared_ptr<Node>& output_scale)
const Output<Node>& input_scale,
const Output<Node>& filter_scale,
const Output<Node>& output_scale)
{
// TODO: need to establish cross-nGraph view of scale (mult or div)
auto requantization_scale = (input_scale * filter_scale) / output_scale;
auto mybias = bias;
if (bias->get_element_type() != element::i32)
if (bias.get_element_type() != element::i32)
{
const auto zero = make_constant(element::i32, input_scale->get_shape(), 0);
const auto zero = make_constant(element::i32, input_scale.get_shape(), 0);
const AxisSet quantization_axes;
const auto bias_scale = input_scale * filter_scale;
op::Quantize::RoundMode round_mode =
......
......@@ -26,17 +26,17 @@ namespace ngraph
namespace quantization
{
std::shared_ptr<Node>
QuantizedLinearConvolutionBias(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& filter,
const std::shared_ptr<Node>& bias,
QuantizedLinearConvolutionBias(const Output<Node>& input,
const Output<Node>& filter,
const Output<Node>& bias,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node>& input_scale,
const std::shared_ptr<Node>& filter_scale,
const std::shared_ptr<Node>& output_scale);
const Output<Node>& input_scale,
const Output<Node>& filter_scale,
const Output<Node>& output_scale);
}
}
}
......@@ -39,14 +39,14 @@ namespace ngraph
{
// TODO: this code is falling back to fp32 dot
// 1) add support in reference kernel for zero point
shared_ptr<Node> QuantizedLinearMatmul(const shared_ptr<Node>& input0,
const shared_ptr<Node>& input1,
const shared_ptr<Node>& input0_scale,
const shared_ptr<Node>& input0_zero_point,
const shared_ptr<Node>& input1_scale,
const shared_ptr<Node>& input1_zero_point,
const shared_ptr<Node>& output_scale,
const shared_ptr<Node>& output_zero_point)
shared_ptr<Node> QuantizedLinearMatmul(const Output<Node>& input0,
const Output<Node>& input1,
const Output<Node>& input0_scale,
const Output<Node>& input0_zero_point,
const Output<Node>& input1_scale,
const Output<Node>& input1_zero_point,
const Output<Node>& output_scale,
const Output<Node>& output_zero_point)
{
// Check if zero point is constant and zero
if (ngraph::is_zero(input0_zero_point) && ngraph::is_zero(input1_zero_point) &&
......@@ -62,13 +62,13 @@ namespace ngraph
auto dq_input0 = make_shared<op::Dequantize>(input0,
input0_scale,
input0_zero_point,
input0_scale->get_element_type(),
input0_scale.get_element_type(),
axes);
auto dq_input1 = make_shared<op::Dequantize>(input1,
input1_scale,
input1_zero_point,
input1_scale->get_element_type(),
input1_scale.get_element_type(),
axes);
auto dot = make_shared<op::Dot>(dq_input0, dq_input1, 1);
......@@ -76,24 +76,23 @@ namespace ngraph
dot,
output_scale,
output_zero_point,
output_zero_point->get_element_type(),
output_zero_point.get_element_type(),
axes,
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
}
}
shared_ptr<Node> QuantizedLinearMatmulInteger(const shared_ptr<Node>& input0,
const shared_ptr<Node>& input1)
shared_ptr<Node> QuantizedLinearMatmulInteger(const Output<Node>& input0,
const Output<Node>& input1)
{
auto output_scale = make_constant(element::f32, Shape{}, 1);
return make_shared<op::QuantizedDot>(input0, input1, output_scale, false, false);
}
shared_ptr<Node>
QuantizedLinearMatmulInteger(const std::shared_ptr<Node>& input0,
const std::shared_ptr<Node>& input1,
const std::shared_ptr<Node>& input0_zero_point,
const std::shared_ptr<Node>& input1_zero_point)
shared_ptr<Node> QuantizedLinearMatmulInteger(const Output<Node>& input0,
const Output<Node>& input1,
const Output<Node>& input0_zero_point,
const Output<Node>& input1_zero_point)
{
// Check if zero points are constant and zero
if (ngraph::is_zero(input0_zero_point) && ngraph::is_zero(input1_zero_point))
......
......@@ -25,24 +25,23 @@ namespace ngraph
{
namespace quantization
{
std::shared_ptr<Node>
QuantizedLinearMatmul(const std::shared_ptr<Node>& input0,
const std::shared_ptr<Node>& input1,
const std::shared_ptr<Node>& input0_scale,
const std::shared_ptr<Node>& input0_zero_point,
const std::shared_ptr<Node>& input1_scale,
const std::shared_ptr<Node>& input1_zero_point,
const std::shared_ptr<Node>& output_scale,
const std::shared_ptr<Node>& output_zero_point);
std::shared_ptr<Node> QuantizedLinearMatmul(const Output<Node>& input0,
const Output<Node>& input1,
const Output<Node>& input0_scale,
const Output<Node>& input0_zero_point,
const Output<Node>& input1_scale,
const Output<Node>& input1_zero_point,
const Output<Node>& output_scale,
const Output<Node>& output_zero_point);
std::shared_ptr<Node> QuantizedLinearMatmulInteger(const std::shared_ptr<Node>& input0,
const std::shared_ptr<Node>& input1);
std::shared_ptr<Node> QuantizedLinearMatmulInteger(const Output<Node>& input0,
const Output<Node>& input1);
std::shared_ptr<Node>
QuantizedLinearMatmulInteger(const std::shared_ptr<Node>& input0,
const std::shared_ptr<Node>& input1,
const std::shared_ptr<Node>& input0_zero_point,
const std::shared_ptr<Node>& input1_zero_point);
QuantizedLinearMatmulInteger(const Output<Node>& input0,
const Output<Node>& input1,
const Output<Node>& input0_zero_point,
const Output<Node>& input1_zero_point);
}
}
}
This diff is collapsed.
......@@ -22,26 +22,26 @@ namespace ngraph
{
namespace quantization_utils
{
std::shared_ptr<Node> max_abs(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
std::shared_ptr<Node> max_abs(const Output<Node>& a, const Output<Node>& b)
{
auto abs_a = std::make_shared<op::Abs>(a);
auto abs_b = std::make_shared<op::Abs>(b);
return std::make_shared<op::Maximum>(abs_a, abs_b);
}
std::shared_ptr<Node> get_scale(std::shared_ptr<Node> input_min_range,
std::shared_ptr<Node> input_max_range,
std::shared_ptr<Node> get_scale(const Output<Node>& input_min_range,
const Output<Node>& input_max_range,
const ngraph::element::Type& quant_type,
bool bump_by_eps)
{
auto type = input_min_range->get_element_type();
if (type != input_max_range->get_element_type())
auto type = input_min_range.get_element_type();
if (type != input_max_range.get_element_type())
{
throw ngraph_error("get_scale: min and max must have same type");
}
auto shape = input_min_range->get_shape();
if (shape != input_max_range->get_shape())
auto shape = input_min_range.get_shape();
if (shape != input_max_range.get_shape())
{
throw ngraph_error("get_scale: min and max must have same shape");
}
......
......@@ -37,10 +37,10 @@ namespace ngraph
{
namespace quantization_utils
{
std::shared_ptr<Node> max_abs(std::shared_ptr<Node> a, std::shared_ptr<Node> b);
std::shared_ptr<Node> max_abs(const Output<Node>& a, const Output<Node>& b);
std::shared_ptr<Node> get_scale(std::shared_ptr<Node> input_min_range,
std::shared_ptr<Node> input_max_range,
std::shared_ptr<Node> get_scale(const Output<Node>& input_min_range,
const Output<Node>& input_max_range,
const ngraph::element::Type& quant_type,
bool bump_by_eps = false);
}
......
......@@ -26,35 +26,34 @@ namespace ngraph
{
namespace builder
{
shared_ptr<Node> QuantizedConvolutionBuilder(const shared_ptr<Node>& input,
const shared_ptr<Node>& filters,
shared_ptr<Node> QuantizedConvolutionBuilder(const Output<Node>& input,
const Output<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const shared_ptr<Node>& min_input,
const shared_ptr<Node>& max_input,
const shared_ptr<Node>& min_filter,
const shared_ptr<Node>& max_filter,
const shared_ptr<Node>& min_output,
const shared_ptr<Node>& max_output,
const Output<Node>& min_input,
const Output<Node>& max_input,
const Output<Node>& min_filter,
const Output<Node>& max_filter,
const Output<Node>& min_output,
const Output<Node>& max_output,
const ngraph::element::Type& output_type,
const ngraph::AxisSet& input_axes,
const ngraph::AxisSet& filter_axes,
const ngraph::AxisSet& output_axes)
{
auto input_scale =
quantization_utils::get_scale(min_input, max_input, input->get_element_type());
quantization_utils::get_scale(min_input, max_input, input.get_element_type());
auto filter_scale =
quantization_utils::get_scale(min_filter, max_filter, filters->get_element_type());
quantization_utils::get_scale(min_filter, max_filter, filters.get_element_type());
auto output_scale = quantization_utils::get_scale(min_output, max_output, output_type);
// TODO: Check for this later
// For Builders the zero point is assumed to be zero (for now)
auto input_zero_point = op::Constant::create(input->get_element_type(), Shape{}, {0});
auto filter_zero_point =
op::Constant::create(filters->get_element_type(), Shape{}, {0});
auto input_zero_point = op::Constant::create(input.get_element_type(), Shape{}, {0});
auto filter_zero_point = op::Constant::create(filters.get_element_type(), Shape{}, {0});
return make_shared<op::QuantizedConvolution>(
input,
......
......@@ -26,19 +26,19 @@ namespace ngraph
namespace builder
{
std::shared_ptr<Node>
QuantizedConvolutionBuilder(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& filters,
QuantizedConvolutionBuilder(const Output<Node>& input,
const Output<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node>& min_input,
const std::shared_ptr<Node>& max_input,
const std::shared_ptr<Node>& min_filter,
const std::shared_ptr<Node>& max_filter,
const std::shared_ptr<Node>& min_output,
const std::shared_ptr<Node>& max_output,
const Output<Node>& min_input,
const Output<Node>& max_input,
const Output<Node>& min_filter,
const Output<Node>& max_filter,
const Output<Node>& min_output,
const Output<Node>& max_output,
const ngraph::element::Type& output_type,
const ngraph::AxisSet& input_axes = ngraph::AxisSet{},
const ngraph::AxisSet& filter_axes = ngraph::AxisSet{},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment