Commit 47be91bd authored by nishant.b.patel's avatar nishant.b.patel

Change infer_convolution_forward method to just do shape checks

parent 351917d5
...@@ -100,9 +100,16 @@ void op::Convolution::validate_and_infer_types() ...@@ -100,9 +100,16 @@ void op::Convolution::validate_and_infer_types()
element::Type result_et; element::Type result_et;
PartialShape result_shape; PartialShape result_shape;
std::tie(result_et, result_shape) = infer_convolution_forward(this, NODE_VALIDATION_CHECK(
this,
element::Type::merge(result_et, data_batch_et, filters_et),
"Element types for data batch and filters do not match (data batch element type: ",
data_batch_et, data_batch_et,
", filters element type: ",
filters_et, filters_et,
").");
result_shape = infer_convolution_forward(this,
data_batch_shape, data_batch_shape,
m_data_dilation_strides, m_data_dilation_strides,
m_padding_below, m_padding_below,
...@@ -255,10 +262,16 @@ void op::ConvolutionBackpropData::validate_and_infer_types() ...@@ -255,10 +262,16 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
element::Type forward_result_et; element::Type forward_result_et;
PartialShape forward_result_shape; PartialShape forward_result_shape;
std::tie(forward_result_et, forward_result_shape) = NODE_VALIDATION_CHECK(
infer_convolution_forward(this, this,
element::Type::merge(forward_result_et, delta_et, filters_et),
"Element types for data batch and filters do not match (data batch element type: ",
delta_et, delta_et,
", filters element type: ",
filters_et, filters_et,
").");
forward_result_shape = infer_convolution_forward(this,
m_data_batch_shape, m_data_batch_shape,
m_data_dilation_strides_forward, m_data_dilation_strides_forward,
m_padding_below_forward, m_padding_below_forward,
...@@ -481,10 +494,16 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types() ...@@ -481,10 +494,16 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
element::Type forward_result_et; element::Type forward_result_et;
PartialShape forward_result_shape; PartialShape forward_result_shape;
std::tie(forward_result_et, forward_result_shape) = NODE_VALIDATION_CHECK(
infer_convolution_forward(this, this,
element::Type::merge(forward_result_et, data_batch_et, delta_et),
"Element types for data batch and filters do not match (data batch element type: ",
data_batch_et, data_batch_et,
", filters element type: ",
delta_et, delta_et,
").");
forward_result_shape = infer_convolution_forward(this,
data_batch_shape, data_batch_shape,
m_data_dilation_strides_forward, m_data_dilation_strides_forward,
m_padding_below_forward, m_padding_below_forward,
......
...@@ -159,9 +159,16 @@ void op::ConvolutionBias::validate_and_infer_types() ...@@ -159,9 +159,16 @@ void op::ConvolutionBias::validate_and_infer_types()
element::Type result_et; element::Type result_et;
PartialShape result_shape; PartialShape result_shape;
std::tie(result_et, result_shape) = infer_convolution_forward(this, NODE_VALIDATION_CHECK(
this,
element::Type::merge(result_et, data_batch_et, filters_et),
"Element types for data batch and filters do not match (data batch element type: ",
data_batch_et, data_batch_et,
", filters element type: ",
filters_et, filters_et,
").");
result_shape = infer_convolution_forward(this,
data_batch_shape, data_batch_shape,
m_data_dilation_strides, m_data_dilation_strides,
m_padding_below, m_padding_below,
...@@ -407,9 +414,16 @@ void op::ConvolutionBiasAdd::validate_and_infer_types() ...@@ -407,9 +414,16 @@ void op::ConvolutionBiasAdd::validate_and_infer_types()
element::Type result_et; element::Type result_et;
PartialShape result_shape; PartialShape result_shape;
std::tie(result_et, result_shape) = infer_convolution_forward(this, NODE_VALIDATION_CHECK(
this,
element::Type::merge(result_et, data_batch_et, filters_et),
"Element types for data batch and filters do not match (data batch element type: ",
data_batch_et, data_batch_et,
", filters element type: ",
filters_et, filters_et,
").");
result_shape = infer_convolution_forward(this,
data_batch_shape, data_batch_shape,
m_data_dilation_strides, m_data_dilation_strides,
m_padding_below, m_padding_below,
......
...@@ -118,8 +118,11 @@ void op::QuantizedConvolution::validate_and_infer_types() ...@@ -118,8 +118,11 @@ void op::QuantizedConvolution::validate_and_infer_types()
shape_size(get_input_shape(7)) == 1, shape_size(get_input_shape(7)) == 1,
"Output scale and output zero point shape must be same and 1"); "Output scale and output zero point shape must be same and 1");
auto input_shape = get_input_shape(0); // auto input_shape = get_input_shape(0);
auto filters_shape = get_input_shape(1); // auto filters_shape = get_input_shape(1);
const PartialShape& input_shape = get_input_partial_shape(0);
const PartialShape& filters_shape = get_input_partial_shape(1);
if (m_data_dilation_strides.size() == 0) if (m_data_dilation_strides.size() == 0)
{ {
...@@ -146,23 +149,16 @@ void op::QuantizedConvolution::validate_and_infer_types() ...@@ -146,23 +149,16 @@ void op::QuantizedConvolution::validate_and_infer_types()
m_padding_above = conv_default_padding(this, input_shape, filters_shape); m_padding_above = conv_default_padding(this, input_shape, filters_shape);
} }
set_output_type(0, PartialShape result_shape;
m_output_type,
util::infer_convolution_output_shape(this, result_shape = infer_convolution_forward(this,
input_shape, input_shape,
filters_shape, m_data_dilation_strides,
m_window_movement_strides,
m_window_dilation_strides,
m_padding_below, m_padding_below,
m_padding_above, m_padding_above,
m_data_dilation_strides, filters_shape,
0, /* batch_axis_data, */ m_window_movement_strides,
1, /* input_channel_axis_data, */ m_window_dilation_strides);
1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */
1 /* output_channel_axis_result, */
));
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, this,
...@@ -172,6 +168,8 @@ void op::QuantizedConvolution::validate_and_infer_types() ...@@ -172,6 +168,8 @@ void op::QuantizedConvolution::validate_and_infer_types()
") must match output element type (", ") must match output element type (",
get_output_element_type(0), get_output_element_type(0),
")"); ")");
set_output_type(0, m_output_type, result_shape);
} }
shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -96,10 +96,17 @@ void op::DeconvolutionBias::validate_and_infer_types() ...@@ -96,10 +96,17 @@ void op::DeconvolutionBias::validate_and_infer_types()
const PartialShape& fwd_filters_shape{ const PartialShape& fwd_filters_shape{
filters_shape[1], filters_shape[0], filters_shape[2], filters_shape[3]}; filters_shape[1], filters_shape[0], filters_shape[2], filters_shape[3]};
std::tie(forward_result_et, forward_result_shape) =
infer_convolution_forward(this, NODE_VALIDATION_CHECK(
this,
element::Type::merge(forward_result_et, delta_et, filters_et),
"Element types for data batch and filters do not match (data batch element type: ",
delta_et, delta_et,
", filters element type: ",
filters_et, filters_et,
").");
forward_result_shape = infer_convolution_forward(this,
m_data_batch_shape, m_data_batch_shape,
m_data_dilation_strides_forward, m_data_dilation_strides_forward,
m_padding_below_forward, m_padding_below_forward,
......
...@@ -211,10 +211,7 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node, ...@@ -211,10 +211,7 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
// //
// Infers the output batch shape and element type for convolution fprop. // Infers the output batch shape and element type for convolution fprop.
// //
std::tuple<element::Type, PartialShape> PartialShape ngraph::infer_convolution_forward(const Node* node,
ngraph::infer_convolution_forward(const Node* node,
element::Type et_batch,
element::Type et_filters,
const PartialShape& data_batch_shape, const PartialShape& data_batch_shape,
const Strides& data_dilation, const Strides& data_dilation,
const CoordinateDiff& data_padding_below, const CoordinateDiff& data_padding_below,
...@@ -223,17 +220,6 @@ std::tuple<element::Type, PartialShape> ...@@ -223,17 +220,6 @@ std::tuple<element::Type, PartialShape>
const Strides& filter_strides, const Strides& filter_strides,
const Strides& filter_dilation) const Strides& filter_dilation)
{ {
element::Type et_result;
NODE_VALIDATION_CHECK(
node,
element::Type::merge(et_result, et_batch, et_filters),
"Element types for data batch and filters do not match (data batch element type: ",
et_batch,
", filters element type: ",
et_filters,
").");
Rank data_batch_filters_rank{Rank::dynamic()}; Rank data_batch_filters_rank{Rank::dynamic()};
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
...@@ -358,7 +344,7 @@ std::tuple<element::Type, PartialShape> ...@@ -358,7 +344,7 @@ std::tuple<element::Type, PartialShape>
batch_output_shape[i + 2] = data_output_shape[i]; batch_output_shape[i + 2] = data_output_shape[i];
} }
return std::make_tuple(et_result, batch_output_shape); return batch_output_shape;
} }
// //
......
...@@ -43,10 +43,7 @@ namespace ngraph ...@@ -43,10 +43,7 @@ namespace ngraph
const Strides& window_dilation, const Strides& window_dilation,
bool is_window_all_in_padding_allowed); bool is_window_all_in_padding_allowed);
std::tuple<element::Type, PartialShape> PartialShape infer_convolution_forward(const Node* node,
infer_convolution_forward(const Node* node,
element::Type et_batch,
element::Type et_filters,
const PartialShape& data_batch_shape, const PartialShape& data_batch_shape,
const Strides& data_dilation, const Strides& data_dilation,
const CoordinateDiff& data_padding_below, const CoordinateDiff& data_padding_below,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment