Commit 778b6004 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Move GroupConv and Slice op to output-handle based constructors (#3083)

parent 13c05a47
......@@ -27,8 +27,14 @@
using namespace std;
using namespace ngraph;
op::GroupConvolution::GroupConvolution(const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters,
const string op::GroupConvolution::type_name{"GroupConvolution"};
op::GroupConvolution::GroupConvolution()
{
}
op::GroupConvolution::GroupConvolution(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
......@@ -36,7 +42,7 @@ op::GroupConvolution::GroupConvolution(const shared_ptr<Node>& data_batch,
const Strides& data_dilation_strides,
const size_t groups,
const PadType& pad_type)
: FusedOp("GroupConvolution", check_single_output_args({data_batch, filters}))
: FusedOp({data_batch, filters})
, m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below)
......@@ -45,7 +51,6 @@ op::GroupConvolution::GroupConvolution(const shared_ptr<Node>& data_batch,
, m_groups(groups)
, m_pad_type(pad_type)
{
// TODO: Move this out of constructor to validate_and_infer_types()
constructor_validate_and_infer_types();
}
......@@ -129,35 +134,35 @@ shared_ptr<Node> op::GroupConvolution::copy_with_new_args(const NodeVector& new_
NodeVector op::GroupConvolution::decompose_op() const
{
auto data = get_argument(0);
auto filters = get_argument(1);
auto data = input(0);
auto filters = input(1);
// Split one convolution op to N ops where N is the number of groups
// and concat results after computation.
// reference: https://github.com/NervanaSystems/ngraph-mxnet/blob/fdd692/src/ngraph/ngraph_emitter.cc#L822-L856
std::size_t n_data_channels{data->get_shape().at(1)};
std::size_t n_filters_channels{filters->get_shape().at(0)};
std::size_t n_data_channels{data.get_shape().at(1)};
std::size_t n_filters_channels{filters.get_shape().at(0)};
std::size_t data_group_size{n_data_channels / m_groups};
std::size_t filters_group_size{n_filters_channels / m_groups};
NodeVector convolution_nodes;
// initial bounds for splice
std::vector<std::size_t> data_lower_bounds(data->get_shape().size());
std::vector<std::size_t> data_upper_bounds{data->get_shape()};
std::vector<std::size_t> filters_lower_bounds(filters->get_shape().size());
std::vector<std::size_t> filters_upper_bounds{filters->get_shape()};
std::vector<std::size_t> data_lower_bounds(data.get_shape().size());
std::vector<std::size_t> data_upper_bounds{data.get_shape()};
std::vector<std::size_t> filters_lower_bounds(filters.get_shape().size());
std::vector<std::size_t> filters_upper_bounds{filters.get_shape()};
for (std::size_t group{0}; group < m_groups; ++group)
{
// slice data
data_lower_bounds[1] = group * data_group_size;
data_upper_bounds[1] = (group + 1) * data_group_size;
auto sliced_data =
std::make_shared<ngraph::op::Slice>(data, data_lower_bounds, data_upper_bounds);
auto sliced_data = std::make_shared<ngraph::op::Slice>(
data.get_source_output(), data_lower_bounds, data_upper_bounds);
// slice filters
filters_lower_bounds[0] = group * filters_group_size;
filters_upper_bounds[0] = (group + 1) * filters_group_size;
auto sliced_filters = std::make_shared<ngraph::op::Slice>(
filters, filters_lower_bounds, filters_upper_bounds);
filters.get_source_output(), filters_lower_bounds, filters_upper_bounds);
convolution_nodes.push_back(
std::make_shared<ngraph::op::Convolution>(sliced_data,
......
......@@ -29,8 +29,12 @@ namespace ngraph
class GroupConvolution : public ngraph::op::util::FusedOp
{
public:
GroupConvolution(const std::shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters,
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
GroupConvolution();
GroupConvolution(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
......
......@@ -19,11 +19,17 @@
using namespace std;
using namespace ngraph;
op::Slice::Slice(const shared_ptr<Node>& arg,
const string op::Slice::type_name{"Slice"};
op::Slice::Slice()
{
}
op::Slice::Slice(const Output<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Strides& strides)
: Op("Slice", check_single_output_args({arg}))
: Op({arg})
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_strides(strides)
......@@ -31,10 +37,10 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
constructor_validate_and_infer_types();
}
op::Slice::Slice(const shared_ptr<Node>& arg,
op::Slice::Slice(const Output<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds)
: Op("Slice", check_single_output_args({arg}))
: Op({arg})
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_strides(Strides())
......
......@@ -28,6 +28,11 @@ namespace ngraph
class Slice : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a tensor slice operation
Slice();
/// \brief Constructs a tensor slice operation.
///
/// \param arg The tensor to be sliced.
......@@ -35,17 +40,16 @@ namespace ngraph
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
/// \param strides The slicing strides; for example, strides of `{n,m}` means to take
/// every nth row and every mth column of the input matrix.
Slice(const std::shared_ptr<Node>& arg,
Slice(const Output<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Strides& strides);
/// \brief Constructs a tensor slice operation with unit strides; i.e., every element inside the bounding box will be copied to the output slice.
///
/// \param arg The tensor to be sliced.
/// \param lower_bounds The axiswise lower bounds of the slice (inclusive).
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
Slice(const std::shared_ptr<Node>& arg,
Slice(const Output<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds);
......
......@@ -30,6 +30,11 @@ op::util::FusedOp::FusedOp(const NodeVector& args)
{
}
op::util::FusedOp::FusedOp(const OutputVector& args)
: Op(args)
{
}
op::util::FusedOp::FusedOp(const std::string& node_type, const NodeVector& args)
: Op(node_type, args)
{
......
......@@ -51,6 +51,8 @@ namespace ngraph
/// \param args Nodes that produce the input tensors for the fused op
FusedOp(const NodeVector& args);
FusedOp(const OutputVector& args);
/// \brief Constructs a FusedOp
///
/// \param args Nodes that produce the input tensors for the fused op
......
......@@ -150,3 +150,19 @@ TEST(build_graph, no_arg_construction)
validate_nodes_and_infer_types(ops);
ASSERT_EQ(add1->get_output_shape(0), Shape{7});
}
TEST(build_graph, multi_output_split)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{64, 8, 100, 150});
auto filters = make_shared<op::Parameter>(element::f32, Shape{128, 2, 10, 20});
const auto split = make_shared<op::Split>(data, 1, 2);
auto conv = make_shared<op::GroupConvolution>(split->output(1),
filters,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1},
2);
EXPECT_EQ(conv->get_shape(), (Shape{64, 128, 91, 131}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment