Commit 4b55a21d authored by nishant.b.patel's avatar nishant.b.patel

change reference convolution api

parent ed701920
......@@ -125,7 +125,12 @@ namespace ngraph
padding_below,
padding_above,
data_dilation_strides,
1.0f);
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr);
};
functors.emplace_back(functor);
}
......
......@@ -119,6 +119,13 @@ namespace ngraph
kernel;
kernel = runtime::cpu::kernel::convolution<uint8_t, uint8_t, uint8_t, int32_t>;
auto arg3_buffer_index =
external_function->get_buffer_index(args[3].get_name()); // input scale
auto arg5_buffer_index =
external_function->get_buffer_index(args[5].get_name()); // filter scale
auto arg7_buffer_index =
external_function->get_buffer_index(args[7].get_name()); // output scale
auto window_movement_strides = qconvolution->get_window_movement_strides();
auto window_dilation_strides = qconvolution->get_window_dilation_strides();
auto padding_below = qconvolution->get_padding_below();
......@@ -132,6 +139,11 @@ namespace ngraph
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
arg3_buffer_index,
arg4_buffer_index,
arg5_buffer_index,
arg6_buffer_index,
arg7_buffer_index,
out0_buffer_index,
result_shape,
window_movement_strides,
......@@ -156,7 +168,12 @@ namespace ngraph
padding_below,
padding_above,
data_dilation_strides,
dyn_scales[0]);
ctx->buffer_data[arg2_buffer_index],
ctx->buffer_data[arg3_buffer_index],
ctx->buffer_data[arg4_buffer_index],
ctx->buffer_data[arg5_buffer_index],
ctx->buffer_data[arg6_buffer_index],
ctx->buffer_data[arg7_buffer_index]);
};
functors.emplace_back(functor);
}
......
......@@ -43,7 +43,12 @@ namespace ngraph
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const float requant_scale)
void* input_scale = nullptr,
void* input_zero_point = nullptr,
void* filter_scale = nullptr,
void* filter_zero_point = nullptr,
void* output_scale = nullptr,
void* output_zero_point = nullptr)
{
reference::convolution<INPUT, FILTER, OUTPUT, ACCUMULATION>(
static_cast<const INPUT*>(input0),
......@@ -57,7 +62,12 @@ namespace ngraph
padding_below,
padding_above,
data_dilation_strides,
requant_scale);
static_cast<const float*>(input_scale),
static_cast<const INPUT*>(input_zero_point),
static_cast<const float*>(filter_scale),
static_cast<const FILTER*>(filter_zero_point),
static_cast<const float*>(output_scale),
static_cast<OUTPUT*>(output_zero_point));
}
template <typename ElementType>
......
......@@ -73,7 +73,12 @@ namespace ngraph
size_t filter_in_channel_axis,
size_t out_batch_axis,
size_t out_channel_axis,
const float requant_scale = 1.0f)
const float* input_scale = nullptr,
const INPUT* input_zero_point = nullptr,
const float* filter_scale = nullptr,
const FILTER* filter_zero_point = nullptr,
const float* output_scale = nullptr,
const OUTPUT* output_zero_point = nullptr)
{
auto old_mode = std::fegetround();
std::fesetround(FE_TONEAREST);
......@@ -221,8 +226,7 @@ namespace ngraph
++in_it;
++filter_it;
}
out[out_transform.index(out_coord)] =
static_cast<OUTPUT>(result * requant_scale);
out[out_transform.index(out_coord)] = result;
}
std::fesetround(old_mode);
}
......@@ -242,7 +246,12 @@ namespace ngraph
const CoordinateDiff& in_pad_below,
const CoordinateDiff& in_pad_above,
const Strides& in_dilation,
const float requant_scale = 1.0f)
const float* input_scale = nullptr,
const INPUT* input_zero_point = nullptr,
const float* filter_scale = nullptr,
const FILTER* filter_zero_point = nullptr,
const float* output_scale = nullptr,
const OUTPUT* output_zero_point = nullptr)
{
general_convolution<INPUT, FILTER, OUTPUT, ACCUMULATION>(in,
......@@ -262,7 +271,12 @@ namespace ngraph
1,
0,
1,
requant_scale);
input_scale,
input_zero_point,
filter_scale,
filter_zero_point,
output_scale,
output_zero_point);
}
template <typename INPUT,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment