Commit 47be91bd authored by nishant.b.patel's avatar nishant.b.patel

Change infer_convolution_forward method to just do shape checks

parent 351917d5
...@@ -100,16 +100,23 @@ void op::Convolution::validate_and_infer_types() ...@@ -100,16 +100,23 @@ void op::Convolution::validate_and_infer_types()
element::Type result_et; element::Type result_et;
PartialShape result_shape; PartialShape result_shape;
std::tie(result_et, result_shape) = infer_convolution_forward(this, NODE_VALIDATION_CHECK(
data_batch_et, this,
filters_et, element::Type::merge(result_et, data_batch_et, filters_et),
data_batch_shape, "Element types for data batch and filters do not match (data batch element type: ",
m_data_dilation_strides, data_batch_et,
m_padding_below, ", filters element type: ",
m_padding_above, filters_et,
filters_shape, ").");
m_window_movement_strides,
m_window_dilation_strides); result_shape = infer_convolution_forward(this,
data_batch_shape,
m_data_dilation_strides,
m_padding_below,
m_padding_above,
filters_shape,
m_window_movement_strides,
m_window_dilation_strides);
set_output_type(0, result_et, result_shape); set_output_type(0, result_et, result_shape);
} }
...@@ -255,17 +262,23 @@ void op::ConvolutionBackpropData::validate_and_infer_types() ...@@ -255,17 +262,23 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
element::Type forward_result_et; element::Type forward_result_et;
PartialShape forward_result_shape; PartialShape forward_result_shape;
std::tie(forward_result_et, forward_result_shape) = NODE_VALIDATION_CHECK(
infer_convolution_forward(this, this,
delta_et, element::Type::merge(forward_result_et, delta_et, filters_et),
filters_et, "Element types for data batch and filters do not match (data batch element type: ",
m_data_batch_shape, delta_et,
m_data_dilation_strides_forward, ", filters element type: ",
m_padding_below_forward, filters_et,
m_padding_above_forward, ").");
filters_shape,
m_window_movement_strides_forward, forward_result_shape = infer_convolution_forward(this,
m_window_dilation_strides_forward); m_data_batch_shape,
m_data_dilation_strides_forward,
m_padding_below_forward,
m_padding_above_forward,
filters_shape,
m_window_movement_strides_forward,
m_window_dilation_strides_forward);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
forward_result_shape.compatible(delta_shape), forward_result_shape.compatible(delta_shape),
...@@ -481,17 +494,23 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types() ...@@ -481,17 +494,23 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
element::Type forward_result_et; element::Type forward_result_et;
PartialShape forward_result_shape; PartialShape forward_result_shape;
std::tie(forward_result_et, forward_result_shape) = NODE_VALIDATION_CHECK(
infer_convolution_forward(this, this,
data_batch_et, element::Type::merge(forward_result_et, data_batch_et, delta_et),
delta_et, "Element types for data batch and filters do not match (data batch element type: ",
data_batch_shape, data_batch_et,
m_data_dilation_strides_forward, ", filters element type: ",
m_padding_below_forward, delta_et,
m_padding_above_forward, ").");
m_filters_shape,
m_window_movement_strides_forward, forward_result_shape = infer_convolution_forward(this,
m_window_dilation_strides_forward); data_batch_shape,
m_data_dilation_strides_forward,
m_padding_below_forward,
m_padding_above_forward,
m_filters_shape,
m_window_movement_strides_forward,
m_window_dilation_strides_forward);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
forward_result_shape.compatible(delta_shape), forward_result_shape.compatible(delta_shape),
......
...@@ -159,16 +159,23 @@ void op::ConvolutionBias::validate_and_infer_types() ...@@ -159,16 +159,23 @@ void op::ConvolutionBias::validate_and_infer_types()
element::Type result_et; element::Type result_et;
PartialShape result_shape; PartialShape result_shape;
std::tie(result_et, result_shape) = infer_convolution_forward(this, NODE_VALIDATION_CHECK(
data_batch_et, this,
filters_et, element::Type::merge(result_et, data_batch_et, filters_et),
data_batch_shape, "Element types for data batch and filters do not match (data batch element type: ",
m_data_dilation_strides, data_batch_et,
m_padding_below, ", filters element type: ",
m_padding_above, filters_et,
filters_shape, ").");
m_window_movement_strides,
m_window_dilation_strides); result_shape = infer_convolution_forward(this,
data_batch_shape,
m_data_dilation_strides,
m_padding_below,
m_padding_above,
filters_shape,
m_window_movement_strides,
m_window_dilation_strides);
set_output_type(0, result_et, result_shape); set_output_type(0, result_et, result_shape);
} }
...@@ -407,16 +414,23 @@ void op::ConvolutionBiasAdd::validate_and_infer_types() ...@@ -407,16 +414,23 @@ void op::ConvolutionBiasAdd::validate_and_infer_types()
element::Type result_et; element::Type result_et;
PartialShape result_shape; PartialShape result_shape;
std::tie(result_et, result_shape) = infer_convolution_forward(this, NODE_VALIDATION_CHECK(
data_batch_et, this,
filters_et, element::Type::merge(result_et, data_batch_et, filters_et),
data_batch_shape, "Element types for data batch and filters do not match (data batch element type: ",
m_data_dilation_strides, data_batch_et,
m_padding_below, ", filters element type: ",
m_padding_above, filters_et,
filters_shape, ").");
m_window_movement_strides,
m_window_dilation_strides); result_shape = infer_convolution_forward(this,
data_batch_shape,
m_data_dilation_strides,
m_padding_below,
m_padding_above,
filters_shape,
m_window_movement_strides,
m_window_dilation_strides);
// TODO: Check result_shape is compatible with add_input // TODO: Check result_shape is compatible with add_input
set_output_type(0, result_et, result_shape); set_output_type(0, result_et, result_shape);
} }
......
...@@ -118,8 +118,11 @@ void op::QuantizedConvolution::validate_and_infer_types() ...@@ -118,8 +118,11 @@ void op::QuantizedConvolution::validate_and_infer_types()
shape_size(get_input_shape(7)) == 1, shape_size(get_input_shape(7)) == 1,
"Output scale and output zero point shape must be same and 1"); "Output scale and output zero point shape must be same and 1");
auto input_shape = get_input_shape(0); // auto input_shape = get_input_shape(0);
auto filters_shape = get_input_shape(1); // auto filters_shape = get_input_shape(1);
const PartialShape& input_shape = get_input_partial_shape(0);
const PartialShape& filters_shape = get_input_partial_shape(1);
if (m_data_dilation_strides.size() == 0) if (m_data_dilation_strides.size() == 0)
{ {
...@@ -146,23 +149,16 @@ void op::QuantizedConvolution::validate_and_infer_types() ...@@ -146,23 +149,16 @@ void op::QuantizedConvolution::validate_and_infer_types()
m_padding_above = conv_default_padding(this, input_shape, filters_shape); m_padding_above = conv_default_padding(this, input_shape, filters_shape);
} }
set_output_type(0, PartialShape result_shape;
m_output_type,
util::infer_convolution_output_shape(this, result_shape = infer_convolution_forward(this,
input_shape, input_shape,
filters_shape, m_data_dilation_strides,
m_window_movement_strides, m_padding_below,
m_window_dilation_strides, m_padding_above,
m_padding_below, filters_shape,
m_padding_above, m_window_movement_strides,
m_data_dilation_strides, m_window_dilation_strides);
0, /* batch_axis_data, */
1, /* input_channel_axis_data, */
1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */
1 /* output_channel_axis_result, */
));
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, this,
...@@ -172,6 +168,8 @@ void op::QuantizedConvolution::validate_and_infer_types() ...@@ -172,6 +168,8 @@ void op::QuantizedConvolution::validate_and_infer_types()
") must match output element type (", ") must match output element type (",
get_output_element_type(0), get_output_element_type(0),
")"); ")");
set_output_type(0, m_output_type, result_shape);
} }
shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -96,17 +96,24 @@ void op::DeconvolutionBias::validate_and_infer_types() ...@@ -96,17 +96,24 @@ void op::DeconvolutionBias::validate_and_infer_types()
const PartialShape& fwd_filters_shape{ const PartialShape& fwd_filters_shape{
filters_shape[1], filters_shape[0], filters_shape[2], filters_shape[3]}; filters_shape[1], filters_shape[0], filters_shape[2], filters_shape[3]};
std::tie(forward_result_et, forward_result_shape) =
infer_convolution_forward(this, NODE_VALIDATION_CHECK(
delta_et, this,
filters_et, element::Type::merge(forward_result_et, delta_et, filters_et),
m_data_batch_shape, "Element types for data batch and filters do not match (data batch element type: ",
m_data_dilation_strides_forward, delta_et,
m_padding_below_forward, ", filters element type: ",
m_padding_above_forward, filters_et,
fwd_filters_shape, ").");
m_window_movement_strides_forward,
m_window_dilation_strides_forward); forward_result_shape = infer_convolution_forward(this,
m_data_batch_shape,
m_data_dilation_strides_forward,
m_padding_below_forward,
m_padding_above_forward,
fwd_filters_shape,
m_window_movement_strides_forward,
m_window_dilation_strides_forward);
NGRAPH_DEBUG << "\tpartial filter_shape: " << filters_shape << "delta_shape: " << delta_shape NGRAPH_DEBUG << "\tpartial filter_shape: " << filters_shape << "delta_shape: " << delta_shape
<< ", inferred_res_shape: " << forward_result_shape << endl; << ", inferred_res_shape: " << forward_result_shape << endl;
......
...@@ -211,29 +211,15 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node, ...@@ -211,29 +211,15 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
// //
// Infers the output batch shape and element type for convolution fprop. // Infers the output batch shape and element type for convolution fprop.
// //
std::tuple<element::Type, PartialShape> PartialShape ngraph::infer_convolution_forward(const Node* node,
ngraph::infer_convolution_forward(const Node* node, const PartialShape& data_batch_shape,
element::Type et_batch, const Strides& data_dilation,
element::Type et_filters, const CoordinateDiff& data_padding_below,
const PartialShape& data_batch_shape, const CoordinateDiff& data_padding_above,
const Strides& data_dilation, const PartialShape& filters_shape,
const CoordinateDiff& data_padding_below, const Strides& filter_strides,
const CoordinateDiff& data_padding_above, const Strides& filter_dilation)
const PartialShape& filters_shape,
const Strides& filter_strides,
const Strides& filter_dilation)
{ {
element::Type et_result;
NODE_VALIDATION_CHECK(
node,
element::Type::merge(et_result, et_batch, et_filters),
"Element types for data batch and filters do not match (data batch element type: ",
et_batch,
", filters element type: ",
et_filters,
").");
Rank data_batch_filters_rank{Rank::dynamic()}; Rank data_batch_filters_rank{Rank::dynamic()};
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
...@@ -358,7 +344,7 @@ std::tuple<element::Type, PartialShape> ...@@ -358,7 +344,7 @@ std::tuple<element::Type, PartialShape>
batch_output_shape[i + 2] = data_output_shape[i]; batch_output_shape[i + 2] = data_output_shape[i];
} }
return std::make_tuple(et_result, batch_output_shape); return batch_output_shape;
} }
// //
......
...@@ -43,17 +43,14 @@ namespace ngraph ...@@ -43,17 +43,14 @@ namespace ngraph
const Strides& window_dilation, const Strides& window_dilation,
bool is_window_all_in_padding_allowed); bool is_window_all_in_padding_allowed);
std::tuple<element::Type, PartialShape> PartialShape infer_convolution_forward(const Node* node,
infer_convolution_forward(const Node* node, const PartialShape& data_batch_shape,
element::Type et_batch, const Strides& data_dilation,
element::Type et_filters, const CoordinateDiff& data_padding_below,
const PartialShape& data_batch_shape, const CoordinateDiff& data_padding_above,
const Strides& data_dilation, const PartialShape& filters_shape,
const CoordinateDiff& data_padding_below, const Strides& filter_strides,
const CoordinateDiff& data_padding_above, const Strides& filter_dilation);
const PartialShape& filters_shape,
const Strides& filter_strides,
const Strides& filter_dilation);
PartialShape infer_batched_pooling_forward(const Node* node, PartialShape infer_batched_pooling_forward(const Node* node,
const PartialShape& data_batch_shape, const PartialShape& data_batch_shape,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment