Commit e568de2e authored by nishant.b.patel's avatar nishant.b.patel

Change the API to take input_axes. filter_axes & output_axes

parent c75f7db3
......@@ -36,52 +36,6 @@ namespace ngraph
{
namespace quantization
{
// TODO: this codes is falling back to fp32 convolution
// need to make this the primary builder which means
// 1) add support for zero point in QuantizeConvolution op API
// 2) add QuantizedConvolution reference kernel, including zero point
shared_ptr<Node> QuantizedLinearConvolution(const shared_ptr<Node>& input,
const shared_ptr<Node>& filter,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const shared_ptr<Node>& input_scale,
const shared_ptr<Node>& input_zero_point,
const shared_ptr<Node>& filter_scale,
const shared_ptr<Node>& filter_zero_point,
const shared_ptr<Node>& output_scale,
const shared_ptr<Node>& output_zero_point)
{
AxisSet axes;
auto dq_input = make_shared<op::Dequantize>(
input, input_scale, input_zero_point, input_scale->get_element_type(), axes);
auto dq_filter = make_shared<op::Dequantize>(filter,
filter_scale,
filter_zero_point,
filter_scale->get_element_type(),
axes);
auto convolution = make_shared<op::Convolution>(dq_input,
dq_filter,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
auto q_convolution =
make_shared<op::Quantize>(convolution,
output_scale,
output_zero_point,
output_zero_point->get_element_type(),
axes,
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
return move(q_convolution);
}
shared_ptr<Node> QuantizedLinearConvolutionBias(const shared_ptr<Node>& input,
const shared_ptr<Node>& filter,
const shared_ptr<Node>& bias,
......
......@@ -25,21 +25,6 @@ namespace ngraph
{
namespace quantization
{
std::shared_ptr<Node>
QuantizedLinearConvolution(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& filter,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node>& input_scale,
const std::shared_ptr<Node>& input_zero_point,
const std::shared_ptr<Node>& filter_scale,
const std::shared_ptr<Node>& filter_zero_point,
const std::shared_ptr<Node>& output_scale,
const std::shared_ptr<Node>& output_zero_point);
std::shared_ptr<Node>
QuantizedLinearConvolutionBias(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& filter,
......
......@@ -40,7 +40,9 @@ namespace ngraph
const shared_ptr<Node>& min_output,
const shared_ptr<Node>& max_output,
const ngraph::element::Type& output_type,
const ngraph::AxisSet& axes)
const ngraph::AxisSet& input_axes,
const ngraph::AxisSet& filter_axes,
const ngraph::AxisSet& output_axes)
{
auto input_scale =
quantization_scale::get_scale(min_input, max_input, input->get_element_type());
......@@ -69,7 +71,9 @@ namespace ngraph
output_scale,
filter_zero_point, // output type will be same as filter
output_type,
axes);
input_axes,
filter_axes,
output_axes);
}
}
}
......@@ -39,6 +39,8 @@ namespace ngraph
const std::shared_ptr<Node>& min_output,
const std::shared_ptr<Node>& max_output,
const ngraph::element::Type& output_type,
const ngraph::AxisSet& axes);
const ngraph::AxisSet& input_axes,
const ngraph::AxisSet& filter_axes,
const ngraph::AxisSet& output_axes);
}
}
......@@ -38,7 +38,9 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& input,
const std::shared_ptr<Node>& output_scale,
const std::shared_ptr<Node>& output_zero_point,
const ngraph::element::Type& output_type,
const ngraph::AxisSet& axes)
const ngraph::AxisSet& input_axes,
const ngraph::AxisSet& filter_axes,
const ngraph::AxisSet& output_axes)
: Op("QuantizedConvolution",
check_single_output_args({input,
filters,
......@@ -54,7 +56,9 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& input,
, m_padding_above(padding_above)
, m_data_dilation_strides(data_dilation_strides)
, m_output_type(output_type)
, m_axes(axes)
, m_input_axes(input_axes)
, m_filter_axes(filter_axes)
, m_output_axes(output_axes)
{
constructor_validate_and_infer_types();
}
......@@ -165,5 +169,7 @@ shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector&
new_args.at(6),
new_args.at(7),
m_output_type,
m_axes));
m_input_axes,
m_filter_axes,
m_output_axes));
}
......@@ -57,7 +57,9 @@ namespace ngraph
const std::shared_ptr<Node>& output_scale,
const std::shared_ptr<Node>& output_zero_point,
const ngraph::element::Type& output_type,
const ngraph::AxisSet& axes);
const ngraph::AxisSet& input_axes,
const ngraph::AxisSet& filter_axes,
const ngraph::AxisSet& output_axes);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
const CoordinateDiff& get_padding_below() const { return m_padding_below; }
......@@ -66,7 +68,9 @@ namespace ngraph
std::shared_ptr<Node> get_filters() { return get_argument(1); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); }
const ngraph::element::Type& get_output_type() const { return m_output_type; }
const ngraph::AxisSet& get_axes() const { return m_axes; }
const ngraph::AxisSet& get_input_axes() const { return m_input_axes; }
const ngraph::AxisSet& get_filter_axes() const { return m_filter_axes; }
const ngraph::AxisSet& get_output_axes() const { return m_output_axes; }
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -78,7 +82,9 @@ namespace ngraph
CoordinateDiff m_padding_above;
Strides m_data_dilation_strides;
ngraph::element::Type m_output_type;
ngraph::AxisSet m_axes;
ngraph::AxisSet m_input_axes;
ngraph::AxisSet m_filter_axes;
ngraph::AxisSet m_output_axes;
};
}
}
......@@ -1823,6 +1823,8 @@ void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconv_relu(bool with_
output_scale,
int8_zero,
element::i8,
AxisSet{},
AxisSet{},
AxisSet{});
}
auto dq =
......
......@@ -1350,7 +1350,9 @@ static shared_ptr<ngraph::Function>
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides = node_js["data_dilation_strides"];
auto output_type = read_element_type(node_js.at("output_type"));
auto axes = node_js.at("axes").get<set<size_t>>();
auto input_axes = node_js.at("input_axes").get<set<size_t>>();
auto filter_axes = node_js.at("filter_axes").get<set<size_t>>();
auto output_axes = node_js.at("output_axes").get<set<size_t>>();
node = make_shared<op::QuantizedConvolution>(
args[0],
args[1],
......@@ -1366,7 +1368,9 @@ static shared_ptr<ngraph::Function>
args[6],
args[7],
output_type,
axes);
input_axes,
filter_axes,
output_axes);
break;
}
case OP_TYPEID::QuantizedDotBias: { break;
......@@ -2298,7 +2302,9 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_above"] = tmp->get_padding_above();
node["data_dilation_strides"] = tmp->get_data_dilation_strides();
node["output_type"] = write_element_type(tmp->get_element_type());
node["axes"] = tmp->get_axes();
node["input_axes"] = tmp->get_input_axes();
node["filter_axes"] = tmp->get_filter_axes();
node["output_axes"] = tmp->get_output_axes();
break;
}
case OP_TYPEID::QuantizedDotBias: { break;
......
......@@ -7520,6 +7520,8 @@ NGRAPH_TEST(${BACKEND_NAME}, quantized_convolution)
G,
H,
element::i8,
AxisSet{},
AxisSet{},
AxisSet{});
auto f = make_shared<Function>(NodeVector{CV}, ParameterVector{A, B, C, D, E, F, G, H});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
......@@ -7653,6 +7655,8 @@ NGRAPH_TEST(${BACKEND_NAME}, quantized_conv_non_zero_zero_point)
result_scale,
result_zero_point,
element::u8,
AxisSet{},
AxisSet{},
AxisSet{});
auto f = make_shared<Function>(NodeVector{CV}, ParameterVector{A, B});
// Create some tensors for input/output
......
This diff is collapsed.
......@@ -3183,6 +3183,8 @@ TEST(cpu_quant_fusion, qconv_relu)
output_scale,
int8_zero,
element::i8,
AxisSet{},
AxisSet{},
AxisSet{});
auto dq = std::make_shared<op::Dequantize>(
conv, output_scale, int8_zero, element::f32, AxisSet{});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment