Commit 10402fdd authored by nishant.b.patel's avatar nishant.b.patel

Change onnx importer quant conv op

parent c94347ca
...@@ -27,8 +27,8 @@ ...@@ -27,8 +27,8 @@
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/quantized_convolution.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/strides.hpp" #include "ngraph/strides.hpp"
...@@ -51,6 +51,13 @@ namespace ngraph ...@@ -51,6 +51,13 @@ namespace ngraph
std::shared_ptr<ngraph::Node> output_scale; std::shared_ptr<ngraph::Node> output_scale;
}; };
struct OpZeroPoint
{
std::shared_ptr<ngraph::Node> data_zero_point;
std::shared_ptr<ngraph::Node> filter_zero_point;
std::shared_ptr<ngraph::Node> output_zero_point;
};
std::shared_ptr<ngraph::Node> std::shared_ptr<ngraph::Node>
make_ng_quant_conv(const std::shared_ptr<ngraph::Node>& data, make_ng_quant_conv(const std::shared_ptr<ngraph::Node>& data,
const std::shared_ptr<ngraph::Node>& filters, const std::shared_ptr<ngraph::Node>& filters,
...@@ -61,8 +68,20 @@ namespace ngraph ...@@ -61,8 +68,20 @@ namespace ngraph
const Strides& data_dilations, const Strides& data_dilations,
int groups, int groups,
const OpScale& op_scale, const OpScale& op_scale,
const OpZeroPoint& op_zero_point,
const std::shared_ptr<ngraph::Node>& bias = nullptr) const std::shared_ptr<ngraph::Node>& bias = nullptr)
{ {
ngraph::element::Type output_type;
if (data->get_element_type() == ngraph::element::u8 &&
filters->get_element_type() == ngraph::element::i8)
{
output_type = ngraph::element::i8;
}
else if (data->get_element_type() == ngraph::element::u8 &&
filters->get_element_type() == ngraph::element::u8)
{
output_type = ngraph::element::u8;
}
if (groups > 1) if (groups > 1)
{ {
// Split one convolution op to N ops where N is the number of groups // Split one convolution op to N ops where N is the number of groups
...@@ -104,7 +123,7 @@ namespace ngraph ...@@ -104,7 +123,7 @@ namespace ngraph
else else
{ {
convolution_nodes.push_back( convolution_nodes.push_back(
ngraph::builder::quantization::QuantizedLinearConvolution( std::make_shared<ngraph::op::QuantizedConvolution>(
sliced_data, sliced_data,
sliced_filters, sliced_filters,
strides, strides,
...@@ -113,8 +132,15 @@ namespace ngraph ...@@ -113,8 +132,15 @@ namespace ngraph
padding_above, padding_above,
data_dilations, data_dilations,
op_scale.data_scale, op_scale.data_scale,
op_zero_point.data_zero_point,
op_scale.filter_scale, op_scale.filter_scale,
op_scale.output_scale)); op_zero_point.filter_zero_point,
op_scale.output_scale,
op_zero_point.output_zero_point,
output_type,
ngraph::AxisSet{},
ngraph::AxisSet{},
ngraph::AxisSet{}));
} }
} }
std::size_t concatenation_axis = 1; std::size_t concatenation_axis = 1;
...@@ -140,7 +166,7 @@ namespace ngraph ...@@ -140,7 +166,7 @@ namespace ngraph
} }
else else
{ {
return ngraph::builder::quantization::QuantizedLinearConvolution( return std::make_shared<ngraph::op::QuantizedConvolution>(
data, data,
filters, filters,
strides, strides,
...@@ -149,8 +175,15 @@ namespace ngraph ...@@ -149,8 +175,15 @@ namespace ngraph
padding_above, padding_above,
data_dilations, data_dilations,
op_scale.data_scale, op_scale.data_scale,
op_zero_point.data_zero_point,
op_scale.filter_scale, op_scale.filter_scale,
op_scale.output_scale); op_zero_point.filter_zero_point,
op_scale.output_scale,
op_zero_point.output_zero_point,
output_type,
ngraph::AxisSet{},
ngraph::AxisSet{},
ngraph::AxisSet{});
} }
} }
} }
...@@ -166,8 +199,11 @@ namespace ngraph ...@@ -166,8 +199,11 @@ namespace ngraph
int64_t groups{node.get_attribute_value<int64_t>("group", 1)}; int64_t groups{node.get_attribute_value<int64_t>("group", 1)};
auto data_scale = inputs.at(1); auto data_scale = inputs.at(1);
auto data_zero_point = inputs.at(2);
auto filters_scale = inputs.at(4); auto filters_scale = inputs.at(4);
auto filters_zero_point = inputs.at(5);
auto output_scale = inputs.at(6); auto output_scale = inputs.at(6);
auto output_zero_point = inputs.at(7);
ASSERT_VALID_ARGUMENT(node, ASSERT_VALID_ARGUMENT(node,
((groups >= 0) && (groups <= data->get_shape().at(1)) && ((groups >= 0) && (groups <= data->get_shape().at(1)) &&
...@@ -197,50 +233,36 @@ namespace ngraph ...@@ -197,50 +233,36 @@ namespace ngraph
if (inputs.size() == 9 && !inputs.at(8)->is_null()) if (inputs.size() == 9 && !inputs.at(8)->is_null())
{ {
auto bias = inputs.at(8); auto bias = inputs.at(8);
conv_node = conv_node = make_ng_quant_conv(
make_ng_quant_conv(data, data,
filters, filters,
strides, strides,
filter_dilations, filter_dilations,
padding_below, padding_below,
padding_above, padding_above,
data_dilations, data_dilations,
groups, groups,
OpScale{data_scale, filters_scale, output_scale}, OpScale{data_scale, filters_scale, output_scale},
bias); OpZeroPoint{data_zero_point, filters_zero_point, output_zero_point},
bias);
} }
else else
{ {
if (filters->get_element_type() == ngraph::element::u8 && groups == 1) conv_node = make_ng_quant_conv(
{ data,
conv_node = ngraph::builder::quantization::QuantizedLinearConvolution( filters,
data, strides,
filters, filter_dilations,
strides, padding_below,
filter_dilations, padding_above,
padding_below, data_dilations,
padding_above, groups,
data_dilations, OpScale{data_scale, filters_scale, output_scale},
data_scale, OpZeroPoint{data_zero_point, filters_zero_point, output_zero_point},
inputs.at(2), output_type,
filters_scale, ngraph::AxisSet{},
inputs.at(5), ngraph::AxisSet{},
output_scale, ngraph::AxisSet{});
inputs.at(7));
}
else
{
conv_node = make_ng_quant_conv(
data,
filters,
strides,
filter_dilations,
padding_below,
padding_above,
data_dilations,
groups,
OpScale{data_scale, filters_scale, output_scale});
}
} }
return {conv_node}; return {conv_node};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment