Commit e568de2e authored by nishant.b.patel's avatar nishant.b.patel

Change the API to take input_axes. filter_axes & output_axes

parent c75f7db3
...@@ -36,52 +36,6 @@ namespace ngraph ...@@ -36,52 +36,6 @@ namespace ngraph
{ {
namespace quantization namespace quantization
{ {
// TODO: this codes is falling back to fp32 convolution
// need to make this the primary builder which means
// 1) add support for zero point in QuantizeConvolution op API
// 2) add QuantizedConvolution reference kernel, including zero point
shared_ptr<Node> QuantizedLinearConvolution(const shared_ptr<Node>& input,
const shared_ptr<Node>& filter,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const shared_ptr<Node>& input_scale,
const shared_ptr<Node>& input_zero_point,
const shared_ptr<Node>& filter_scale,
const shared_ptr<Node>& filter_zero_point,
const shared_ptr<Node>& output_scale,
const shared_ptr<Node>& output_zero_point)
{
AxisSet axes;
auto dq_input = make_shared<op::Dequantize>(
input, input_scale, input_zero_point, input_scale->get_element_type(), axes);
auto dq_filter = make_shared<op::Dequantize>(filter,
filter_scale,
filter_zero_point,
filter_scale->get_element_type(),
axes);
auto convolution = make_shared<op::Convolution>(dq_input,
dq_filter,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
auto q_convolution =
make_shared<op::Quantize>(convolution,
output_scale,
output_zero_point,
output_zero_point->get_element_type(),
axes,
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
return move(q_convolution);
}
shared_ptr<Node> QuantizedLinearConvolutionBias(const shared_ptr<Node>& input, shared_ptr<Node> QuantizedLinearConvolutionBias(const shared_ptr<Node>& input,
const shared_ptr<Node>& filter, const shared_ptr<Node>& filter,
const shared_ptr<Node>& bias, const shared_ptr<Node>& bias,
......
...@@ -25,21 +25,6 @@ namespace ngraph ...@@ -25,21 +25,6 @@ namespace ngraph
{ {
namespace quantization namespace quantization
{ {
std::shared_ptr<Node>
QuantizedLinearConvolution(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& filter,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node>& input_scale,
const std::shared_ptr<Node>& input_zero_point,
const std::shared_ptr<Node>& filter_scale,
const std::shared_ptr<Node>& filter_zero_point,
const std::shared_ptr<Node>& output_scale,
const std::shared_ptr<Node>& output_zero_point);
std::shared_ptr<Node> std::shared_ptr<Node>
QuantizedLinearConvolutionBias(const std::shared_ptr<Node>& input, QuantizedLinearConvolutionBias(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& filter, const std::shared_ptr<Node>& filter,
......
...@@ -40,7 +40,9 @@ namespace ngraph ...@@ -40,7 +40,9 @@ namespace ngraph
const shared_ptr<Node>& min_output, const shared_ptr<Node>& min_output,
const shared_ptr<Node>& max_output, const shared_ptr<Node>& max_output,
const ngraph::element::Type& output_type, const ngraph::element::Type& output_type,
const ngraph::AxisSet& axes) const ngraph::AxisSet& input_axes,
const ngraph::AxisSet& filter_axes,
const ngraph::AxisSet& output_axes)
{ {
auto input_scale = auto input_scale =
quantization_scale::get_scale(min_input, max_input, input->get_element_type()); quantization_scale::get_scale(min_input, max_input, input->get_element_type());
...@@ -69,7 +71,9 @@ namespace ngraph ...@@ -69,7 +71,9 @@ namespace ngraph
output_scale, output_scale,
filter_zero_point, // output type will be same as filter filter_zero_point, // output type will be same as filter
output_type, output_type,
axes); input_axes,
filter_axes,
output_axes);
} }
} }
} }
...@@ -39,6 +39,8 @@ namespace ngraph ...@@ -39,6 +39,8 @@ namespace ngraph
const std::shared_ptr<Node>& min_output, const std::shared_ptr<Node>& min_output,
const std::shared_ptr<Node>& max_output, const std::shared_ptr<Node>& max_output,
const ngraph::element::Type& output_type, const ngraph::element::Type& output_type,
const ngraph::AxisSet& axes); const ngraph::AxisSet& input_axes,
const ngraph::AxisSet& filter_axes,
const ngraph::AxisSet& output_axes);
} }
} }
...@@ -38,7 +38,9 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& input, ...@@ -38,7 +38,9 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& input,
const std::shared_ptr<Node>& output_scale, const std::shared_ptr<Node>& output_scale,
const std::shared_ptr<Node>& output_zero_point, const std::shared_ptr<Node>& output_zero_point,
const ngraph::element::Type& output_type, const ngraph::element::Type& output_type,
const ngraph::AxisSet& axes) const ngraph::AxisSet& input_axes,
const ngraph::AxisSet& filter_axes,
const ngraph::AxisSet& output_axes)
: Op("QuantizedConvolution", : Op("QuantizedConvolution",
check_single_output_args({input, check_single_output_args({input,
filters, filters,
...@@ -54,7 +56,9 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& input, ...@@ -54,7 +56,9 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& input,
, m_padding_above(padding_above) , m_padding_above(padding_above)
, m_data_dilation_strides(data_dilation_strides) , m_data_dilation_strides(data_dilation_strides)
, m_output_type(output_type) , m_output_type(output_type)
, m_axes(axes) , m_input_axes(input_axes)
, m_filter_axes(filter_axes)
, m_output_axes(output_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -165,5 +169,7 @@ shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector& ...@@ -165,5 +169,7 @@ shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector&
new_args.at(6), new_args.at(6),
new_args.at(7), new_args.at(7),
m_output_type, m_output_type,
m_axes)); m_input_axes,
m_filter_axes,
m_output_axes));
} }
...@@ -57,7 +57,9 @@ namespace ngraph ...@@ -57,7 +57,9 @@ namespace ngraph
const std::shared_ptr<Node>& output_scale, const std::shared_ptr<Node>& output_scale,
const std::shared_ptr<Node>& output_zero_point, const std::shared_ptr<Node>& output_zero_point,
const ngraph::element::Type& output_type, const ngraph::element::Type& output_type,
const ngraph::AxisSet& axes); const ngraph::AxisSet& input_axes,
const ngraph::AxisSet& filter_axes,
const ngraph::AxisSet& output_axes);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; } const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
...@@ -66,7 +68,9 @@ namespace ngraph ...@@ -66,7 +68,9 @@ namespace ngraph
std::shared_ptr<Node> get_filters() { return get_argument(1); } std::shared_ptr<Node> get_filters() { return get_argument(1); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); } std::shared_ptr<Node> get_data_batch() { return get_argument(0); }
const ngraph::element::Type& get_output_type() const { return m_output_type; } const ngraph::element::Type& get_output_type() const { return m_output_type; }
const ngraph::AxisSet& get_axes() const { return m_axes; } const ngraph::AxisSet& get_input_axes() const { return m_input_axes; }
const ngraph::AxisSet& get_filter_axes() const { return m_filter_axes; }
const ngraph::AxisSet& get_output_axes() const { return m_output_axes; }
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -78,7 +82,9 @@ namespace ngraph ...@@ -78,7 +82,9 @@ namespace ngraph
CoordinateDiff m_padding_above; CoordinateDiff m_padding_above;
Strides m_data_dilation_strides; Strides m_data_dilation_strides;
ngraph::element::Type m_output_type; ngraph::element::Type m_output_type;
ngraph::AxisSet m_axes; ngraph::AxisSet m_input_axes;
ngraph::AxisSet m_filter_axes;
ngraph::AxisSet m_output_axes;
}; };
} }
} }
...@@ -1823,6 +1823,8 @@ void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconv_relu(bool with_ ...@@ -1823,6 +1823,8 @@ void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconv_relu(bool with_
output_scale, output_scale,
int8_zero, int8_zero,
element::i8, element::i8,
AxisSet{},
AxisSet{},
AxisSet{}); AxisSet{});
} }
auto dq = auto dq =
......
...@@ -1350,7 +1350,9 @@ static shared_ptr<ngraph::Function> ...@@ -1350,7 +1350,9 @@ static shared_ptr<ngraph::Function>
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>(); auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides = node_js["data_dilation_strides"]; auto data_dilation_strides = node_js["data_dilation_strides"];
auto output_type = read_element_type(node_js.at("output_type")); auto output_type = read_element_type(node_js.at("output_type"));
auto axes = node_js.at("axes").get<set<size_t>>(); auto input_axes = node_js.at("input_axes").get<set<size_t>>();
auto filter_axes = node_js.at("filter_axes").get<set<size_t>>();
auto output_axes = node_js.at("output_axes").get<set<size_t>>();
node = make_shared<op::QuantizedConvolution>( node = make_shared<op::QuantizedConvolution>(
args[0], args[0],
args[1], args[1],
...@@ -1366,7 +1368,9 @@ static shared_ptr<ngraph::Function> ...@@ -1366,7 +1368,9 @@ static shared_ptr<ngraph::Function>
args[6], args[6],
args[7], args[7],
output_type, output_type,
axes); input_axes,
filter_axes,
output_axes);
break; break;
} }
case OP_TYPEID::QuantizedDotBias: { break; case OP_TYPEID::QuantizedDotBias: { break;
...@@ -2298,7 +2302,9 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -2298,7 +2302,9 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["data_dilation_strides"] = tmp->get_data_dilation_strides(); node["data_dilation_strides"] = tmp->get_data_dilation_strides();
node["output_type"] = write_element_type(tmp->get_element_type()); node["output_type"] = write_element_type(tmp->get_element_type());
node["axes"] = tmp->get_axes(); node["input_axes"] = tmp->get_input_axes();
node["filter_axes"] = tmp->get_filter_axes();
node["output_axes"] = tmp->get_output_axes();
break; break;
} }
case OP_TYPEID::QuantizedDotBias: { break; case OP_TYPEID::QuantizedDotBias: { break;
......
...@@ -7520,6 +7520,8 @@ NGRAPH_TEST(${BACKEND_NAME}, quantized_convolution) ...@@ -7520,6 +7520,8 @@ NGRAPH_TEST(${BACKEND_NAME}, quantized_convolution)
G, G,
H, H,
element::i8, element::i8,
AxisSet{},
AxisSet{},
AxisSet{}); AxisSet{});
auto f = make_shared<Function>(NodeVector{CV}, ParameterVector{A, B, C, D, E, F, G, H}); auto f = make_shared<Function>(NodeVector{CV}, ParameterVector{A, B, C, D, E, F, G, H});
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
...@@ -7653,6 +7655,8 @@ NGRAPH_TEST(${BACKEND_NAME}, quantized_conv_non_zero_zero_point) ...@@ -7653,6 +7655,8 @@ NGRAPH_TEST(${BACKEND_NAME}, quantized_conv_non_zero_zero_point)
result_scale, result_scale,
result_zero_point, result_zero_point,
element::u8, element::u8,
AxisSet{},
AxisSet{},
AxisSet{}); AxisSet{});
auto f = make_shared<Function>(NodeVector{CV}, ParameterVector{A, B}); auto f = make_shared<Function>(NodeVector{CV}, ParameterVector{A, B});
// Create some tensors for input/output // Create some tensors for input/output
......
This diff is collapsed.
...@@ -3183,6 +3183,8 @@ TEST(cpu_quant_fusion, qconv_relu) ...@@ -3183,6 +3183,8 @@ TEST(cpu_quant_fusion, qconv_relu)
output_scale, output_scale,
int8_zero, int8_zero,
element::i8, element::i8,
AxisSet{},
AxisSet{},
AxisSet{}); AxisSet{});
auto dq = std::make_shared<op::Dequantize>( auto dq = std::make_shared<op::Dequantize>(
conv, output_scale, int8_zero, element::f32, AxisSet{}); conv, output_scale, int8_zero, element::f32, AxisSet{});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment