Commit c8288d48 authored by nishant.b.patel's avatar nishant.b.patel

Make input, filter and output axes optional for QuantizedConv op

parent ada5c18f
......@@ -25,22 +25,23 @@ namespace ngraph
{
namespace builder
{
std::shared_ptr<Node> QuantizedConvolutionBuilder(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node>& min_input,
const std::shared_ptr<Node>& max_input,
const std::shared_ptr<Node>& min_filter,
const std::shared_ptr<Node>& max_filter,
const std::shared_ptr<Node>& min_output,
const std::shared_ptr<Node>& max_output,
const ngraph::element::Type& output_type,
const ngraph::AxisSet& input_axes,
const ngraph::AxisSet& filter_axes,
const ngraph::AxisSet& output_axes);
std::shared_ptr<Node>
QuantizedConvolutionBuilder(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node>& min_input,
const std::shared_ptr<Node>& max_input,
const std::shared_ptr<Node>& min_filter,
const std::shared_ptr<Node>& max_filter,
const std::shared_ptr<Node>& min_output,
const std::shared_ptr<Node>& max_output,
const ngraph::element::Type& output_type,
const ngraph::AxisSet& input_axes = ngraph::AxisSet{},
const ngraph::AxisSet& filter_axes = ngraph::AxisSet{},
const ngraph::AxisSet& output_axes = ngraph::AxisSet{});
}
}
......@@ -141,6 +141,12 @@ void op::QuantizedConvolution::validate_and_infer_types()
get_input_partial_shape(7).compatible(PartialShape{}),
"Output scale and output zero point shape must be same and 1");
// AxisSet should be empty till we support channel wise quantization
NODE_VALIDATION_CHECK(this,
m_input_axes == AxisSet{} && m_filter_axes == AxisSet{} &&
m_output_axes == AxisSet{},
"Input, filter and output AxisSet should be empty");
const PartialShape& input_shape = get_input_partial_shape(0);
const PartialShape& filters_shape = get_input_partial_shape(1);
......
......@@ -59,9 +59,9 @@ namespace ngraph
const std::shared_ptr<Node>& output_scale,
const std::shared_ptr<Node>& output_zero_point,
const ngraph::element::Type& output_type,
const ngraph::AxisSet& input_axes,
const ngraph::AxisSet& filter_axes,
const ngraph::AxisSet& output_axes);
const ngraph::AxisSet& input_axes = ngraph::AxisSet{},
const ngraph::AxisSet& filter_axes = ngraph::AxisSet{},
const ngraph::AxisSet& output_axes = ngraph::AxisSet{});
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
......
......@@ -7557,10 +7557,7 @@ NGRAPH_TEST(${BACKEND_NAME}, quantized_convolution)
F,
G,
H,
element::i8,
AxisSet{},
AxisSet{},
AxisSet{});
element::i8);
auto f = make_shared<Function>(NodeVector{CV}, ParameterVector{A, B, C, D, E, F, G, H});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
......@@ -7692,10 +7689,7 @@ NGRAPH_TEST(${BACKEND_NAME}, quantized_conv_non_zero_zero_point)
rhs_zero_point,
result_scale,
result_zero_point,
element::u8,
AxisSet{},
AxisSet{},
AxisSet{});
element::u8);
auto f = make_shared<Function>(NodeVector{CV}, ParameterVector{A, B});
// Create some tensors for input/output
auto a = backend->create_tensor(element::u8, shape_a);
......@@ -7713,7 +7707,7 @@ NGRAPH_TEST(${BACKEND_NAME}, quantized_conv_non_zero_zero_point)
}
}
TEST(${BACKEND_NAME}, quantized_conv_int32_output)
NGRAPH_TEST(${BACKEND_NAME}, quantized_conv_int32_output)
{
Shape shape_a{1, 1, 3, 4};
Shape shape_b{1, 1, 3, 3};
......@@ -7741,10 +7735,7 @@ TEST(${BACKEND_NAME}, quantized_conv_int32_output)
F,
G,
H,
element::i32,
AxisSet{},
AxisSet{},
AxisSet{});
element::i32);
auto f = make_shared<Function>(NodeVector{CV}, ParameterVector{A, B, C, E, G});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment