Commit 47be91bd authored by nishant.b.patel's avatar nishant.b.patel

Change infer_convolution_forward method to just do shape checks

parent 351917d5
......@@ -100,16 +100,23 @@ void op::Convolution::validate_and_infer_types()
element::Type result_et;
PartialShape result_shape;
std::tie(result_et, result_shape) = infer_convolution_forward(this,
data_batch_et,
filters_et,
data_batch_shape,
m_data_dilation_strides,
m_padding_below,
m_padding_above,
filters_shape,
m_window_movement_strides,
m_window_dilation_strides);
NODE_VALIDATION_CHECK(
this,
element::Type::merge(result_et, data_batch_et, filters_et),
"Element types for data batch and filters do not match (data batch element type: ",
data_batch_et,
", filters element type: ",
filters_et,
").");
result_shape = infer_convolution_forward(this,
data_batch_shape,
m_data_dilation_strides,
m_padding_below,
m_padding_above,
filters_shape,
m_window_movement_strides,
m_window_dilation_strides);
set_output_type(0, result_et, result_shape);
}
......@@ -255,17 +262,23 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
element::Type forward_result_et;
PartialShape forward_result_shape;
std::tie(forward_result_et, forward_result_shape) =
infer_convolution_forward(this,
delta_et,
filters_et,
m_data_batch_shape,
m_data_dilation_strides_forward,
m_padding_below_forward,
m_padding_above_forward,
filters_shape,
m_window_movement_strides_forward,
m_window_dilation_strides_forward);
NODE_VALIDATION_CHECK(
this,
element::Type::merge(forward_result_et, delta_et, filters_et),
"Element types for data batch and filters do not match (data batch element type: ",
delta_et,
", filters element type: ",
filters_et,
").");
forward_result_shape = infer_convolution_forward(this,
m_data_batch_shape,
m_data_dilation_strides_forward,
m_padding_below_forward,
m_padding_above_forward,
filters_shape,
m_window_movement_strides_forward,
m_window_dilation_strides_forward);
NODE_VALIDATION_CHECK(this,
forward_result_shape.compatible(delta_shape),
......@@ -481,17 +494,23 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
element::Type forward_result_et;
PartialShape forward_result_shape;
std::tie(forward_result_et, forward_result_shape) =
infer_convolution_forward(this,
data_batch_et,
delta_et,
data_batch_shape,
m_data_dilation_strides_forward,
m_padding_below_forward,
m_padding_above_forward,
m_filters_shape,
m_window_movement_strides_forward,
m_window_dilation_strides_forward);
NODE_VALIDATION_CHECK(
this,
element::Type::merge(forward_result_et, data_batch_et, delta_et),
"Element types for data batch and filters do not match (data batch element type: ",
data_batch_et,
", filters element type: ",
delta_et,
").");
forward_result_shape = infer_convolution_forward(this,
data_batch_shape,
m_data_dilation_strides_forward,
m_padding_below_forward,
m_padding_above_forward,
m_filters_shape,
m_window_movement_strides_forward,
m_window_dilation_strides_forward);
NODE_VALIDATION_CHECK(this,
forward_result_shape.compatible(delta_shape),
......
......@@ -159,16 +159,23 @@ void op::ConvolutionBias::validate_and_infer_types()
element::Type result_et;
PartialShape result_shape;
std::tie(result_et, result_shape) = infer_convolution_forward(this,
data_batch_et,
filters_et,
data_batch_shape,
m_data_dilation_strides,
m_padding_below,
m_padding_above,
filters_shape,
m_window_movement_strides,
m_window_dilation_strides);
NODE_VALIDATION_CHECK(
this,
element::Type::merge(result_et, data_batch_et, filters_et),
"Element types for data batch and filters do not match (data batch element type: ",
data_batch_et,
", filters element type: ",
filters_et,
").");
result_shape = infer_convolution_forward(this,
data_batch_shape,
m_data_dilation_strides,
m_padding_below,
m_padding_above,
filters_shape,
m_window_movement_strides,
m_window_dilation_strides);
set_output_type(0, result_et, result_shape);
}
......@@ -407,16 +414,23 @@ void op::ConvolutionBiasAdd::validate_and_infer_types()
element::Type result_et;
PartialShape result_shape;
std::tie(result_et, result_shape) = infer_convolution_forward(this,
data_batch_et,
filters_et,
data_batch_shape,
m_data_dilation_strides,
m_padding_below,
m_padding_above,
filters_shape,
m_window_movement_strides,
m_window_dilation_strides);
NODE_VALIDATION_CHECK(
this,
element::Type::merge(result_et, data_batch_et, filters_et),
"Element types for data batch and filters do not match (data batch element type: ",
data_batch_et,
", filters element type: ",
filters_et,
").");
result_shape = infer_convolution_forward(this,
data_batch_shape,
m_data_dilation_strides,
m_padding_below,
m_padding_above,
filters_shape,
m_window_movement_strides,
m_window_dilation_strides);
// TODO: Check result_shape is compatible with add_input
set_output_type(0, result_et, result_shape);
}
......
......@@ -118,8 +118,11 @@ void op::QuantizedConvolution::validate_and_infer_types()
shape_size(get_input_shape(7)) == 1,
"Output scale and output zero point shape must be same and 1");
auto input_shape = get_input_shape(0);
auto filters_shape = get_input_shape(1);
// auto input_shape = get_input_shape(0);
// auto filters_shape = get_input_shape(1);
const PartialShape& input_shape = get_input_partial_shape(0);
const PartialShape& filters_shape = get_input_partial_shape(1);
if (m_data_dilation_strides.size() == 0)
{
......@@ -146,23 +149,16 @@ void op::QuantizedConvolution::validate_and_infer_types()
m_padding_above = conv_default_padding(this, input_shape, filters_shape);
}
set_output_type(0,
m_output_type,
util::infer_convolution_output_shape(this,
input_shape,
filters_shape,
m_window_movement_strides,
m_window_dilation_strides,
m_padding_below,
m_padding_above,
m_data_dilation_strides,
0, /* batch_axis_data, */
1, /* input_channel_axis_data, */
1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */
1 /* output_channel_axis_result, */
));
PartialShape result_shape;
result_shape = infer_convolution_forward(this,
input_shape,
m_data_dilation_strides,
m_padding_below,
m_padding_above,
filters_shape,
m_window_movement_strides,
m_window_dilation_strides);
NODE_VALIDATION_CHECK(
this,
......@@ -172,6 +168,8 @@ void op::QuantizedConvolution::validate_and_infer_types()
") must match output element type (",
get_output_element_type(0),
")");
set_output_type(0, m_output_type, result_shape);
}
shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -96,17 +96,24 @@ void op::DeconvolutionBias::validate_and_infer_types()
const PartialShape& fwd_filters_shape{
filters_shape[1], filters_shape[0], filters_shape[2], filters_shape[3]};
std::tie(forward_result_et, forward_result_shape) =
infer_convolution_forward(this,
delta_et,
filters_et,
m_data_batch_shape,
m_data_dilation_strides_forward,
m_padding_below_forward,
m_padding_above_forward,
fwd_filters_shape,
m_window_movement_strides_forward,
m_window_dilation_strides_forward);
NODE_VALIDATION_CHECK(
this,
element::Type::merge(forward_result_et, delta_et, filters_et),
"Element types for data batch and filters do not match (data batch element type: ",
delta_et,
", filters element type: ",
filters_et,
").");
forward_result_shape = infer_convolution_forward(this,
m_data_batch_shape,
m_data_dilation_strides_forward,
m_padding_below_forward,
m_padding_above_forward,
fwd_filters_shape,
m_window_movement_strides_forward,
m_window_dilation_strides_forward);
NGRAPH_DEBUG << "\tpartial filter_shape: " << filters_shape << "delta_shape: " << delta_shape
<< ", inferred_res_shape: " << forward_result_shape << endl;
......
......@@ -211,29 +211,15 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
//
// Infers the output batch shape and element type for convolution fprop.
//
std::tuple<element::Type, PartialShape>
ngraph::infer_convolution_forward(const Node* node,
element::Type et_batch,
element::Type et_filters,
const PartialShape& data_batch_shape,
const Strides& data_dilation,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const PartialShape& filters_shape,
const Strides& filter_strides,
const Strides& filter_dilation)
PartialShape ngraph::infer_convolution_forward(const Node* node,
const PartialShape& data_batch_shape,
const Strides& data_dilation,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const PartialShape& filters_shape,
const Strides& filter_strides,
const Strides& filter_dilation)
{
element::Type et_result;
NODE_VALIDATION_CHECK(
node,
element::Type::merge(et_result, et_batch, et_filters),
"Element types for data batch and filters do not match (data batch element type: ",
et_batch,
", filters element type: ",
et_filters,
").");
Rank data_batch_filters_rank{Rank::dynamic()};
NODE_VALIDATION_CHECK(
......@@ -358,7 +344,7 @@ std::tuple<element::Type, PartialShape>
batch_output_shape[i + 2] = data_output_shape[i];
}
return std::make_tuple(et_result, batch_output_shape);
return batch_output_shape;
}
//
......
......@@ -43,17 +43,14 @@ namespace ngraph
const Strides& window_dilation,
bool is_window_all_in_padding_allowed);
std::tuple<element::Type, PartialShape>
infer_convolution_forward(const Node* node,
element::Type et_batch,
element::Type et_filters,
const PartialShape& data_batch_shape,
const Strides& data_dilation,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const PartialShape& filters_shape,
const Strides& filter_strides,
const Strides& filter_dilation);
PartialShape infer_convolution_forward(const Node* node,
const PartialShape& data_batch_shape,
const Strides& data_dilation,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const PartialShape& filters_shape,
const Strides& filter_strides,
const Strides& filter_dilation);
PartialShape infer_batched_pooling_forward(const Node* node,
const PartialShape& data_batch_shape,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment