Commit 4b55a21d authored by nishant.b.patel's avatar nishant.b.patel

change reference convolution api

parent ed701920
...@@ -125,7 +125,12 @@ namespace ngraph ...@@ -125,7 +125,12 @@ namespace ngraph
padding_below, padding_below,
padding_above, padding_above,
data_dilation_strides, data_dilation_strides,
1.0f); nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -119,6 +119,13 @@ namespace ngraph ...@@ -119,6 +119,13 @@ namespace ngraph
kernel; kernel;
kernel = runtime::cpu::kernel::convolution<uint8_t, uint8_t, uint8_t, int32_t>; kernel = runtime::cpu::kernel::convolution<uint8_t, uint8_t, uint8_t, int32_t>;
auto arg3_buffer_index =
external_function->get_buffer_index(args[3].get_name()); // input scale
auto arg5_buffer_index =
external_function->get_buffer_index(args[5].get_name()); // filter scale
auto arg7_buffer_index =
external_function->get_buffer_index(args[7].get_name()); // output scale
auto window_movement_strides = qconvolution->get_window_movement_strides(); auto window_movement_strides = qconvolution->get_window_movement_strides();
auto window_dilation_strides = qconvolution->get_window_dilation_strides(); auto window_dilation_strides = qconvolution->get_window_dilation_strides();
auto padding_below = qconvolution->get_padding_below(); auto padding_below = qconvolution->get_padding_below();
...@@ -132,6 +139,11 @@ namespace ngraph ...@@ -132,6 +139,11 @@ namespace ngraph
arg0_buffer_index, arg0_buffer_index,
arg1_buffer_index, arg1_buffer_index,
arg2_buffer_index, arg2_buffer_index,
arg3_buffer_index,
arg4_buffer_index,
arg5_buffer_index,
arg6_buffer_index,
arg7_buffer_index,
out0_buffer_index, out0_buffer_index,
result_shape, result_shape,
window_movement_strides, window_movement_strides,
...@@ -156,7 +168,12 @@ namespace ngraph ...@@ -156,7 +168,12 @@ namespace ngraph
padding_below, padding_below,
padding_above, padding_above,
data_dilation_strides, data_dilation_strides,
dyn_scales[0]); ctx->buffer_data[arg2_buffer_index],
ctx->buffer_data[arg3_buffer_index],
ctx->buffer_data[arg4_buffer_index],
ctx->buffer_data[arg5_buffer_index],
ctx->buffer_data[arg6_buffer_index],
ctx->buffer_data[arg7_buffer_index]);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -43,7 +43,12 @@ namespace ngraph ...@@ -43,7 +43,12 @@ namespace ngraph
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const float requant_scale) void* input_scale = nullptr,
void* input_zero_point = nullptr,
void* filter_scale = nullptr,
void* filter_zero_point = nullptr,
void* output_scale = nullptr,
void* output_zero_point = nullptr)
{ {
reference::convolution<INPUT, FILTER, OUTPUT, ACCUMULATION>( reference::convolution<INPUT, FILTER, OUTPUT, ACCUMULATION>(
static_cast<const INPUT*>(input0), static_cast<const INPUT*>(input0),
...@@ -57,7 +62,12 @@ namespace ngraph ...@@ -57,7 +62,12 @@ namespace ngraph
padding_below, padding_below,
padding_above, padding_above,
data_dilation_strides, data_dilation_strides,
requant_scale); static_cast<const float*>(input_scale),
static_cast<const INPUT*>(input_zero_point),
static_cast<const float*>(filter_scale),
static_cast<const FILTER*>(filter_zero_point),
static_cast<const float*>(output_scale),
static_cast<OUTPUT*>(output_zero_point));
} }
template <typename ElementType> template <typename ElementType>
......
...@@ -73,7 +73,12 @@ namespace ngraph ...@@ -73,7 +73,12 @@ namespace ngraph
size_t filter_in_channel_axis, size_t filter_in_channel_axis,
size_t out_batch_axis, size_t out_batch_axis,
size_t out_channel_axis, size_t out_channel_axis,
const float requant_scale = 1.0f) const float* input_scale = nullptr,
const INPUT* input_zero_point = nullptr,
const float* filter_scale = nullptr,
const FILTER* filter_zero_point = nullptr,
const float* output_scale = nullptr,
const OUTPUT* output_zero_point = nullptr)
{ {
auto old_mode = std::fegetround(); auto old_mode = std::fegetround();
std::fesetround(FE_TONEAREST); std::fesetround(FE_TONEAREST);
...@@ -221,8 +226,7 @@ namespace ngraph ...@@ -221,8 +226,7 @@ namespace ngraph
++in_it; ++in_it;
++filter_it; ++filter_it;
} }
out[out_transform.index(out_coord)] = out[out_transform.index(out_coord)] = result;
static_cast<OUTPUT>(result * requant_scale);
} }
std::fesetround(old_mode); std::fesetround(old_mode);
} }
...@@ -242,7 +246,12 @@ namespace ngraph ...@@ -242,7 +246,12 @@ namespace ngraph
const CoordinateDiff& in_pad_below, const CoordinateDiff& in_pad_below,
const CoordinateDiff& in_pad_above, const CoordinateDiff& in_pad_above,
const Strides& in_dilation, const Strides& in_dilation,
const float requant_scale = 1.0f) const float* input_scale = nullptr,
const INPUT* input_zero_point = nullptr,
const float* filter_scale = nullptr,
const FILTER* filter_zero_point = nullptr,
const float* output_scale = nullptr,
const OUTPUT* output_zero_point = nullptr)
{ {
general_convolution<INPUT, FILTER, OUTPUT, ACCUMULATION>(in, general_convolution<INPUT, FILTER, OUTPUT, ACCUMULATION>(in,
...@@ -262,7 +271,12 @@ namespace ngraph ...@@ -262,7 +271,12 @@ namespace ngraph
1, 1,
0, 0,
1, 1,
requant_scale); input_scale,
input_zero_point,
filter_scale,
filter_zero_point,
output_scale,
output_zero_point);
} }
template <typename INPUT, template <typename INPUT,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment