Commit 5f1a0679 authored by nishant.b.patel's avatar nishant.b.patel

Change onnx importer conv integer op

parent 4ec19b95
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
#include "op/conv_integer.hpp" #include "op/conv_integer.hpp"
#include "ngraph/builder/make_constant.hpp" #include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/quantization/quantized_linear_convolution.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp" #include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/utils/convpool.hpp" #include "ngraph/frontend/onnx_import/utils/convpool.hpp"
#include "ngraph/op/quantized_convolution.hpp"
using namespace ngraph::builder; using namespace ngraph::builder;
...@@ -48,35 +48,58 @@ namespace ngraph ...@@ -48,35 +48,58 @@ namespace ngraph
const auto& padding_below = paddings.first; const auto& padding_below = paddings.first;
const auto& padding_above = paddings.second; const auto& padding_above = paddings.second;
const Strides default_data_dilation_strides(input->get_shape().size() - 2, 1); const Strides default_data_dilation_strides(input->get_shape().size() - 2, 1);
auto scale_one = make_constant(element::f32, Shape{}, 1);
auto input_zero_point = make_constant(input->get_element_type(), Shape{}, 0);
auto filters_zero_point =
make_constant(filters->get_element_type(), Shape{}, 0);
auto output_zero_point = make_constant(output->get_element_type(), Shape{}, 0);
if (num_inputs == 2) if (num_inputs == 2)
{ {
return {quantization::QuantizedConvInteger(input, return {std::make_shared<ngraph::op::QuantizedConvolution>(
input,
filters, filters,
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
padding_below, padding_below,
padding_above, padding_above,
default_data_dilation_strides)}; default_data_dilation_strides,
scale_one,
input_zero_point,
scale_one,
filters_zero_point,
scale_one,
output_zero_point,
ngraph::element::i32,
ngraph::AxisSet{},
ngraph::AxisSet{},
ngraph::AxisSet{})};
} }
auto input_zero_point = inputs.at(2);
auto filters_zero_point =
make_constant(filters->get_element_type(), Shape{}, 0);
if (num_inputs == 4) if (num_inputs == 4)
{ {
input_zero_point = inputs.at(2);
filters_zero_point = inputs.at(3); filters_zero_point = inputs.at(3);
} }
return {quantization::QuantizedConvInteger(input, return {std::make_shared<ngraph::op::QuantizedConvolution>(
input,
filters, filters,
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
padding_below, padding_below,
padding_above, padding_above,
default_data_dilation_strides, default_data_dilation_strides,
scale_one,
input_zero_point, input_zero_point,
filters_zero_point)}; scale_one,
filters_zero_point,
scale_one,
output_zero_point,
ngraph::element::i32,
ngraph::AxisSet{},
ngraph::AxisSet{},
ngraph::AxisSet{})};
} }
} // namespace set_1 } // namespace set_1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment