Commit 005660f0 authored by Tomasz Socha's avatar Tomasz Socha Committed by Scott Cyphers

Add upgrade and downgrade pass for GroupConvolutionBackpropData ops (#4035)

* Add upgrade and downgrade pass for GroupConvolutionBackpropData ops

- Add up/downgrade passes for GroupConvolutionBackpropData operators
- Improve decompose operatorion of v0::GroupConvolutionBackpropData to support N-dimensional data
- Add UT for up/downgrade passes.

* Remove unused variable
parent 5d116018
......@@ -600,49 +600,34 @@ shared_ptr<Node>
NodeVector op::v0::GroupConvolutionBackpropData::decompose_op() const
{
auto data_batch = input_value(0);
auto filters = input_value(1);
auto output_delta = input_value(2);
auto data_shape = get_input_shape(0);
auto filters_shape = get_input_shape(1);
auto delta_shape = get_input_shape(2);
NodeVector sliced_inputs;
for (size_t i = 0; i < get_groups(); ++i)
{
size_t channel_step = filters_shape.at(1);
const Coordinate data_lower_bound{0, i * channel_step, 0, 0};
const Coordinate data_upper_bound{
data_shape.at(0), (i + 1) * channel_step, data_shape.at(2), data_shape.at(3)};
auto sliced_data =
std::make_shared<op::Slice>(data_batch, data_lower_bound, data_upper_bound);
size_t filters_step = filters_shape.at(0) / get_groups();
const Coordinate filters_lower_bound{i * filters_step, 0, 0, 0};
const Coordinate filters_upper_bound{
(i + 1) * filters_step, filters_shape.at(1), filters_shape.at(2), filters_shape.at(3)};
auto sliced_filters =
std::make_shared<op::Slice>(filters, filters_lower_bound, filters_upper_bound);
auto groups = get_groups();
// slice data shape
data_shape[1] /= groups;
// slice delta
auto sliced_delta = builder::split(output_delta, groups, 1);
// slice filters
auto sliced_filters = builder::split(filters, groups, 0);
const Coordinate delta_lower_bound{0, i * filters_step, 0, 0};
const Coordinate delta_upper_bound{
delta_shape.at(0), (i + 1) * filters_step, delta_shape.at(2), delta_shape.at(3)};
auto sliced_delta =
std::make_shared<op::Slice>(output_delta, delta_lower_bound, delta_upper_bound);
auto num_spatials = get_window_movement_strides().size();
auto sliced_conv =
std::make_shared<op::ConvolutionBackpropData>(sliced_data->get_shape(),
sliced_filters,
sliced_delta,
for (size_t i = 0; i < groups; ++i)
{
auto sliced_conv = std::make_shared<op::ConvolutionBackpropData>(
data_shape,
sliced_filters[i],
sliced_delta[i],
get_window_movement_strides(),
get_window_dilation_strides(),
get_padding_below(),
get_padding_above(),
Strides{1, 1});
Strides(num_spatials, 1)); // default data dilation strides
sliced_inputs.push_back(sliced_conv);
}
......
......@@ -22,12 +22,14 @@
#include "ngraph/builder/reshape.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/slice_plan.hpp"
#include "ngraph/type.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
......@@ -300,6 +302,82 @@ namespace
return true;
}
bool op_cast(shared_ptr<op::v1::GroupConvolutionBackpropData> node)
{
auto output_shape_input =
as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
const auto data_arg = node->input_value(0);
const auto filters_arg = node->input_value(1);
const auto strides = node->get_strides();
const auto dilations = node->get_dilations();
NGRAPH_CHECK(
output_shape_input,
"Unable to convert GroupConvolutionBackpropData:v1 to GroupConvolutionBackpropData:v0 "
"if output_shape is not constant. Node: ",
*node);
auto output_padding = node->get_output_padding();
bool is_op_valid = all_of(
output_padding.begin(), output_padding.end(), [](size_t value) { return value == 0; });
NGRAPH_CHECK(
is_op_valid,
"Unable to convert GroupConvolutionBackpropData:v1 to GroupConvolutionBackpropData:v0 "
"with output padding other than `0`. Node: ",
*node);
NGRAPH_CHECK(data_arg.get_partial_shape().is_static(),
"Unable to convert GroupConvolution:1 to GroupConvolution:0"
"with dynamic data shape. Node: ",
*node);
NGRAPH_CHECK(filters_arg.get_partial_shape().is_static(),
"Unable to convert GroupConvolution:1 to GroupConvolution:0"
"with dynamic filters shape. Node: ",
*node);
auto filters_shape = filters_arg.get_shape();
auto data_shape = data_arg.get_shape();
auto groups = filters_shape.at(0);
filters_shape[1] *= groups;
filters_shape.erase(filters_shape.begin());
auto reshaped_filters = builder::reshape(node->input_value(1), filters_shape);
auto pads_begin = node->get_pads_begin();
auto pads_end = node->get_pads_end();
auto auto_pad = node->get_auto_pad();
auto output_shape = output_shape_input->get_shape_val();
if (auto_pad == op::PadType::SAME_UPPER || auto_pad == op::PadType::SAME_LOWER)
{
infer_auto_padding(output_shape,
Shape(filters_shape.begin() + 2, filters_shape.end()),
strides,
dilations,
auto_pad,
pads_begin,
pads_end);
}
output_shape.insert(output_shape.begin(), filters_shape[1]);
output_shape.insert(output_shape.begin(), data_shape[0]);
auto replacement_node = make_shared<op::v0::GroupConvolutionBackpropData>(
op::Constant::create(data_arg.get_element_type(), output_shape, {0}),
reshaped_filters,
data_arg,
node->get_strides(),
node->get_dilations(),
pads_begin,
pads_end,
groups);
replace_node(node, replacement_node);
return true;
}
bool op_cast(shared_ptr<op::v1::Less> node)
{
op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
......
......@@ -327,6 +327,46 @@ namespace
return true;
}
bool op_cast(shared_ptr<op::v0::GroupConvolutionBackpropData> node)
{
auto strides = node->get_window_movement_strides();
auto dilations = node->get_window_dilation_strides();
auto pads_begin = node->get_padding_below();
auto pads_end = node->get_padding_above();
auto data_batch_pshape = node->get_input_partial_shape(0);
NGRAPH_CHECK(data_batch_pshape.is_static(),
"Unable to convert GroupConvolution:0 to GroupConvolution:1"
"with dynamic data_batch shape. Node: ",
*node);
auto data_batch_shape = data_batch_pshape.to_shape();
data_batch_shape.erase(data_batch_shape.begin(), data_batch_shape.end());
NGRAPH_CHECK(node->get_input_partial_shape(1).is_static(),
"Unable to convert GroupConvolution:0 to GroupConvolution:1"
"with dynamic filters shape. Node: ",
*node);
auto filters_shape = node->get_input_shape(1);
auto groups = node->get_groups();
filters_shape[0] /= groups;
filters_shape.insert(filters_shape.begin(), groups);
auto reshaped_filters = builder::reshape(node->input_value(1), filters_shape);
auto replacement_node = make_shared<op::v1::GroupConvolutionBackpropData>(
node->input_value(2),
reshaped_filters,
op::Constant::create(element::i64, Shape{data_batch_shape.size()}, data_batch_shape),
strides,
pads_begin,
pads_end,
dilations);
replace_node(node, replacement_node);
return true;
}
bool op_cast(shared_ptr<op::Less> node)
{
op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
......@@ -657,7 +697,7 @@ namespace
};
return dispatch_map;
}
}
} // namespace
bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
{
......
......@@ -138,3 +138,63 @@ TEST(opset_transform, opset1_convolution_backprop_filters_downgrade_pass)
EXPECT_EQ(conv_v0_node->get_padding_above_forward(), padding_end);
EXPECT_EQ(conv_v0_node->get_data_dilation_strides_forward(), (Strides{1}));
}
TEST(opset_transform, opset1_group_convolution_backprop_data_downgrade_pass)
{
auto output_shape = op::Constant::create<int64_t>(element::i64, Shape{1}, {100});
auto filters = make_shared<op::Parameter>(element::f32, Shape{1, 128, 3, 10});
auto delta = make_shared<op::Parameter>(element::f32, Shape{64, 128, 96});
auto strides = Strides{1};
auto dilations = Strides{1};
auto padding_begin = CoordinateDiff{2};
auto padding_end = CoordinateDiff{3};
auto group_conv_backprop = make_shared<op::v1::GroupConvolutionBackpropData>(
delta, filters, output_shape, strides, padding_begin, padding_end, dilations);
auto result = make_shared<op::Result>(group_conv_backprop);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{filters, delta});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
auto group_conv_backprop_s0_result = f->get_results().at(0);
auto node = group_conv_backprop_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto group_conv_backprop_v0_node = as_type_ptr<op::v0::GroupConvolutionBackpropData>(node);
ASSERT_TRUE(group_conv_backprop_v0_node);
EXPECT_EQ(group_conv_backprop_v0_node->get_window_movement_strides(), strides);
EXPECT_EQ(group_conv_backprop_v0_node->get_window_dilation_strides(), dilations);
EXPECT_EQ(group_conv_backprop_v0_node->get_padding_below(), padding_begin);
EXPECT_EQ(group_conv_backprop_v0_node->get_padding_above(), padding_end);
}
TEST(opset_transform, opset1_group_convolution_backprop_data_upgrade_pass)
{
auto data_batch_shape = op::Constant::create<int64_t>(element::f32, Shape{64, 3, 100}, {0});
auto filters = make_shared<op::Parameter>(element::f32, Shape{128, 3, 10});
auto delta = make_shared<op::Parameter>(element::f32, Shape{64, 128, 96});
auto strides = Strides{1};
auto dilations = Strides{1};
auto padding_begin = CoordinateDiff{2};
auto padding_end = CoordinateDiff{3};
auto group_conv_backprop = make_shared<op::v0::GroupConvolutionBackpropData>(
data_batch_shape, filters, delta, strides, dilations, padding_begin, padding_end, 1);
auto result = make_shared<op::Result>(group_conv_backprop);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{filters, delta});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
auto group_conv_backprop_s1_result = f->get_results().at(0);
auto node = group_conv_backprop_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto group_conv_backprop_v1_node = as_type_ptr<op::v1::GroupConvolutionBackpropData>(node);
ASSERT_TRUE(group_conv_backprop_v1_node);
EXPECT_EQ(group_conv_backprop_v1_node->get_strides(), strides);
EXPECT_EQ(group_conv_backprop_v1_node->get_dilations(), dilations);
EXPECT_EQ(group_conv_backprop_v1_node->get_pads_begin(), padding_begin);
EXPECT_EQ(group_conv_backprop_v1_node->get_pads_end(), padding_end);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment