Commit 90410792 authored by nishant.b.patel's avatar nishant.b.patel

Add axes as an argument in the QC api

parent 954da585
......@@ -53,6 +53,8 @@ namespace ngraph
auto filter_zero_point =
op::Constant::create(filters->get_element_type(), Shape{}, {0});
AxisSet quantization_axes;
return make_shared<op::QuantizedConvolution>(
input,
filters,
......@@ -67,7 +69,8 @@ namespace ngraph
filter_zero_point,
output_scale,
filter_zero_point, // output type will be same as filter
output_type);
output_type,
quantization_axes);
}
}
}
......@@ -37,7 +37,8 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& input,
const std::shared_ptr<Node>& filter_zero_point,
const std::shared_ptr<Node>& output_scale,
const std::shared_ptr<Node>& output_zero_point,
const ngraph::element::Type& output_type)
const ngraph::element::Type& output_type,
const ngraph::AxisSet& axes)
: Op("QuantizedConvolution",
check_single_output_args({input,
filters,
......@@ -53,6 +54,7 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& input,
, m_padding_above(padding_above)
, m_data_dilation_strides(data_dilation_strides)
, m_output_type(output_type)
, m_axes(axes)
{
constructor_validate_and_infer_types();
}
......@@ -162,5 +164,6 @@ shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector&
new_args.at(5),
new_args.at(6),
new_args.at(7),
m_output_type));
m_output_type,
m_axes));
}
......@@ -56,7 +56,8 @@ namespace ngraph
const std::shared_ptr<Node>& filter_zero_point,
const std::shared_ptr<Node>& output_scale,
const std::shared_ptr<Node>& output_zero_point,
const ngraph::element::Type& output_type);
const ngraph::element::Type& output_type,
const ngraph::AxisSet& axes);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
const CoordinateDiff& get_padding_below() const { return m_padding_below; }
......@@ -65,6 +66,7 @@ namespace ngraph
std::shared_ptr<Node> get_filters() { return get_argument(1); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); }
const ngraph::element::Type& get_output_type() const { return m_output_type; }
const ngraph::AxisSet& get_axes() const { return m_axes; }
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -76,6 +78,7 @@ namespace ngraph
CoordinateDiff m_padding_above;
Strides m_data_dilation_strides;
ngraph::element::Type m_output_type;
ngraph::AxisSet m_axes;
};
}
}
......@@ -88,20 +88,12 @@ namespace ngraph
{
vector<float> dyn_scales;
// Calculate the requantization scale
for (size_t i = 0; i < scales_size; i++)
{
dyn_scales.push_back(
*(static_cast<float*>(ctx->buffer_data[arg2_buffer_index]) +
i) *
*(static_cast<float*>(ctx->buffer_data[arg4_buffer_index]) +
i) /
*(static_cast<float*>(ctx->buffer_data[arg6_buffer_index]) +
i));
}
std::cout << "SCALE " << dyn_scales[0] << std::endl;
*(static_cast<float*>(ctx->buffer_data[arg2_buffer_index])) *
*(static_cast<float*>(ctx->buffer_data[arg4_buffer_index])) /
*(static_cast<float*>(ctx->buffer_data[arg6_buffer_index])));
// use conv channelwise (dim 1, mask=2^1) if dyn_scales is a vector
const int mask = scales_size == 1 ? 0 : 2;
conv_attr.set_output_scales(mask, dyn_scales);
conv_attr.set_output_scales(0, dyn_scales);
mkldnn_emitter->build_convolution_forward<false>(
ctx->mkldnn_primitives,
conv_desc,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment