Commit 10402fdd authored by nishant.b.patel's avatar nishant.b.patel

Change onnx importer quant conv op

parent c94347ca
......@@ -27,8 +27,8 @@
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/quantized_convolution.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/strides.hpp"
......@@ -51,6 +51,13 @@ namespace ngraph
std::shared_ptr<ngraph::Node> output_scale;
};
struct OpZeroPoint
{
std::shared_ptr<ngraph::Node> data_zero_point;
std::shared_ptr<ngraph::Node> filter_zero_point;
std::shared_ptr<ngraph::Node> output_zero_point;
};
std::shared_ptr<ngraph::Node>
make_ng_quant_conv(const std::shared_ptr<ngraph::Node>& data,
const std::shared_ptr<ngraph::Node>& filters,
......@@ -61,8 +68,20 @@ namespace ngraph
const Strides& data_dilations,
int groups,
const OpScale& op_scale,
const OpZeroPoint& op_zero_point,
const std::shared_ptr<ngraph::Node>& bias = nullptr)
{
ngraph::element::Type output_type;
if (data->get_element_type() == ngraph::element::u8 &&
filters->get_element_type() == ngraph::element::i8)
{
output_type = ngraph::element::i8;
}
else if (data->get_element_type() == ngraph::element::u8 &&
filters->get_element_type() == ngraph::element::u8)
{
output_type = ngraph::element::u8;
}
if (groups > 1)
{
// Split one convolution op to N ops where N is the number of groups
......@@ -104,7 +123,7 @@ namespace ngraph
else
{
convolution_nodes.push_back(
ngraph::builder::quantization::QuantizedLinearConvolution(
std::make_shared<ngraph::op::QuantizedConvolution>(
sliced_data,
sliced_filters,
strides,
......@@ -113,8 +132,15 @@ namespace ngraph
padding_above,
data_dilations,
op_scale.data_scale,
op_zero_point.data_zero_point,
op_scale.filter_scale,
op_scale.output_scale));
op_zero_point.filter_zero_point,
op_scale.output_scale,
op_zero_point.output_zero_point,
output_type,
ngraph::AxisSet{},
ngraph::AxisSet{},
ngraph::AxisSet{}));
}
}
std::size_t concatenation_axis = 1;
......@@ -140,7 +166,7 @@ namespace ngraph
}
else
{
return ngraph::builder::quantization::QuantizedLinearConvolution(
return std::make_shared<ngraph::op::QuantizedConvolution>(
data,
filters,
strides,
......@@ -149,8 +175,15 @@ namespace ngraph
padding_above,
data_dilations,
op_scale.data_scale,
op_zero_point.data_zero_point,
op_scale.filter_scale,
op_scale.output_scale);
op_zero_point.filter_zero_point,
op_scale.output_scale,
op_zero_point.output_zero_point,
output_type,
ngraph::AxisSet{},
ngraph::AxisSet{},
ngraph::AxisSet{});
}
}
}
......@@ -166,8 +199,11 @@ namespace ngraph
int64_t groups{node.get_attribute_value<int64_t>("group", 1)};
auto data_scale = inputs.at(1);
auto data_zero_point = inputs.at(2);
auto filters_scale = inputs.at(4);
auto filters_zero_point = inputs.at(5);
auto output_scale = inputs.at(6);
auto output_zero_point = inputs.at(7);
ASSERT_VALID_ARGUMENT(node,
((groups >= 0) && (groups <= data->get_shape().at(1)) &&
......@@ -197,50 +233,36 @@ namespace ngraph
if (inputs.size() == 9 && !inputs.at(8)->is_null())
{
auto bias = inputs.at(8);
conv_node =
make_ng_quant_conv(data,
filters,
strides,
filter_dilations,
padding_below,
padding_above,
data_dilations,
groups,
OpScale{data_scale, filters_scale, output_scale},
bias);
conv_node = make_ng_quant_conv(
data,
filters,
strides,
filter_dilations,
padding_below,
padding_above,
data_dilations,
groups,
OpScale{data_scale, filters_scale, output_scale},
OpZeroPoint{data_zero_point, filters_zero_point, output_zero_point},
bias);
}
else
{
if (filters->get_element_type() == ngraph::element::u8 && groups == 1)
{
conv_node = ngraph::builder::quantization::QuantizedLinearConvolution(
data,
filters,
strides,
filter_dilations,
padding_below,
padding_above,
data_dilations,
data_scale,
inputs.at(2),
filters_scale,
inputs.at(5),
output_scale,
inputs.at(7));
}
else
{
conv_node = make_ng_quant_conv(
data,
filters,
strides,
filter_dilations,
padding_below,
padding_above,
data_dilations,
groups,
OpScale{data_scale, filters_scale, output_scale});
}
conv_node = make_ng_quant_conv(
data,
filters,
strides,
filter_dilations,
padding_below,
padding_above,
data_dilations,
groups,
OpScale{data_scale, filters_scale, output_scale},
OpZeroPoint{data_zero_point, filters_zero_point, output_zero_point},
output_type,
ngraph::AxisSet{},
ngraph::AxisSet{},
ngraph::AxisSet{});
}
return {conv_node};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment