Commit e21db881 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Robert Kimball

[FUSED] Group Transpose Convolution (#3040)

* Adding GroupConvTranspose fused operator.

* Add missing header and remove commented code.

* Remove unused variable.

* Add a few more convieniece constructors.

* Add more type prop UTs.

* Remove unused post validation functions.

* Style apply.

* Fix conversion of vector to CoordinateDiff

* Add GroupConvolutionTranspose to intel gpu backend.

* Add documentation.

* Use default (python-like) divide.
parent c0b4f2d3
......@@ -308,6 +308,8 @@ set (SRC
op/fused/grn.hpp
op/fused/group_conv.hpp
op/fused/group_conv.cpp
op/fused/group_conv_transpose.hpp
op/fused/group_conv_transpose.cpp
op/fused/leaky_relu.cpp
op/fused/leaky_relu.hpp
op/fused/mvn.cpp
......
......@@ -25,13 +25,7 @@
#include "ngraph/frontend/onnx_import/op/conv_transpose.hpp"
#include "ngraph/frontend/onnx_import/utils/convpool.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp"
......@@ -44,79 +38,6 @@ namespace ngraph
{
namespace set_1
{
namespace
{
std::shared_ptr<ngraph::Node>
make_ng_conv_transpose(std::int64_t groups,
const Shape& data_batch_shape,
const std::shared_ptr<ngraph::Node>& filters,
const std::shared_ptr<ngraph::Node>& data,
const Strides& strides,
const Strides& dilations,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides)
{
if (groups > 1)
{
// Split one convolution op to N ops where N is the number of groups
// and concat results after computation.
std::size_t n_data_channels{data->get_shape().at(1)};
std::size_t n_filters_channels{filters->get_shape().at(0)};
std::size_t data_group_size{n_data_channels / groups};
std::size_t filters_group_size{n_filters_channels / groups};
NodeVector conv_transpose_nodes;
// initial bounds for slice
std::vector<std::size_t> data_lower_bounds(data->get_shape().size());
std::vector<std::size_t> data_upper_bounds{data->get_shape()};
std::vector<std::size_t> filters_lower_bounds(
filters->get_shape().size());
std::vector<std::size_t> filters_upper_bounds{filters->get_shape()};
for (std::size_t group{0}; group < groups; ++group)
{
// slice data
data_lower_bounds[1] = group * data_group_size;
data_upper_bounds[1] = (group + 1) * data_group_size;
auto sliced_data = std::make_shared<ngraph::op::Slice>(
data, data_lower_bounds, data_upper_bounds);
// slice filters
filters_lower_bounds[0] = group * filters_group_size;
filters_upper_bounds[0] = (group + 1) * filters_group_size;
auto sliced_filters = std::make_shared<ngraph::op::Slice>(
filters, filters_lower_bounds, filters_upper_bounds);
conv_transpose_nodes.push_back(
std::make_shared<ngraph::op::ConvolutionBackpropData>(
data_batch_shape,
sliced_filters,
sliced_data,
strides,
dilations,
padding_below,
padding_above,
data_dilation_strides));
}
std::size_t concatenation_axis = 1;
return std::make_shared<ngraph::op::Concat>(conv_transpose_nodes,
concatenation_axis);
}
else
{
return std::make_shared<ngraph::op::ConvolutionBackpropData>(
data_batch_shape,
filters,
data,
strides,
dilations,
padding_below,
padding_above,
data_dilation_strides);
}
}
} // anonymous namespace
NodeVector conv_transpose(const Node& node)
{
const NodeVector& inputs = node.get_ng_inputs();
......@@ -130,10 +51,9 @@ namespace ngraph
auto strides = convpool::get_strides(node);
auto dilations = convpool::get_dilations(node);
auto paddings = convpool::get_pads(node);
ngraph::CoordinateDiff padding_below = paddings.first;
ngraph::CoordinateDiff padding_above = paddings.second;
CoordinateDiff padding_below = paddings.first;
CoordinateDiff padding_above = paddings.second;
Strides data_dilation_strides(num_spatial_dims, 1);
std::vector<std::int64_t> output_shape{
node.get_attribute_value<std::vector<std::int64_t>>("output_shape", {})};
......@@ -158,73 +78,31 @@ namespace ngraph
<< "provided group attribute value must be a multiple of filter channels "
"count.";
Shape data_batch_shape(data_shape.size(), 1);
data_batch_shape.at(0) = data_shape.at(0);
data_batch_shape.at(1) = weights_shape.at(1);
std::shared_ptr<ngraph::Node> conv_node;
if (!output_shape.empty())
{
if (output_shape.size() > num_spatial_dims)
{
output_shape.erase(std::begin(output_shape),
std::begin(output_shape) + 2);
}
for (int i = 0; i < num_spatial_dims; ++i)
{
padding_below[i] = strides[i] * (data_shape[i + 2] - 1) +
dilations[i] * (weights_shape[i + 2] - 1) -
data_dilation_strides[i] *
(output_shape[i] - output_padding[i] - 1);
if (padding_below[i] < 0)
{
// (int) -9 / 2 = -5 but we need -4
// (int) -9 --> 9 / 2 = 4 --> -4
padding_below[i] = -(-padding_below[i] / 2);
}
else
{
padding_below[i] /= 2;
}
padding_above[i] = padding_below[i];
data_batch_shape[i + 2] = output_shape[i];
}
conv_node = std::make_shared<ngraph::op::GroupConvolutionTranspose>(
data,
filters,
strides,
dilations,
CoordinateDiff(std::begin(output_padding), std::end(output_padding)),
Shape(std::begin(output_shape), std::end(output_shape)),
groups);
}
else
{
for (int i = 0; i < num_spatial_dims; ++i)
{
// Calculating spatial dims of data output shape for ngraph conv backprop op
// | s(ds-1) + d(ws-1) - pb - pa |
// | --------------------------- | + 1 + op
// | _ dds _ |
//
// d - dilation
// ds - data shape
// dds - data dilation strides
// op - output padding
// pa - padding above
// pb - padding below
// s - strides
// ws - weights shape
data_batch_shape[i + 2] = (strides[i] * (data_shape[i + 2] - 1) +
dilations[i] * (weights_shape[i + 2] - 1) -
padding_below[i] - padding_above[i]) /
data_dilation_strides[i] +
1 + output_padding[i];
}
conv_node = std::make_shared<ngraph::op::GroupConvolutionTranspose>(
data,
filters,
strides,
dilations,
padding_below,
padding_above,
CoordinateDiff(std::begin(output_padding), std::end(output_padding)),
groups);
}
std::shared_ptr<ngraph::Node> conv_node =
make_ng_conv_transpose(groups,
data_batch_shape,
filters,
data,
strides,
dilations,
padding_below,
padding_above,
data_dilation_strides);
// no bias param
if (inputs.size() < 3)
{
......
......@@ -103,6 +103,7 @@
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/leaky_relu.hpp"
#include "ngraph/op/fused/mvn.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <iterator>
#include <numeric>
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/util.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
op::GroupConvolutionTranspose::GroupConvolutionTranspose(const shared_ptr<Node>& data,
const shared_ptr<Node>& filters,
const Strides& strides,
const Strides& dilations,
const CoordinateDiff& padding_begin,
const CoordinateDiff& padding_end,
const CoordinateDiff& output_padding,
const size_t groups,
const PadType& pad_type,
const Shape& output_shape)
: FusedOp("GroupConvolutionTranspose", check_single_output_args({data, filters}))
, m_strides(strides)
, m_dilations(dilations)
, m_padding_begin(padding_begin)
, m_padding_end(padding_end)
, m_output_padding(output_padding)
, m_groups(groups)
, m_pad_type(pad_type)
, m_output_shape(output_shape)
{
constructor_validate_and_infer_types();
}
op::GroupConvolutionTranspose::GroupConvolutionTranspose(const std::shared_ptr<Node>& data,
const std::shared_ptr<Node>& filters,
const std::size_t groups)
: GroupConvolutionTranspose(data,
filters,
Strides(),
Strides(),
CoordinateDiff(),
CoordinateDiff(),
CoordinateDiff(),
groups,
PadType::EXPLICIT,
Shape())
{
}
op::GroupConvolutionTranspose::GroupConvolutionTranspose(const std::shared_ptr<Node>& data,
const std::shared_ptr<Node>& filters,
const Strides& strides,
const Strides& dilations,
const CoordinateDiff& output_padding,
const Shape& output_shape,
const std::size_t groups)
: GroupConvolutionTranspose(data,
filters,
strides,
dilations,
CoordinateDiff(),
CoordinateDiff(),
output_padding,
groups,
PadType::EXPLICIT,
output_shape)
{
}
op::GroupConvolutionTranspose::GroupConvolutionTranspose(const std::shared_ptr<Node>& data,
const std::shared_ptr<Node>& filters,
const Shape& output_shape,
const std::size_t groups)
: GroupConvolutionTranspose(data,
filters,
Strides(),
Strides(),
CoordinateDiff(),
CoordinateDiff(),
CoordinateDiff(),
groups,
PadType::EXPLICIT,
output_shape)
{
}
void op::GroupConvolutionTranspose::pre_validate_and_infer_types()
{
auto data_pshape = get_input_partial_shape(0);
auto filters_pshape = get_input_partial_shape(1);
if (data_pshape.is_static() && filters_pshape.is_static())
{
const Shape& data_shape = data_pshape.to_shape();
const Shape& filters_shape = filters_pshape.to_shape();
size_t n_data_channels{data_shape.at(1)};
size_t n_filters_channels{filters_shape.at(0)};
// groups
NODE_VALIDATION_CHECK(this,
(m_groups <= n_data_channels && m_groups <= n_filters_channels),
"Incorrect value of groups: ",
m_groups);
// filter channels
NODE_VALIDATION_CHECK(
this,
n_filters_channels == n_data_channels,
"Number of filters channels must be equal to number of data channels.");
// data channels
NODE_VALIDATION_CHECK(this,
n_data_channels % m_groups == 0,
"Number of data channels not a multiple of group size.");
// padding type
NODE_VALIDATION_CHECK(
this, m_pad_type == PadType::EXPLICIT, "Currently only eplicit pad type is supported.");
if (m_padding_begin.size() == 0)
{
m_padding_begin = conv_default_padding(this, data_pshape, filters_pshape);
}
if (m_padding_end.size() == 0)
{
m_padding_end = conv_default_padding(this, data_pshape, filters_pshape);
}
if (m_output_padding.size() == 0)
{
m_output_padding = conv_default_padding(this, data_pshape, filters_pshape);
}
if (m_strides.size() == 0)
{
m_strides = conv_default_strides(this, data_pshape, filters_pshape);
}
if (m_dilations.size() == 0)
{
m_dilations = conv_default_strides(this, data_pshape, filters_pshape);
}
const size_t num_spatial_dims = data_shape.size() - 2;
NODE_VALIDATION_CHECK(this,
m_strides.size() == num_spatial_dims,
"Strides should be of number of input data features size.");
NODE_VALIDATION_CHECK(this,
m_dilations.size() == num_spatial_dims,
"Dilations should be of number of input data features size.");
NODE_VALIDATION_CHECK(this,
m_output_padding.size() == num_spatial_dims,
"Output padding should be of number of input data features size.");
// If output shape is provided, ignore current values for padding begin/end and infer them.
if (!m_output_shape.empty())
{
m_padding_begin = CoordinateDiff(num_spatial_dims);
m_padding_end = CoordinateDiff(num_spatial_dims);
Shape out_shape(m_output_shape);
if (out_shape.size() > num_spatial_dims)
{
out_shape.erase(std::begin(out_shape), std::begin(out_shape) + 2);
}
for (int i = 0; i < num_spatial_dims; ++i)
{
int total_padding = m_strides[i] * (data_shape[i + 2] - 1) +
m_dilations[i] * (filters_shape[i + 2] - 1) - out_shape[i] +
m_output_padding[i] + 1;
m_padding_begin[i] = total_padding / 2;
}
m_padding_end = m_padding_begin;
}
}
}
shared_ptr<Node> op::GroupConvolutionTranspose::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::GroupConvolutionTranspose>(new_args.at(0),
new_args.at(1),
get_strides(),
get_dilations(),
get_padding_begin(),
get_padding_end(),
get_output_padding(),
get_groups(),
get_pad_type());
}
Shape op::GroupConvolutionTranspose::get_data_batch_shape() const
{
const auto& data_shape = get_argument(0)->get_shape();
const auto& filters_shape = get_argument(1)->get_shape();
const size_t num_spatial_dims = data_shape.size() - 2;
Shape data_batch_shape(data_shape.size(), 1);
data_batch_shape.at(0) = data_shape.at(0);
data_batch_shape.at(1) = filters_shape.at(1);
if (m_output_shape.empty())
{
for (size_t i = 0; i < num_spatial_dims; ++i)
{
data_batch_shape[i + 2] = m_strides[i] * (data_shape[i + 2] - 1) +
m_dilations[i] * (filters_shape[i + 2] - 1) -
m_padding_begin[i] - m_padding_end[i] + m_output_padding[i] +
1;
}
}
else
{
Shape output_shape(m_output_shape);
if (output_shape.size() > num_spatial_dims)
{
output_shape.erase(std::begin(output_shape), std::begin(output_shape) + 2);
}
for (size_t i = 0; i < num_spatial_dims; ++i)
{
data_batch_shape[i + 2] = output_shape[i];
}
}
return data_batch_shape;
}
NodeVector op::GroupConvolutionTranspose::decompose_op() const
{
auto data = get_argument(0);
auto filters = get_argument(1);
const Shape data_batch_shape = get_data_batch_shape();
const size_t num_spatial_dims = data->get_shape().size() - 2;
if (m_groups > 1)
{
// Split one convolution op to N ops where N is the number of groups
// and concat results after computation.
const size_t n_data_channels{data->get_shape().at(1)};
const size_t n_filters_channels{filters->get_shape().at(0)};
const size_t data_group_size{n_data_channels / m_groups};
const size_t filters_group_size{n_filters_channels / m_groups};
NodeVector convolution_nodes;
// initial bounds for slice
vector<size_t> data_lower_bounds(data->get_shape().size());
vector<size_t> data_upper_bounds{data->get_shape()};
vector<size_t> filters_lower_bounds(filters->get_shape().size());
vector<size_t> filters_upper_bounds{filters->get_shape()};
for (size_t group{0}; group < m_groups; ++group)
{
// slice data
data_lower_bounds[1] = group * data_group_size;
data_upper_bounds[1] = (group + 1) * data_group_size;
auto sliced_data = make_shared<op::Slice>(data, data_lower_bounds, data_upper_bounds);
// slice filters
filters_lower_bounds[0] = group * filters_group_size;
filters_upper_bounds[0] = (group + 1) * filters_group_size;
auto sliced_filters =
make_shared<op::Slice>(filters, filters_lower_bounds, filters_upper_bounds);
convolution_nodes.push_back(
make_shared<op::ConvolutionBackpropData>(data_batch_shape,
sliced_filters,
sliced_data,
m_strides,
m_dilations,
m_padding_begin,
m_padding_end,
Strides(num_spatial_dims, 1)));
}
size_t concatenation_axis = 1;
return {make_shared<op::Concat>(convolution_nodes, concatenation_axis)};
}
else
{
return {make_shared<op::ConvolutionBackpropData>(data_batch_shape,
filters,
data,
m_strides,
m_dilations,
m_padding_begin,
m_padding_end,
Strides(num_spatial_dims, 1))};
}
}
void op::GroupConvolutionTranspose::generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas)
{
throw ngraph_error(
"Generating adjoints is not yet implemented for GroupConvolutionTranspose node.");
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstdlib>
#include <memory>
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp"
namespace ngraph
{
namespace op
{
/// \brief Group Transpose Convolution (Deconvolution)
class GroupConvolutionTranspose : public util::FusedOp
{
public:
///
/// \brief Constructs GroupConvolutionTranspose operation.
///
/// \param[in] data The node producing input data.
/// \param[in] filters The node producing filters data.
/// \param[in] strides The strides along each feature axis.
/// \param[in] dilations The dilations along each feature axis.
/// \param[in] padding_begin The padding added at the beggining of each feature axis.
/// \param[in] padding_end The padding added at the end of each feature axis.
/// \param[in] output_padding The zero-padding (adjustment) added to one side of the output.
/// \param[in] groups The number of groups the input channels and output channels
/// are divided into.
/// \param[in] pad_type The provided padding type.
/// \param[in] output_shape The output shape. When provided padding values are
/// automatically inferred.
///
GroupConvolutionTranspose(const std::shared_ptr<Node>& data,
const std::shared_ptr<Node>& filters,
const Strides& strides,
const Strides& dilations,
const CoordinateDiff& padding_begin,
const CoordinateDiff& padding_end,
const CoordinateDiff& output_padding,
const std::size_t groups = 1UL,
const PadType& pad_type = PadType::EXPLICIT,
const Shape& output_shape = Shape{});
///
/// \brief Constructs GroupConvolutionTranspose operation.
///
/// \param[in] data The node producing input data.
/// \param[in] filters The node producing filters data.
/// \param[in] groups The number of groups the input channels and output channels
/// are divided into.
///
GroupConvolutionTranspose(const std::shared_ptr<Node>& data,
const std::shared_ptr<Node>& filters,
const std::size_t groups = 1UL);
///
/// \brief Constructs GroupConvolutionTranspose operation.
///
/// \param[in] data The node producing input data.
/// \param[in] filters The node producing filters data.
/// \param[in] strides The strides along each feature axis.
/// \param[in] dilations The dilations along each feature axis.
/// \param[in] output_padding The zero-padding (adjustment) added to one side of the output.
/// \param[in] output_shape The output shape. When provided padding values are
/// automatically inferred.
/// \param[in] groups The number of groups the input channels and output channels
/// are divided into.
///
GroupConvolutionTranspose(const std::shared_ptr<Node>& data,
const std::shared_ptr<Node>& filters,
const Strides& strides,
const Strides& dilations,
const CoordinateDiff& output_padding,
const Shape& output_shape,
const std::size_t groups = 1UL);
///
/// \brief Constructs GroupConvolutionTranspose operation.
///
/// \param[in] data The node producing input data.
/// \param[in] filters The node producing filters data.
/// \param[in] output_shape The output shape. When provided padding values are
/// automatically inferred.
/// \param[in] groups The number of groups the input channels and output channels
/// are divided into.
///
GroupConvolutionTranspose(const std::shared_ptr<Node>& data,
const std::shared_ptr<Node>& filters,
const Shape& output_shape,
const std::size_t groups = 1UL);
std::shared_ptr<Node> get_filters() { return get_argument(1); }
std::shared_ptr<Node> get_data() { return get_argument(0); }
const Strides& get_strides() const { return m_strides; }
const Strides& get_dilations() const { return m_dilations; }
const CoordinateDiff& get_padding_begin() const { return m_padding_begin; }
const CoordinateDiff& get_padding_end() const { return m_padding_end; }
const CoordinateDiff& get_output_padding() const { return m_output_padding; }
std::size_t get_groups() const { return m_groups; }
const PadType& get_pad_type() const { return m_pad_type; }
const Shape& get_output_shape() const { return m_output_shape; }
virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
///
/// \brief Calculate the shape of the data batch from forward propagation.
///
/// \return The data batch shape.
///
Shape get_data_batch_shape() const;
Strides m_strides;
Strides m_dilations;
CoordinateDiff m_padding_begin;
CoordinateDiff m_padding_end;
CoordinateDiff m_output_padding;
std::size_t m_groups;
PadType m_pad_type;
Shape m_output_shape;
};
}
}
......@@ -27,6 +27,7 @@ NGRAPH_OP(FakeQuantize, ngraph::op)
NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GroupConvolution, ngraph::op)
NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LeakyRelu, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
......
......@@ -86,6 +86,7 @@
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/leaky_relu.hpp"
#include "ngraph/op/fused/mvn.hpp"
......@@ -2063,6 +2064,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::GatherND:
case OP_TYPEID::GenerateMask:
case OP_TYPEID::GRN:
case OP_TYPEID::GroupConvolutionTranspose:
case OP_TYPEID::HardSigmoid:
case OP_TYPEID::LeakyRelu:
case OP_TYPEID::MVN:
......@@ -2183,6 +2185,7 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node)
case OP_TYPEID::FakeQuantize:
case OP_TYPEID::Gemm:
case OP_TYPEID::GRN:
case OP_TYPEID::GroupConvolutionTranspose:
case OP_TYPEID::LeakyRelu:
case OP_TYPEID::MVN:
case OP_TYPEID::Normalize:
......
......@@ -74,6 +74,7 @@
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/leaky_relu.hpp"
#include "ngraph/op/fused/mvn.hpp"
......@@ -1078,6 +1079,31 @@ static shared_ptr<ngraph::Function>
pad_type);
break;
}
case OP_TYPEID::GroupConvolutionTranspose:
{
auto strides = node_js.at("strides").get<vector<size_t>>();
auto dilations = node_js.at("dilations").get<vector<size_t>>();
auto padding_begin = node_js.at("padding_begin").get<vector<ptrdiff_t>>();
auto padding_end = node_js.at("padding_end").get<vector<ptrdiff_t>>();
auto output_padding = node_js.at("output_padding").get<vector<ptrdiff_t>>();
auto groups = node_js.at("groups").get<size_t>();
op::PadType pad_type = node_js["pad_type"].empty()
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
auto output_shape = node_js.at("output_shape").get<vector<size_t>>();
node = make_shared<op::GroupConvolutionTranspose>(args[0],
args[1],
strides,
dilations,
padding_begin,
padding_end,
output_padding,
groups,
pad_type,
output_shape);
break;
}
case OP_TYPEID::LeakyRelu:
{
node = make_shared<op::LeakyRelu>(args[0], args[1]);
......@@ -2091,6 +2117,19 @@ static json write(const Node& n, bool binary_constant_data)
node["pad_type"] = tmp->get_pad_type();
break;
}
case OP_TYPEID::GroupConvolutionTranspose:
{
auto tmp = dynamic_cast<const op::GroupConvolutionTranspose*>(&n);
node["strides"] = tmp->get_strides();
node["dilations"] = tmp->get_dilations();
node["padding_begin"] = tmp->get_padding_begin();
node["padding_end"] = tmp->get_padding_end();
node["output_padding"] = tmp->get_output_padding();
node["groups"] = tmp->get_groups();
node["pad_type"] = tmp->get_pad_type();
node["output_shape"] = tmp->get_output_shape();
break;
}
case OP_TYPEID::LeakyRelu: { break;
}
case OP_TYPEID::Less:
......
......@@ -1156,3 +1156,94 @@ NGRAPH_TEST(${BACKEND_NAME}, fake_quantize_with_clip_across_channels)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, group_conv_transpose)
{
const CoordinateDiff output_padding{1, 1};
const CoordinateDiff padding_begin{1, 1};
const CoordinateDiff padding_end{1, 1};
Strides strides{2, 2};
Strides dilations{1, 1};
size_t groups = 1;
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 1, 3, 3});
auto filters = make_shared<op::Parameter>(element::f32, Shape{1, 1, 3, 3});
auto gct = make_shared<op::GroupConvolutionTranspose>(
data, filters, strides, dilations, padding_begin, padding_end, output_padding, groups);
auto function = make_shared<Function>(NodeVector{gct}, ParameterVector{data, filters});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
// X
test_case.add_input<float>(vector<float>{0.16857791f,
-0.15161794f,
0.08540368f,
0.1820628f,
-0.21746576f,
0.08245695f,
0.1431433f,
-0.43156421f,
0.30591947f});
// W
test_case.add_input<float>({-0.06230065f,
0.37932432f,
-0.25388849f,
0.33878803f,
0.43709868f,
-0.22477469f,
0.04118127f,
-0.44696793f,
0.06373066f});
test_case.add_expected_output(
Shape{1, 1, 6, 6},
vector<float>{
0.07368518f, -0.08925839f, -0.06627201f, 0.06301362f, 0.03732984f, -0.01919658f,
-0.00628807f, -0.02817563f, -0.01472169f, 0.04392925f, -0.00689478f, -0.01549204f,
0.07957941f, -0.11459791f, -0.09505399f, 0.07681622f, 0.03604182f, -0.01853423f,
-0.0270785f, -0.00680824f, -0.06650258f, 0.08004665f, 0.07918708f, -0.0724144f,
0.06256775f, -0.17838378f, -0.18863615f, 0.20064656f, 0.133717f, -0.06876295f,
-0.06398046f, -0.00864975f, 0.19289537f, -0.01490572f, -0.13673618f, 0.01949645f});
test_case.set_tolerance(3);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, group_conv_transpose_output_shape)
{
const CoordinateDiff output_padding{};
const Shape output_shape{1, 1, 1, 14};
Strides strides{1, 1};
Strides dilations{1, 1};
size_t groups = 1;
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 1, 1, 10});
auto filters = make_shared<op::Parameter>(element::f32, Shape{1, 1, 1, 5});
auto gct = make_shared<op::GroupConvolutionTranspose>(
data, filters, strides, dilations, output_padding, output_shape, groups);
auto function = make_shared<Function>(NodeVector{gct}, ParameterVector{data, filters});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
// X
test_case.add_input<float>(
vector<float>{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f});
// W
test_case.add_input<float>({1.0f, 2.0f, 3.0f, 2.0f, 1.0f});
test_case.add_expected_output(Shape{1, 1, 1, 14},
vector<float>{0.0f,
1.0f,
4.0f,
10.0f,
18.0f,
27.0f,
36.0f,
45.0f,
54.0f,
63.0f,
62.0f,
50.0f,
26.0f,
9.0f});
test_case.run();
}
......@@ -19,6 +19,7 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/util/attr_types.hpp"
using namespace std;
using namespace ngraph;
......@@ -14950,3 +14951,193 @@ TEST(type_prop, fake_quantize_invalid_rank)
"to number of channels."));
}
}
TEST(type_prop, group_conv_transpose)
{
// C x M / group x kH x kW
auto weights = make_shared<op::Parameter>(element::f32, Shape{16, 2, 3, 3});
// N x C x H x W
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 16, 6, 6});
auto gct = make_shared<op::GroupConvolutionTranspose>(data,
weights,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
2);
EXPECT_EQ(gct->get_element_type(), element::f32);
EXPECT_EQ(gct->get_shape(), (Shape{1, 4, 8, 8}));
EXPECT_EQ(gct->get_strides(), (Strides{1, 1}));
EXPECT_EQ(gct->get_dilations(), (Strides{1, 1}));
EXPECT_EQ(gct->get_padding_begin(), (CoordinateDiff{0, 0}));
EXPECT_EQ(gct->get_padding_end(), (CoordinateDiff{0, 0}));
EXPECT_EQ(gct->get_output_padding(), (CoordinateDiff{0, 0}));
EXPECT_EQ(gct->get_groups(), size_t(2));
EXPECT_EQ(gct->get_pad_type(), op::PadType::EXPLICIT);
}
TEST(type_prop, group_conv_transpose_output_shape)
{
// N x C x H x W
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 16, 5, 5});
// C x M / group x kH x kW
auto weights = make_shared<op::Parameter>(element::f32, Shape{16, 2, 3, 3});
auto gct = make_shared<op::GroupConvolutionTranspose>(
data, weights, Strides{1, 1}, Strides{1, 1}, CoordinateDiff{0, 0}, Shape{1, 2, 3, 3}, 1);
EXPECT_EQ(gct->get_element_type(), element::f32);
EXPECT_EQ(gct->get_shape(), (Shape{1, 2, 3, 3}));
EXPECT_EQ(gct->get_strides(), (Strides{1, 1}));
EXPECT_EQ(gct->get_dilations(), (Strides{1, 1}));
EXPECT_EQ(gct->get_padding_begin(), (CoordinateDiff{2, 2}));
EXPECT_EQ(gct->get_padding_end(), (CoordinateDiff{2, 2}));
EXPECT_EQ(gct->get_output_padding(), (CoordinateDiff{0, 0}));
EXPECT_EQ(gct->get_groups(), size_t(1));
EXPECT_EQ(gct->get_pad_type(), op::PadType::EXPLICIT);
}
TEST(type_prop, group_conv_transpose_invalid_params)
{
// C x M / group x kH x kW
auto weights = make_shared<op::Parameter>(element::f32, Shape{16, 20, 3, 3});
// N x C x H x W
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 16, 5, 5});
try
{
const auto gct = make_shared<op::GroupConvolutionTranspose>(data,
weights,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{2, 2},
CoordinateDiff{2, 2},
CoordinateDiff{0, 0},
21);
EXPECT_FALSE(gct.get()) << "GroupConvolutionTranspose validation did not work. "
"Node was created with incorrect params.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Incorrect value of groups:"));
}
try
{
const auto gct = make_shared<op::GroupConvolutionTranspose>(data,
weights,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{2, 2},
CoordinateDiff{2, 2},
CoordinateDiff{0, 0},
5);
EXPECT_FALSE(gct.get()) << "GroupConvolutionTranspose validation did not work. "
"Node was created with incorrect params.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Number of data channels not a multiple of group size."));
}
try
{
// C x M / group x kH x kW
auto bad_weights = make_shared<op::Parameter>(element::f32, Shape{10, 20, 3, 3});
const auto gct = make_shared<op::GroupConvolutionTranspose>(data,
bad_weights,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{2, 2},
CoordinateDiff{2, 2},
CoordinateDiff{0, 0},
8);
EXPECT_FALSE(gct.get()) << "GroupConvolutionTranspose validation did not work. "
"Node was created with incorrect params.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Number of filters channels must be equal to number of ") +
std::string("data channels"));
}
try
{
const auto gct = make_shared<op::GroupConvolutionTranspose>(data,
weights,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{2, 2},
CoordinateDiff{2, 2},
CoordinateDiff{0, 0},
4,
op::PadType::SAME_UPPER);
EXPECT_FALSE(gct.get()) << "GroupConvolutionTranspose validation did not work. "
"Node was created with incorrect params.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Currently only eplicit pad type is supported."));
}
try
{
const auto gct = make_shared<op::GroupConvolutionTranspose>(data,
weights,
Strides{1},
Strides{1, 1},
CoordinateDiff{2, 2},
CoordinateDiff{2, 2},
CoordinateDiff{0, 0},
4);
EXPECT_FALSE(gct.get()) << "GroupConvolutionTranspose validation did not work. "
"Node was created with incorrect params.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), std::string("Strides should be of number of input data features size."));
}
try
{
const auto gct = make_shared<op::GroupConvolutionTranspose>(data,
weights,
Strides{1, 1},
Strides{1, 1, 2},
CoordinateDiff{2, 2},
CoordinateDiff{2, 2},
CoordinateDiff{0, 0},
4);
EXPECT_FALSE(gct.get()) << "GroupConvolutionTranspose validation did not work. "
"Node was created with incorrect params.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Dilations should be of number of input data features size."));
}
try
{
const auto gct = make_shared<op::GroupConvolutionTranspose>(data,
weights,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{2, 2},
CoordinateDiff{2, 2},
CoordinateDiff{0, 0, 1, 1},
4);
EXPECT_FALSE(gct.get()) << "GroupConvolutionTranspose validation did not work. "
"Node was created with incorrect params.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Output padding should be of number of input data features size."));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment