Commit 90410792 authored by nishant.b.patel's avatar nishant.b.patel

Add axes as an argument in the QC api

parent 954da585
...@@ -53,6 +53,8 @@ namespace ngraph ...@@ -53,6 +53,8 @@ namespace ngraph
auto filter_zero_point = auto filter_zero_point =
op::Constant::create(filters->get_element_type(), Shape{}, {0}); op::Constant::create(filters->get_element_type(), Shape{}, {0});
AxisSet quantization_axes;
return make_shared<op::QuantizedConvolution>( return make_shared<op::QuantizedConvolution>(
input, input,
filters, filters,
...@@ -67,7 +69,8 @@ namespace ngraph ...@@ -67,7 +69,8 @@ namespace ngraph
filter_zero_point, filter_zero_point,
output_scale, output_scale,
filter_zero_point, // output type will be same as filter filter_zero_point, // output type will be same as filter
output_type); output_type,
quantization_axes);
} }
} }
} }
...@@ -37,7 +37,8 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& input, ...@@ -37,7 +37,8 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& input,
const std::shared_ptr<Node>& filter_zero_point, const std::shared_ptr<Node>& filter_zero_point,
const std::shared_ptr<Node>& output_scale, const std::shared_ptr<Node>& output_scale,
const std::shared_ptr<Node>& output_zero_point, const std::shared_ptr<Node>& output_zero_point,
const ngraph::element::Type& output_type) const ngraph::element::Type& output_type,
const ngraph::AxisSet& axes)
: Op("QuantizedConvolution", : Op("QuantizedConvolution",
check_single_output_args({input, check_single_output_args({input,
filters, filters,
...@@ -53,6 +54,7 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& input, ...@@ -53,6 +54,7 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& input,
, m_padding_above(padding_above) , m_padding_above(padding_above)
, m_data_dilation_strides(data_dilation_strides) , m_data_dilation_strides(data_dilation_strides)
, m_output_type(output_type) , m_output_type(output_type)
, m_axes(axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -162,5 +164,6 @@ shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector& ...@@ -162,5 +164,6 @@ shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector&
new_args.at(5), new_args.at(5),
new_args.at(6), new_args.at(6),
new_args.at(7), new_args.at(7),
m_output_type)); m_output_type,
m_axes));
} }
...@@ -56,7 +56,8 @@ namespace ngraph ...@@ -56,7 +56,8 @@ namespace ngraph
const std::shared_ptr<Node>& filter_zero_point, const std::shared_ptr<Node>& filter_zero_point,
const std::shared_ptr<Node>& output_scale, const std::shared_ptr<Node>& output_scale,
const std::shared_ptr<Node>& output_zero_point, const std::shared_ptr<Node>& output_zero_point,
const ngraph::element::Type& output_type); const ngraph::element::Type& output_type,
const ngraph::AxisSet& axes);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; } const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
...@@ -65,6 +66,7 @@ namespace ngraph ...@@ -65,6 +66,7 @@ namespace ngraph
std::shared_ptr<Node> get_filters() { return get_argument(1); } std::shared_ptr<Node> get_filters() { return get_argument(1); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); } std::shared_ptr<Node> get_data_batch() { return get_argument(0); }
const ngraph::element::Type& get_output_type() const { return m_output_type; } const ngraph::element::Type& get_output_type() const { return m_output_type; }
const ngraph::AxisSet& get_axes() const { return m_axes; }
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -76,6 +78,7 @@ namespace ngraph ...@@ -76,6 +78,7 @@ namespace ngraph
CoordinateDiff m_padding_above; CoordinateDiff m_padding_above;
Strides m_data_dilation_strides; Strides m_data_dilation_strides;
ngraph::element::Type m_output_type; ngraph::element::Type m_output_type;
ngraph::AxisSet m_axes;
}; };
} }
} }
...@@ -88,20 +88,12 @@ namespace ngraph ...@@ -88,20 +88,12 @@ namespace ngraph
{ {
vector<float> dyn_scales; vector<float> dyn_scales;
// Calculate the requantization scale // Calculate the requantization scale
for (size_t i = 0; i < scales_size; i++) dyn_scales.push_back(
{ *(static_cast<float*>(ctx->buffer_data[arg2_buffer_index])) *
dyn_scales.push_back( *(static_cast<float*>(ctx->buffer_data[arg4_buffer_index])) /
*(static_cast<float*>(ctx->buffer_data[arg2_buffer_index]) + *(static_cast<float*>(ctx->buffer_data[arg6_buffer_index])));
i) *
*(static_cast<float*>(ctx->buffer_data[arg4_buffer_index]) +
i) /
*(static_cast<float*>(ctx->buffer_data[arg6_buffer_index]) +
i));
}
std::cout << "SCALE " << dyn_scales[0] << std::endl;
// use conv channelwise (dim 1, mask=2^1) if dyn_scales is a vector // use conv channelwise (dim 1, mask=2^1) if dyn_scales is a vector
const int mask = scales_size == 1 ? 0 : 2; conv_attr.set_output_scales(0, dyn_scales);
conv_attr.set_output_scales(mask, dyn_scales);
mkldnn_emitter->build_convolution_forward<false>( mkldnn_emitter->build_convolution_forward<false>(
ctx->mkldnn_primitives, ctx->mkldnn_primitives,
conv_desc, conv_desc,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment