Commit 954da585 authored by nishant.b.patel's avatar nishant.b.patel

add QuantizedConvolution::validate_and_infer_types()

parent bdf082d9
...@@ -49,9 +49,12 @@ namespace ngraph ...@@ -49,9 +49,12 @@ namespace ngraph
// TODO: Check for this later // TODO: Check for this later
// For Builders the zero point is assumed to be zero (for now) // For Builders the zero point is assumed to be zero (for now)
auto zero_point = op::Constant::create(output_type, Shape{}, {0}); auto input_zero_point = op::Constant::create(input->get_element_type(), Shape{}, {0});
auto filter_zero_point =
op::Constant::create(filters->get_element_type(), Shape{}, {0});
return make_shared<op::QuantizedConvolution>(input, return make_shared<op::QuantizedConvolution>(
input,
filters, filters,
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
...@@ -59,11 +62,11 @@ namespace ngraph ...@@ -59,11 +62,11 @@ namespace ngraph
padding_above, padding_above,
data_dilation_strides, data_dilation_strides,
input_scale, input_scale,
zero_point, input_zero_point,
filter_scale, filter_scale,
zero_point, filter_zero_point,
output_scale, output_scale,
zero_point, filter_zero_point, // output type will be same as filter
output_type); output_type);
} }
} }
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& data_batch, op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& input,
const shared_ptr<Node>& filters, const shared_ptr<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
...@@ -39,7 +39,7 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& data_batc ...@@ -39,7 +39,7 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& data_batc
const std::shared_ptr<Node>& output_zero_point, const std::shared_ptr<Node>& output_zero_point,
const ngraph::element::Type& output_type) const ngraph::element::Type& output_type)
: Op("QuantizedConvolution", : Op("QuantizedConvolution",
check_single_output_args({data_batch, check_single_output_args({input,
filters, filters,
input_scale, input_scale,
input_zero_point, input_zero_point,
...@@ -55,66 +55,82 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& data_batc ...@@ -55,66 +55,82 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& data_batc
, m_output_type(output_type) , m_output_type(output_type)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
auto& data_batch_shape = data_batch->get_shape();
auto& filters_shape = filters->get_shape();
set_output_type(0,
output_type,
util::infer_convolution_output_shape(this,
data_batch_shape,
filters_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
0, /* batch_axis_data, */
1, /* input_channel_axis_data, */
1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */
1 /* output_channel_axis_result, */
));
} }
#if 0
void op::QuantizedConvolution::validate_and_infer_types() void op::QuantizedConvolution::validate_and_infer_types()
{ {
const Shape& data_batch_shape = get_input_shape(0); enum
element::Type data_batch_et = get_input_element_type(0); {
const Shape& filters_shape = get_input_shape(1); INPUT_SCALE = 2,
element::Type filters_et = get_input_element_type(1); INPUT_ZERO_POINT,
FILTER_SCALE,
FILTER_ZERO_POINT,
OUTPUT_SCALE,
OUTPUT_ZERO_POINT
};
NODE_VALIDATION_CHECK(this,
get_input_element_type(INPUT_SCALE).is_real() ||
get_input_element_type(FILTER_SCALE).is_real() ||
get_input_element_type(OUTPUT_SCALE).is_real(),
"Scale must be a floating point number");
NODE_VALIDATION_CHECK(this,
get_input_element_type(0) == get_input_element_type(INPUT_ZERO_POINT),
"Input Zero point element type (",
get_input_element_type(INPUT_ZERO_POINT),
") must match Input element type (",
get_input_element_type(0),
")");
NODE_VALIDATION_CHECK(this,
get_input_element_type(1) == get_input_element_type(FILTER_ZERO_POINT),
"Filter Zero point element type (",
get_input_element_type(FILTER_ZERO_POINT),
") must match Input element type (",
get_input_element_type(1),
")");
NODE_VALIDATION_CHECK(this,
get_input_element_type(1) == get_input_element_type(OUTPUT_ZERO_POINT),
"Output Zero point element type (",
get_input_element_type(OUTPUT_ZERO_POINT),
") must match Input element type (",
get_input_element_type(1),
")");
auto input_shape = get_input_shape(0);
auto filters_shape = get_input_shape(1);
if (m_data_dilation_strides.size() == 0) if (m_data_dilation_strides.size() == 0)
{ {
m_data_dilation_strides = conv_default_strides(this, data_batch_shape, filters_shape); m_data_dilation_strides = conv_default_strides(this, input_shape, filters_shape);
} }
if (m_window_movement_strides.size() == 0) if (m_window_movement_strides.size() == 0)
{ {
m_window_movement_strides = conv_default_strides(this, data_batch_shape, filters_shape); m_window_movement_strides = conv_default_strides(this, input_shape, filters_shape);
} }
if (m_window_dilation_strides.size() == 0) if (m_window_dilation_strides.size() == 0)
{ {
m_window_dilation_strides = conv_default_strides(this, data_batch_shape, filters_shape); m_window_dilation_strides = conv_default_strides(this, input_shape, filters_shape);
} }
if (m_padding_below.size() == 0) if (m_padding_below.size() == 0)
{ {
m_padding_below = conv_default_padding(this, data_batch_shape, filters_shape); m_padding_below = conv_default_padding(this, input_shape, filters_shape);
} }
if (m_padding_above.size() == 0) if (m_padding_above.size() == 0)
{ {
m_padding_above = conv_default_padding(this, data_batch_shape, filters_shape); m_padding_above = conv_default_padding(this, input_shape, filters_shape);
} }
set_output_type(0, set_output_type(0,
m_output_type, m_output_type,
util::infer_convolution_output_shape(this, util::infer_convolution_output_shape(this,
data_batch_shape, input_shape,
filters_shape, filters_shape,
m_window_movement_strides, m_window_movement_strides,
m_window_dilation_strides, m_window_dilation_strides,
...@@ -128,9 +144,7 @@ void op::QuantizedConvolution::validate_and_infer_types() ...@@ -128,9 +144,7 @@ void op::QuantizedConvolution::validate_and_infer_types()
0, /* batch_axis_result, */ 0, /* batch_axis_result, */
1 /* output_channel_axis_result, */ 1 /* output_channel_axis_result, */
)); ));
} }
#endif
shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector& new_args) const
{ {
......
...@@ -43,7 +43,7 @@ namespace ngraph ...@@ -43,7 +43,7 @@ namespace ngraph
/// \param output_zero_point Zero point used for mapping /// \param output_zero_point Zero point used for mapping
/// \param output_type Output element type /// \param output_type Output element type
/// ///
QuantizedConvolution(const std::shared_ptr<Node>& data_batch, QuantizedConvolution(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& filters, const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
...@@ -65,7 +65,7 @@ namespace ngraph ...@@ -65,7 +65,7 @@ namespace ngraph
std::shared_ptr<Node> get_filters() { return get_argument(1); } std::shared_ptr<Node> get_filters() { return get_argument(1); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); } std::shared_ptr<Node> get_data_batch() { return get_argument(0); }
const ngraph::element::Type& get_output_type() const { return m_output_type; } const ngraph::element::Type& get_output_type() const { return m_output_type; }
// void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment