Commit eb2ce8e5 authored by Tomasz Socha's avatar Tomasz Socha Committed by Michał Karzyński

[SPEC] Fix output_shape input in (Group)ConvolutionBackpropData ops (#4005)

parent 206bc657
...@@ -205,13 +205,28 @@ op::v1::ConvolutionBackpropData::ConvolutionBackpropData(const Output<Node>& dat ...@@ -205,13 +205,28 @@ op::v1::ConvolutionBackpropData::ConvolutionBackpropData(const Output<Node>& dat
const PartialShape op::v1::ConvolutionBackpropData::get_output_shape() const const PartialShape op::v1::ConvolutionBackpropData::get_output_shape() const
{ {
PartialShape shape{PartialShape::dynamic()}; PartialShape shape{vector<Dimension>(m_strides.size() + 2)};
auto data_pshape = get_input_partial_shape(0);
if (data_pshape.rank().is_static())
{
shape[0] = data_pshape[0]; // N
}
auto filters_pshape = get_input_partial_shape(1);
if (filters_pshape.rank().is_static())
{
shape[1] = filters_pshape[1]; // C
}
bool is_output_shape_present = get_inputs().size() == 3; bool is_output_shape_present = get_inputs().size() == 3;
if (is_output_shape_present) if (is_output_shape_present)
{ {
if (auto const_op = as_type<op::Constant>(input_value(2).get_node())) if (auto const_op = as_type<op::Constant>(input_value(2).get_node()))
{ {
shape = const_op->get_shape_val(); auto output_shape = const_op->get_shape_val();
// Populate spatials
for (int i = 0; i < output_shape.size(); ++i)
{
shape[i + 2] = output_shape[i];
}
} }
} }
return shape; return shape;
...@@ -270,13 +285,6 @@ void op::v1::ConvolutionBackpropData::validate_and_infer_types() ...@@ -270,13 +285,6 @@ void op::v1::ConvolutionBackpropData::validate_and_infer_types()
if (is_output_shape_present) if (is_output_shape_present)
{ {
set_input_is_relevant_to_shape(2); set_input_is_relevant_to_shape(2);
if (output_pshape.is_static() && data_pshape.is_static())
{
auto data_shape = data_pshape.to_shape();
auto output_shape = output_pshape.to_shape();
output_shape.insert(output_shape.begin(), data_shape.begin(), data_shape.begin() + 1);
output_pshape = output_shape;
}
} }
else else
{ {
...@@ -295,12 +303,13 @@ void op::v1::ConvolutionBackpropData::validate_and_infer_types() ...@@ -295,12 +303,13 @@ void op::v1::ConvolutionBackpropData::validate_and_infer_types()
for (size_t i = 0; i < data_spatial_rank; ++i) for (size_t i = 0; i < data_spatial_rank; ++i)
{ {
size_t tmp = m_strides[i] * (data_shape[i + 2] - 1) + size_t tmp = m_strides[i] * (data_shape[i + 2] - 1) +
((filters_shape[i] + 2 - 1) * m_dilations[i] + 1) - m_pads_begin[i] - ((filters_shape[i + 2] - 1) * m_dilations[i] + 1) - m_pads_begin[i] -
m_pads_end[i] + output_padding[i]; m_pads_end[i] + output_padding[i];
output_shape.push_back(tmp); output_shape.push_back(tmp);
output_pshape = output_shape;
} }
output_shape.insert(output_shape.begin(), data_shape.begin(), data_shape.begin() + 1); output_shape.insert(output_shape.begin(), filters_shape.at(1));
output_shape.insert(output_shape.begin(), data_shape.at(0));
output_pshape = output_shape;
} }
} }
......
...@@ -192,13 +192,28 @@ op::v1::GroupConvolutionBackpropData::GroupConvolutionBackpropData( ...@@ -192,13 +192,28 @@ op::v1::GroupConvolutionBackpropData::GroupConvolutionBackpropData(
const PartialShape op::v1::GroupConvolutionBackpropData::get_output_shape() const const PartialShape op::v1::GroupConvolutionBackpropData::get_output_shape() const
{ {
PartialShape shape{PartialShape::dynamic()}; PartialShape shape{vector<Dimension>(m_strides.size() + 2)};
auto data_pshape = get_input_partial_shape(0);
if (data_pshape.rank().is_static())
{
shape[0] = data_pshape[0]; // N
}
auto filters_pshape = get_input_partial_shape(1);
if (filters_pshape.rank().is_static())
{
shape[1] = filters_pshape[1]; // C
}
bool is_output_shape_present = get_inputs().size() == 3; bool is_output_shape_present = get_inputs().size() == 3;
if (is_output_shape_present) if (is_output_shape_present)
{ {
if (auto const_op = as_type<op::Constant>(input_value(2).get_node())) if (auto const_op = as_type<op::Constant>(input_value(2).get_node()))
{ {
shape = const_op->get_shape_val(); auto output_shape = const_op->get_shape_val();
// Populate spatials
for (int i = 0; i < output_shape.size(); ++i)
{
shape[i + 2] = output_shape[i];
}
} }
} }
return shape; return shape;
...@@ -257,26 +272,16 @@ void op::v1::GroupConvolutionBackpropData::validate_and_infer_types() ...@@ -257,26 +272,16 @@ void op::v1::GroupConvolutionBackpropData::validate_and_infer_types()
if (is_output_shape_present) if (is_output_shape_present)
{ {
set_input_is_relevant_to_shape(2); set_input_is_relevant_to_shape(2);
if (output_pshape.is_static() && data_pshape.is_static())
{
auto data_shape = data_pshape.to_shape();
auto output_shape = output_pshape.to_shape();
output_shape.insert(output_shape.begin(), data_shape.begin(), data_shape.begin() + 1);
output_pshape = output_shape;
}
} }
else else
{ {
if (filters_pshape.is_static() && data_pshape.is_static()) if (filters_pshape.is_static() && data_pshape.is_static())
{ {
auto filters_shape = filters_pshape.to_shape(); auto filters_shape = filters_pshape.to_shape();
filters_shape.erase(filters_shape.begin(),
filters_shape.begin() + 3); // remove {G, O, I}
auto data_shape = data_pshape.to_shape(); auto data_shape = data_pshape.to_shape();
data_shape.erase(data_shape.begin(), data_shape.begin() + 2); // remove {N, C}
Shape output_shape; Shape output_shape;
auto data_spatial_rank = data_shape.size(); auto data_spatial_rank = data_shape.size() - 2;
auto output_padding = get_output_padding(); auto output_padding = get_output_padding();
if (output_padding.size() == 0) if (output_padding.size() == 0)
{ {
...@@ -284,13 +289,15 @@ void op::v1::GroupConvolutionBackpropData::validate_and_infer_types() ...@@ -284,13 +289,15 @@ void op::v1::GroupConvolutionBackpropData::validate_and_infer_types()
} }
for (size_t i = 0; i < data_spatial_rank; ++i) for (size_t i = 0; i < data_spatial_rank; ++i)
{ {
size_t tmp = m_strides[i] * (data_shape[i] - 1) + size_t tmp = m_strides[i] * (data_shape[i + 2] - 1) +
((filters_shape[i] - 1) * m_dilations[i] + 1) - m_pads_begin[i] - ((filters_shape[i + 3] - 1) * m_dilations[i] + 1) - m_pads_begin[i] -
m_pads_end[i] + output_padding[i]; m_pads_end[i] + output_padding[i];
output_shape.push_back(tmp); output_shape.push_back(tmp);
output_pshape = output_shape;
} }
output_shape.insert(output_shape.begin(), data_shape.begin(), data_shape.begin() + 1); output_shape.insert(output_shape.begin(),
filters_shape.at(0) * filters_shape.at(2)); // GROUP * C_OUTPUT
output_shape.insert(output_shape.begin(), data_shape.at(0));
output_pshape = output_shape;
} }
} }
......
...@@ -135,11 +135,12 @@ namespace ...@@ -135,11 +135,12 @@ namespace
bool op_cast(shared_ptr<op::v1::ConvolutionBackpropData> node) bool op_cast(shared_ptr<op::v1::ConvolutionBackpropData> node)
{ {
auto output_shape = as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr()); auto output_shape_node =
as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
const auto data_arg = node->input(0).get_source_output(); const auto data_arg = node->input(0).get_source_output();
const auto filters_arg = node->input(1).get_source_output(); const auto filters_arg = node->input(1).get_source_output();
const auto strides = node->get_strides(); const auto strides = node->get_strides();
NGRAPH_CHECK(output_shape, NGRAPH_CHECK(output_shape_node,
"Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 " "Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 "
"if output_shape is not constant. Node: ", "if output_shape is not constant. Node: ",
*node); *node);
...@@ -155,8 +156,22 @@ namespace ...@@ -155,8 +156,22 @@ namespace
"with output padding other than `0`. Node: ", "with output padding other than `0`. Node: ",
*node); *node);
auto data_pshape = data_arg.get_partial_shape();
auto filters_pshape = filters_arg.get_partial_shape();
NGRAPH_CHECK(data_pshape.rank().is_static() && data_pshape[0].is_static() &&
filters_pshape.rank().is_static() && filters_pshape[1].is_static(),
"Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 "
"if data shape N and filters shape C dimensions are not static. Node: ",
*node);
// Add N and C dimenstions to output_shape
auto output_shape = output_shape_node->get_shape_val();
output_shape.insert(output_shape.begin(), static_cast<size_t>(filters_pshape[1]));
output_shape.insert(output_shape.begin(), static_cast<size_t>(data_pshape[0]));
auto replacement_node = auto replacement_node =
make_shared<op::v0::ConvolutionBackpropData>(output_shape->get_shape_val(), make_shared<op::v0::ConvolutionBackpropData>(output_shape,
filters_arg, filters_arg,
data_arg, data_arg,
node->get_strides(), node->get_strides(),
......
...@@ -179,7 +179,10 @@ namespace ...@@ -179,7 +179,10 @@ namespace
auto replacement_node = make_shared<op::v1::ConvolutionBackpropData>( auto replacement_node = make_shared<op::v1::ConvolutionBackpropData>(
node->input_value(1), // data node->input_value(1), // data
node->input_value(0), // filters node->input_value(0), // filters
op::Constant::create(element::i64, Shape{data_batch_shape.size()}, data_batch_shape), op::Constant::create(
element::i64,
Shape{data_batch_shape.size() - 2},
vector<size_t>(data_batch_shape.begin() + 2, data_batch_shape.end())),
strides, strides,
pads_begin, pads_begin,
pads_end, pads_end,
......
...@@ -175,7 +175,7 @@ NGRAPH_TEST(${BACKEND_NAME}, dyn_convolution_backprop_data) ...@@ -175,7 +175,7 @@ NGRAPH_TEST(${BACKEND_NAME}, dyn_convolution_backprop_data)
for (int i = 0; i < 2 * 3 * 5 * 5; i++) for (int i = 0; i < 2 * 3 * 5 * 5; i++)
expected_result.emplace_back(i); expected_result.emplace_back(i);
vector<int64_t> shapes = {2, 3, 5, 5}; vector<int64_t> shapes = {5, 5};
// Create some tensors for input/output // Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_delta); auto a = backend->create_tensor(element::f32, shape_delta);
......
...@@ -78,7 +78,7 @@ TEST(opset_transform, opset1_convolution_downgrade_pass) ...@@ -78,7 +78,7 @@ TEST(opset_transform, opset1_convolution_downgrade_pass)
TEST(opset_transform, opset1_convolution_backprop_data_downgrade_pass) TEST(opset_transform, opset1_convolution_backprop_data_downgrade_pass)
{ {
auto data_batch_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {64, 3, 100}); auto data_batch_shape = op::Constant::create<int64_t>(element::i64, Shape{1}, {100});
auto filters = make_shared<op::Parameter>(element::f32, Shape{128, 3, 10}); auto filters = make_shared<op::Parameter>(element::f32, Shape{128, 3, 10});
auto delta = make_shared<op::Parameter>(element::f32, Shape{64, 128, 96}); auto delta = make_shared<op::Parameter>(element::f32, Shape{64, 128, 96});
auto strides = Strides{1}; auto strides = Strides{1};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment