Commit 34806616 authored by nishant.b.patel's avatar nishant.b.patel

Add support for non-zero zero point in reference conv

parent 4b55a21d
...@@ -63,9 +63,9 @@ namespace ngraph ...@@ -63,9 +63,9 @@ namespace ngraph
padding_above, padding_above,
data_dilation_strides, data_dilation_strides,
static_cast<const float*>(input_scale), static_cast<const float*>(input_scale),
static_cast<const INPUT*>(input_zero_point), static_cast<INPUT*>(input_zero_point),
static_cast<const float*>(filter_scale), static_cast<const float*>(filter_scale),
static_cast<const FILTER*>(filter_zero_point), static_cast<FILTER*>(filter_zero_point),
static_cast<const float*>(output_scale), static_cast<const float*>(output_scale),
static_cast<OUTPUT*>(output_zero_point)); static_cast<OUTPUT*>(output_zero_point));
} }
......
...@@ -74,12 +74,19 @@ namespace ngraph ...@@ -74,12 +74,19 @@ namespace ngraph
size_t out_batch_axis, size_t out_batch_axis,
size_t out_channel_axis, size_t out_channel_axis,
const float* input_scale = nullptr, const float* input_scale = nullptr,
const INPUT* input_zero_point = nullptr, INPUT* input_zero_point = nullptr,
const float* filter_scale = nullptr, const float* filter_scale = nullptr,
const FILTER* filter_zero_point = nullptr, FILTER* filter_zero_point = nullptr,
const float* output_scale = nullptr, const float* output_scale = nullptr,
const OUTPUT* output_zero_point = nullptr) OUTPUT* output_zero_point = nullptr)
{ {
bool is_quantized = false;
if (input_scale && input_zero_point && filter_scale && filter_zero_point &&
output_scale && output_zero_point)
{
is_quantized = true;
}
auto old_mode = std::fegetround(); auto old_mode = std::fegetround();
std::fesetround(FE_TONEAREST); std::fesetround(FE_TONEAREST);
// Comments throughout assume without loss of generality that: // Comments throughout assume without loss of generality that:
...@@ -218,6 +225,11 @@ namespace ngraph ...@@ -218,6 +225,11 @@ namespace ngraph
{ {
ACCUMULATION in_v = in[in_idx]; ACCUMULATION in_v = in[in_idx];
ACCUMULATION f_v = filter[filter_idx]; ACCUMULATION f_v = filter[filter_idx];
if (is_quantized)
{
in_v = in_v - *(input_zero_point);
f_v = f_v - *(filter_zero_point);
}
result += in_v * f_v; result += in_v * f_v;
in_idx += in_channel_stride; in_idx += in_channel_stride;
filter_idx += filter_in_channel_stride; filter_idx += filter_in_channel_stride;
...@@ -226,7 +238,16 @@ namespace ngraph ...@@ -226,7 +238,16 @@ namespace ngraph
++in_it; ++in_it;
++filter_it; ++filter_it;
} }
out[out_transform.index(out_coord)] = result; if (is_quantized)
{
float scale = ((*(input_scale)) * (*(filter_scale))) / (*(output_scale));
out[out_transform.index(out_coord)] =
static_cast<OUTPUT>((result * scale) + *(output_zero_point));
}
else
{
out[out_transform.index(out_coord)] = result;
}
} }
std::fesetround(old_mode); std::fesetround(old_mode);
} }
...@@ -247,11 +268,11 @@ namespace ngraph ...@@ -247,11 +268,11 @@ namespace ngraph
const CoordinateDiff& in_pad_above, const CoordinateDiff& in_pad_above,
const Strides& in_dilation, const Strides& in_dilation,
const float* input_scale = nullptr, const float* input_scale = nullptr,
const INPUT* input_zero_point = nullptr, INPUT* input_zero_point = nullptr,
const float* filter_scale = nullptr, const float* filter_scale = nullptr,
const FILTER* filter_zero_point = nullptr, FILTER* filter_zero_point = nullptr,
const float* output_scale = nullptr, const float* output_scale = nullptr,
const OUTPUT* output_zero_point = nullptr) OUTPUT* output_zero_point = nullptr)
{ {
general_convolution<INPUT, FILTER, OUTPUT, ACCUMULATION>(in, general_convolution<INPUT, FILTER, OUTPUT, ACCUMULATION>(in,
......
...@@ -1429,7 +1429,6 @@ TEST(builder, dynamic_scaled_QD_with_bias) ...@@ -1429,7 +1429,6 @@ TEST(builder, dynamic_scaled_QD_with_bias)
read_vector<uint8_t>(f_requantize_relu_r)); read_vector<uint8_t>(f_requantize_relu_r));
} }
#if 0
TEST(builder, scaled_QC_u8u8) TEST(builder, scaled_QC_u8u8)
{ {
Shape shape_a{1, 1, 3, 4}; // input shape Shape shape_a{1, 1, 3, 4}; // input shape
...@@ -1439,7 +1438,10 @@ TEST(builder, scaled_QC_u8u8) ...@@ -1439,7 +1438,10 @@ TEST(builder, scaled_QC_u8u8)
vector<uint8_t> b_data = {1, 2, 3, 4, 5, 0, 0, 1, 2}; //{0, -1, 0, -2, -3, 5, 0, 2, 1}; vector<uint8_t> b_data = {1, 2, 3, 4, 5, 0, 0, 1, 2}; //{0, -1, 0, -2, -3, 5, 0, 2, 1};
auto A = make_shared<op::Parameter>(element::u8, shape_a); auto A = make_shared<op::Parameter>(element::u8, shape_a);
auto B = make_shared<op::Parameter>(element::u8, shape_b); auto B = make_shared<op::Parameter>(element::u8, shape_b);
auto scale = op::Constant::create(element::f32, Shape{}, {2}); auto input_scale = op::Constant::create(element::f32, Shape{}, {2});
auto filter_scale = op::Constant::create(element::f32, Shape{}, {2});
auto output_scale = op::Constant::create(element::f32, Shape{}, {2});
auto u8_zero = op::Constant::create(element::u8, Shape{}, {0});
auto CV = make_shared<ngraph::op::QuantizedConvolution>(A, auto CV = make_shared<ngraph::op::QuantizedConvolution>(A,
B, B,
Strides{1, 1}, // move_strides Strides{1, 1}, // move_strides
...@@ -1447,8 +1449,14 @@ TEST(builder, scaled_QC_u8u8) ...@@ -1447,8 +1449,14 @@ TEST(builder, scaled_QC_u8u8)
CoordinateDiff{1, 1}, // below_pads CoordinateDiff{1, 1}, // below_pads
CoordinateDiff{1, 1}, // above_pads CoordinateDiff{1, 1}, // above_pads
Strides{1, 1}, // data_dilation Strides{1, 1}, // data_dilation
scale, input_scale,
false); u8_zero,
filter_scale,
u8_zero,
output_scale,
u8_zero,
element::u8,
AxisSet{});
auto f = make_shared<Function>(NodeVector{CV}, ParameterVector{A, B}); auto f = make_shared<Function>(NodeVector{CV}, ParameterVector{A, B});
constant_fold(f); constant_fold(f);
...@@ -1475,7 +1483,6 @@ TEST(builder, scaled_QC_u8u8) ...@@ -1475,7 +1483,6 @@ TEST(builder, scaled_QC_u8u8)
39 * 2} /*{1, 28, -3, 16, -7, -14, 3, -7, -3}*/), 39 * 2} /*{1, 28, -3, 16, -7, -14, 3, -7, -3}*/),
read_vector<uint8_t>(result)); read_vector<uint8_t>(result));
} }
#endif
TEST(builder, scaled_QDot_u8u8) TEST(builder, scaled_QDot_u8u8)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment