Commit dee4a8b8 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[ONNX] ConvTranspose with groups. (#2289)

* Enable support for group attribute.

* UT for ConvTranspose with groups.

* Validate group attribute value.

* Move helper function to unnamed namespace.

* Access values with bounds checking.

* Fix spelling.
parent 0eaa960c
...@@ -55,7 +55,6 @@ namespace ngraph ...@@ -55,7 +55,6 @@ namespace ngraph
// reference: https://github.com/NervanaSystems/ngraph-mxnet/blob/fdd692/src/ngraph/ngraph_emitter.cc#L822-L856 // reference: https://github.com/NervanaSystems/ngraph-mxnet/blob/fdd692/src/ngraph/ngraph_emitter.cc#L822-L856
std::size_t n_data_channels{data->get_shape().at(1)}; std::size_t n_data_channels{data->get_shape().at(1)};
std::size_t n_filters_channels{filters->get_shape().at(0)}; std::size_t n_filters_channels{filters->get_shape().at(0)};
// TODO: ensure n_data_channels % groups = 0
std::size_t data_group_size{n_data_channels / groups}; std::size_t data_group_size{n_data_channels / groups};
std::size_t filters_group_size{n_filters_channels / groups}; std::size_t filters_group_size{n_filters_channels / groups};
NodeVector convolution_nodes; NodeVector convolution_nodes;
...@@ -114,6 +113,16 @@ namespace ngraph ...@@ -114,6 +113,16 @@ namespace ngraph
(groups <= filters->get_shape().at(0)))) (groups <= filters->get_shape().at(0))))
<< "incorrect value of 'group' attribute: " << groups; << "incorrect value of 'group' attribute: " << groups;
std::size_t n_data_channels{data->get_shape().at(1)};
std::size_t n_filters_channels{filters->get_shape().at(0)};
ASSERT_VALID_ARGUMENT(node, n_data_channels % groups == 0)
<< "provided group attribute value must be a multiple of data channels "
"count.";
ASSERT_VALID_ARGUMENT(node, n_filters_channels % groups == 0)
<< "provided group attribute value must be a multiple of filter channels "
"count.";
auto strides = convpool::get_strides(node); auto strides = convpool::get_strides(node);
auto dilations = convpool::get_dilations(node); auto dilations = convpool::get_dilations(node);
auto paddings = convpool::get_pads(node); auto paddings = convpool::get_pads(node);
......
...@@ -14,20 +14,24 @@ ...@@ -14,20 +14,24 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <cstddef>
#include <cstdint> #include <cstdint>
#include <iterator> #include <iterator>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/coordinate_diff.hpp" #include "ngraph/coordinate_diff.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/op/conv_transpose.hpp" #include "ngraph/frontend/onnx_import/op/conv_transpose.hpp"
#include "ngraph/frontend/onnx_import/utils/broadcasting.hpp" #include "ngraph/frontend/onnx_import/utils/broadcasting.hpp"
#include "ngraph/frontend/onnx_import/utils/convpool.hpp" #include "ngraph/frontend/onnx_import/utils/convpool.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/strides.hpp" #include "ngraph/strides.hpp"
...@@ -40,6 +44,79 @@ namespace ngraph ...@@ -40,6 +44,79 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
namespace
{
std::shared_ptr<ngraph::Node>
make_ng_conv_transpose(std::int64_t groups,
const Shape& data_batch_shape,
const std::shared_ptr<ngraph::Node>& filters,
const std::shared_ptr<ngraph::Node>& data,
const Strides& strides,
const Strides& dilations,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides)
{
if (groups > 1)
{
// Split one convolution op to N ops where N is the number of groups
// and concat results after computation.
std::size_t n_data_channels{data->get_shape().at(1)};
std::size_t n_filters_channels{filters->get_shape().at(0)};
std::size_t data_group_size{n_data_channels / groups};
std::size_t filters_group_size{n_filters_channels / groups};
NodeVector conv_transpose_nodes;
// initial bounds for slice
std::vector<std::size_t> data_lower_bounds(data->get_shape().size());
std::vector<std::size_t> data_upper_bounds{data->get_shape()};
std::vector<std::size_t> filters_lower_bounds(
filters->get_shape().size());
std::vector<std::size_t> filters_upper_bounds{filters->get_shape()};
for (std::size_t group{0}; group < groups; ++group)
{
// slice data
data_lower_bounds[1] = group * data_group_size;
data_upper_bounds[1] = (group + 1) * data_group_size;
auto sliced_data = std::make_shared<ngraph::op::Slice>(
data, data_lower_bounds, data_upper_bounds);
// slice filters
filters_lower_bounds[0] = group * filters_group_size;
filters_upper_bounds[0] = (group + 1) * filters_group_size;
auto sliced_filters = std::make_shared<ngraph::op::Slice>(
filters, filters_lower_bounds, filters_upper_bounds);
conv_transpose_nodes.push_back(
std::make_shared<ngraph::op::ConvolutionBackpropData>(
data_batch_shape,
sliced_filters,
sliced_data,
strides,
dilations,
padding_below,
padding_above,
data_dilation_strides));
}
std::size_t concatenation_axis = 1;
return std::make_shared<ngraph::op::Concat>(conv_transpose_nodes,
concatenation_axis);
}
else
{
return std::make_shared<ngraph::op::ConvolutionBackpropData>(
data_batch_shape,
filters,
data,
strides,
dilations,
padding_below,
padding_above,
data_dilation_strides);
}
}
} // anonymous namespace
NodeVector conv_transpose(const Node& node) NodeVector conv_transpose(const Node& node)
{ {
const NodeVector& inputs = node.get_ng_inputs(); const NodeVector& inputs = node.get_ng_inputs();
...@@ -64,9 +141,26 @@ namespace ngraph ...@@ -64,9 +141,26 @@ namespace ngraph
node.get_attribute_value<std::vector<std::int64_t>>( node.get_attribute_value<std::vector<std::int64_t>>(
"output_padding", std::vector<std::int64_t>(num_spatial_dims, 0))}; "output_padding", std::vector<std::int64_t>(num_spatial_dims, 0))};
int64_t groups{node.get_attribute_value<int64_t>("group", 1)};
ASSERT_VALID_ARGUMENT(node,
((groups >= 0) && (groups <= data->get_shape().at(1)) &&
(groups <= filters->get_shape().at(0))))
<< "incorrect value of 'group' attribute: " << groups;
std::size_t n_data_channels{data_shape.at(1)};
std::size_t n_filters_channels{weights_shape.at(0)};
ASSERT_VALID_ARGUMENT(node, n_data_channels % groups == 0)
<< "provided group attribute value must be a multiple of data channels "
"count.";
ASSERT_VALID_ARGUMENT(node, n_filters_channels % groups == 0)
<< "provided group attribute value must be a multiple of filter channels "
"count.";
Shape data_batch_shape(data_shape.size(), 1); Shape data_batch_shape(data_shape.size(), 1);
data_batch_shape[0] = data_shape[0]; data_batch_shape.at(0) = data_shape.at(0);
data_batch_shape[1] = weights_shape[1]; data_batch_shape.at(1) = weights_shape.at(1);
if (!output_shape.empty()) if (!output_shape.empty())
{ {
...@@ -116,7 +210,8 @@ namespace ngraph ...@@ -116,7 +210,8 @@ namespace ngraph
} }
} }
auto conv_node = std::make_shared<ngraph::op::ConvolutionBackpropData>( std::shared_ptr<ngraph::Node> conv_node =
make_ng_conv_transpose(groups,
data_batch_shape, data_batch_shape,
filters, filters,
data, data,
......
...@@ -1628,3 +1628,22 @@ TEST(onnx, model_sum_opset8) ...@@ -1628,3 +1628,22 @@ TEST(onnx, model_sum_opset8)
Outputs outputs{execute(function, inputs, "INTERPRETER")}; Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front())); EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
} }
TEST(onnx, model_conv_transpose_w_groups)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/conv_transpose_w_groups.onnx"));
Inputs inputs;
inputs.emplace_back(std::vector<float>{
0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f});
inputs.emplace_back(std::vector<float>{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f,
8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f,
16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f,
24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.0f});
Outputs expected_output{
std::vector<float>{28.f, 34.f, 252.f, 274.f, 732.f, 770.f, 1468.f, 1522.f}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment