Commit 34806616 authored by nishant.b.patel's avatar nishant.b.patel

Add support for non-zero zero point in reference conv

parent 4b55a21d
......@@ -63,9 +63,9 @@ namespace ngraph
padding_above,
data_dilation_strides,
static_cast<const float*>(input_scale),
static_cast<const INPUT*>(input_zero_point),
static_cast<INPUT*>(input_zero_point),
static_cast<const float*>(filter_scale),
static_cast<const FILTER*>(filter_zero_point),
static_cast<FILTER*>(filter_zero_point),
static_cast<const float*>(output_scale),
static_cast<OUTPUT*>(output_zero_point));
}
......
......@@ -74,12 +74,19 @@ namespace ngraph
size_t out_batch_axis,
size_t out_channel_axis,
const float* input_scale = nullptr,
const INPUT* input_zero_point = nullptr,
INPUT* input_zero_point = nullptr,
const float* filter_scale = nullptr,
const FILTER* filter_zero_point = nullptr,
FILTER* filter_zero_point = nullptr,
const float* output_scale = nullptr,
const OUTPUT* output_zero_point = nullptr)
OUTPUT* output_zero_point = nullptr)
{
bool is_quantized = false;
if (input_scale && input_zero_point && filter_scale && filter_zero_point &&
output_scale && output_zero_point)
{
is_quantized = true;
}
auto old_mode = std::fegetround();
std::fesetround(FE_TONEAREST);
// Comments throughout assume without loss of generality that:
......@@ -218,6 +225,11 @@ namespace ngraph
{
ACCUMULATION in_v = in[in_idx];
ACCUMULATION f_v = filter[filter_idx];
if (is_quantized)
{
in_v = in_v - *(input_zero_point);
f_v = f_v - *(filter_zero_point);
}
result += in_v * f_v;
in_idx += in_channel_stride;
filter_idx += filter_in_channel_stride;
......@@ -226,7 +238,16 @@ namespace ngraph
++in_it;
++filter_it;
}
out[out_transform.index(out_coord)] = result;
if (is_quantized)
{
float scale = ((*(input_scale)) * (*(filter_scale))) / (*(output_scale));
out[out_transform.index(out_coord)] =
static_cast<OUTPUT>((result * scale) + *(output_zero_point));
}
else
{
out[out_transform.index(out_coord)] = result;
}
}
std::fesetround(old_mode);
}
......@@ -247,11 +268,11 @@ namespace ngraph
const CoordinateDiff& in_pad_above,
const Strides& in_dilation,
const float* input_scale = nullptr,
const INPUT* input_zero_point = nullptr,
INPUT* input_zero_point = nullptr,
const float* filter_scale = nullptr,
const FILTER* filter_zero_point = nullptr,
FILTER* filter_zero_point = nullptr,
const float* output_scale = nullptr,
const OUTPUT* output_zero_point = nullptr)
OUTPUT* output_zero_point = nullptr)
{
general_convolution<INPUT, FILTER, OUTPUT, ACCUMULATION>(in,
......
......@@ -1429,7 +1429,6 @@ TEST(builder, dynamic_scaled_QD_with_bias)
read_vector<uint8_t>(f_requantize_relu_r));
}
#if 0
TEST(builder, scaled_QC_u8u8)
{
Shape shape_a{1, 1, 3, 4}; // input shape
......@@ -1439,7 +1438,10 @@ TEST(builder, scaled_QC_u8u8)
vector<uint8_t> b_data = {1, 2, 3, 4, 5, 0, 0, 1, 2}; //{0, -1, 0, -2, -3, 5, 0, 2, 1};
auto A = make_shared<op::Parameter>(element::u8, shape_a);
auto B = make_shared<op::Parameter>(element::u8, shape_b);
auto scale = op::Constant::create(element::f32, Shape{}, {2});
auto input_scale = op::Constant::create(element::f32, Shape{}, {2});
auto filter_scale = op::Constant::create(element::f32, Shape{}, {2});
auto output_scale = op::Constant::create(element::f32, Shape{}, {2});
auto u8_zero = op::Constant::create(element::u8, Shape{}, {0});
auto CV = make_shared<ngraph::op::QuantizedConvolution>(A,
B,
Strides{1, 1}, // move_strides
......@@ -1447,8 +1449,14 @@ TEST(builder, scaled_QC_u8u8)
CoordinateDiff{1, 1}, // below_pads
CoordinateDiff{1, 1}, // above_pads
Strides{1, 1}, // data_dilation
scale,
false);
input_scale,
u8_zero,
filter_scale,
u8_zero,
output_scale,
u8_zero,
element::u8,
AxisSet{});
auto f = make_shared<Function>(NodeVector{CV}, ParameterVector{A, B});
constant_fold(f);
......@@ -1475,7 +1483,6 @@ TEST(builder, scaled_QC_u8u8)
39 * 2} /*{1, 28, -3, 16, -7, -14, 3, -7, -3}*/),
read_vector<uint8_t>(result));
}
#endif
TEST(builder, scaled_QDot_u8u8)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment