Commit 954da585 authored by nishant.b.patel's avatar nishant.b.patel

add QuantizedConvolution::validate_and_infer_types()

parent bdf082d9
......@@ -49,22 +49,25 @@ namespace ngraph
// TODO: Check for this later
// For Builders the zero point is assumed to be zero (for now)
auto zero_point = op::Constant::create(output_type, Shape{}, {0});
auto input_zero_point = op::Constant::create(input->get_element_type(), Shape{}, {0});
auto filter_zero_point =
op::Constant::create(filters->get_element_type(), Shape{}, {0});
return make_shared<op::QuantizedConvolution>(input,
filters,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
input_scale,
zero_point,
filter_scale,
zero_point,
output_scale,
zero_point,
output_type);
return make_shared<op::QuantizedConvolution>(
input,
filters,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
input_scale,
input_zero_point,
filter_scale,
filter_zero_point,
output_scale,
filter_zero_point, // output type will be same as filter
output_type);
}
}
}
......@@ -24,7 +24,7 @@
using namespace std;
using namespace ngraph;
op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& data_batch,
op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& input,
const shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
......@@ -39,7 +39,7 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& data_batc
const std::shared_ptr<Node>& output_zero_point,
const ngraph::element::Type& output_type)
: Op("QuantizedConvolution",
check_single_output_args({data_batch,
check_single_output_args({input,
filters,
input_scale,
input_zero_point,
......@@ -55,66 +55,82 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& data_batc
, m_output_type(output_type)
{
constructor_validate_and_infer_types();
auto& data_batch_shape = data_batch->get_shape();
auto& filters_shape = filters->get_shape();
set_output_type(0,
output_type,
util::infer_convolution_output_shape(this,
data_batch_shape,
filters_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
0, /* batch_axis_data, */
1, /* input_channel_axis_data, */
1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */
1 /* output_channel_axis_result, */
));
}
#if 0
void op::QuantizedConvolution::validate_and_infer_types()
{
const Shape& data_batch_shape = get_input_shape(0);
element::Type data_batch_et = get_input_element_type(0);
const Shape& filters_shape = get_input_shape(1);
element::Type filters_et = get_input_element_type(1);
enum
{
INPUT_SCALE = 2,
INPUT_ZERO_POINT,
FILTER_SCALE,
FILTER_ZERO_POINT,
OUTPUT_SCALE,
OUTPUT_ZERO_POINT
};
NODE_VALIDATION_CHECK(this,
get_input_element_type(INPUT_SCALE).is_real() ||
get_input_element_type(FILTER_SCALE).is_real() ||
get_input_element_type(OUTPUT_SCALE).is_real(),
"Scale must be a floating point number");
NODE_VALIDATION_CHECK(this,
get_input_element_type(0) == get_input_element_type(INPUT_ZERO_POINT),
"Input Zero point element type (",
get_input_element_type(INPUT_ZERO_POINT),
") must match Input element type (",
get_input_element_type(0),
")");
NODE_VALIDATION_CHECK(this,
get_input_element_type(1) == get_input_element_type(FILTER_ZERO_POINT),
"Filter Zero point element type (",
get_input_element_type(FILTER_ZERO_POINT),
") must match Input element type (",
get_input_element_type(1),
")");
NODE_VALIDATION_CHECK(this,
get_input_element_type(1) == get_input_element_type(OUTPUT_ZERO_POINT),
"Output Zero point element type (",
get_input_element_type(OUTPUT_ZERO_POINT),
") must match Input element type (",
get_input_element_type(1),
")");
auto input_shape = get_input_shape(0);
auto filters_shape = get_input_shape(1);
if (m_data_dilation_strides.size() == 0)
{
m_data_dilation_strides = conv_default_strides(this, data_batch_shape, filters_shape);
m_data_dilation_strides = conv_default_strides(this, input_shape, filters_shape);
}
if (m_window_movement_strides.size() == 0)
{
m_window_movement_strides = conv_default_strides(this, data_batch_shape, filters_shape);
m_window_movement_strides = conv_default_strides(this, input_shape, filters_shape);
}
if (m_window_dilation_strides.size() == 0)
{
m_window_dilation_strides = conv_default_strides(this, data_batch_shape, filters_shape);
m_window_dilation_strides = conv_default_strides(this, input_shape, filters_shape);
}
if (m_padding_below.size() == 0)
{
m_padding_below = conv_default_padding(this, data_batch_shape, filters_shape);
m_padding_below = conv_default_padding(this, input_shape, filters_shape);
}
if (m_padding_above.size() == 0)
{
m_padding_above = conv_default_padding(this, data_batch_shape, filters_shape);
m_padding_above = conv_default_padding(this, input_shape, filters_shape);
}
set_output_type(0,
set_output_type(0,
m_output_type,
util::infer_convolution_output_shape(this,
data_batch_shape,
input_shape,
filters_shape,
m_window_movement_strides,
m_window_dilation_strides,
......@@ -128,9 +144,7 @@ void op::QuantizedConvolution::validate_and_infer_types()
0, /* batch_axis_result, */
1 /* output_channel_axis_result, */
));
}
#endif
shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector& new_args) const
{
......
......@@ -43,7 +43,7 @@ namespace ngraph
/// \param output_zero_point Zero point used for mapping
/// \param output_type Output element type
///
QuantizedConvolution(const std::shared_ptr<Node>& data_batch,
QuantizedConvolution(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
......@@ -65,7 +65,7 @@ namespace ngraph
std::shared_ptr<Node> get_filters() { return get_argument(1); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); }
const ngraph::element::Type& get_output_type() const { return m_output_type; }
// void validate_and_infer_types() override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment