Commit 778b6004 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Move GroupConv and Slice op to output-handle based constructors (#3083)

parent 13c05a47
...@@ -27,8 +27,14 @@ ...@@ -27,8 +27,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::GroupConvolution::GroupConvolution(const shared_ptr<Node>& data_batch, const string op::GroupConvolution::type_name{"GroupConvolution"};
const shared_ptr<Node>& filters,
op::GroupConvolution::GroupConvolution()
{
}
op::GroupConvolution::GroupConvolution(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
...@@ -36,7 +42,7 @@ op::GroupConvolution::GroupConvolution(const shared_ptr<Node>& data_batch, ...@@ -36,7 +42,7 @@ op::GroupConvolution::GroupConvolution(const shared_ptr<Node>& data_batch,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const size_t groups, const size_t groups,
const PadType& pad_type) const PadType& pad_type)
: FusedOp("GroupConvolution", check_single_output_args({data_batch, filters})) : FusedOp({data_batch, filters})
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides) , m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -45,7 +51,6 @@ op::GroupConvolution::GroupConvolution(const shared_ptr<Node>& data_batch, ...@@ -45,7 +51,6 @@ op::GroupConvolution::GroupConvolution(const shared_ptr<Node>& data_batch,
, m_groups(groups) , m_groups(groups)
, m_pad_type(pad_type) , m_pad_type(pad_type)
{ {
// TODO: Move this out of constructor to validate_and_infer_types()
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -129,35 +134,35 @@ shared_ptr<Node> op::GroupConvolution::copy_with_new_args(const NodeVector& new_ ...@@ -129,35 +134,35 @@ shared_ptr<Node> op::GroupConvolution::copy_with_new_args(const NodeVector& new_
NodeVector op::GroupConvolution::decompose_op() const NodeVector op::GroupConvolution::decompose_op() const
{ {
auto data = get_argument(0); auto data = input(0);
auto filters = get_argument(1); auto filters = input(1);
// Split one convolution op to N ops where N is the number of groups // Split one convolution op to N ops where N is the number of groups
// and concat results after computation. // and concat results after computation.
// reference: https://github.com/NervanaSystems/ngraph-mxnet/blob/fdd692/src/ngraph/ngraph_emitter.cc#L822-L856 // reference: https://github.com/NervanaSystems/ngraph-mxnet/blob/fdd692/src/ngraph/ngraph_emitter.cc#L822-L856
std::size_t n_data_channels{data->get_shape().at(1)}; std::size_t n_data_channels{data.get_shape().at(1)};
std::size_t n_filters_channels{filters->get_shape().at(0)}; std::size_t n_filters_channels{filters.get_shape().at(0)};
std::size_t data_group_size{n_data_channels / m_groups}; std::size_t data_group_size{n_data_channels / m_groups};
std::size_t filters_group_size{n_filters_channels / m_groups}; std::size_t filters_group_size{n_filters_channels / m_groups};
NodeVector convolution_nodes; NodeVector convolution_nodes;
// initial bounds for splice // initial bounds for splice
std::vector<std::size_t> data_lower_bounds(data->get_shape().size()); std::vector<std::size_t> data_lower_bounds(data.get_shape().size());
std::vector<std::size_t> data_upper_bounds{data->get_shape()}; std::vector<std::size_t> data_upper_bounds{data.get_shape()};
std::vector<std::size_t> filters_lower_bounds(filters->get_shape().size()); std::vector<std::size_t> filters_lower_bounds(filters.get_shape().size());
std::vector<std::size_t> filters_upper_bounds{filters->get_shape()}; std::vector<std::size_t> filters_upper_bounds{filters.get_shape()};
for (std::size_t group{0}; group < m_groups; ++group) for (std::size_t group{0}; group < m_groups; ++group)
{ {
// slice data // slice data
data_lower_bounds[1] = group * data_group_size; data_lower_bounds[1] = group * data_group_size;
data_upper_bounds[1] = (group + 1) * data_group_size; data_upper_bounds[1] = (group + 1) * data_group_size;
auto sliced_data = auto sliced_data = std::make_shared<ngraph::op::Slice>(
std::make_shared<ngraph::op::Slice>(data, data_lower_bounds, data_upper_bounds); data.get_source_output(), data_lower_bounds, data_upper_bounds);
// slice filters // slice filters
filters_lower_bounds[0] = group * filters_group_size; filters_lower_bounds[0] = group * filters_group_size;
filters_upper_bounds[0] = (group + 1) * filters_group_size; filters_upper_bounds[0] = (group + 1) * filters_group_size;
auto sliced_filters = std::make_shared<ngraph::op::Slice>( auto sliced_filters = std::make_shared<ngraph::op::Slice>(
filters, filters_lower_bounds, filters_upper_bounds); filters.get_source_output(), filters_lower_bounds, filters_upper_bounds);
convolution_nodes.push_back( convolution_nodes.push_back(
std::make_shared<ngraph::op::Convolution>(sliced_data, std::make_shared<ngraph::op::Convolution>(sliced_data,
......
...@@ -29,8 +29,12 @@ namespace ngraph ...@@ -29,8 +29,12 @@ namespace ngraph
class GroupConvolution : public ngraph::op::util::FusedOp class GroupConvolution : public ngraph::op::util::FusedOp
{ {
public: public:
GroupConvolution(const std::shared_ptr<Node>& data_batch, NGRAPH_API
const std::shared_ptr<Node>& filters, static const std::string type_name;
const std::string& description() const override { return type_name; }
GroupConvolution();
GroupConvolution(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
......
...@@ -19,11 +19,17 @@ ...@@ -19,11 +19,17 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Slice::Slice(const shared_ptr<Node>& arg, const string op::Slice::type_name{"Slice"};
op::Slice::Slice()
{
}
op::Slice::Slice(const Output<Node>& arg,
const Coordinate& lower_bounds, const Coordinate& lower_bounds,
const Coordinate& upper_bounds, const Coordinate& upper_bounds,
const Strides& strides) const Strides& strides)
: Op("Slice", check_single_output_args({arg})) : Op({arg})
, m_lower_bounds(lower_bounds) , m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds) , m_upper_bounds(upper_bounds)
, m_strides(strides) , m_strides(strides)
...@@ -31,10 +37,10 @@ op::Slice::Slice(const shared_ptr<Node>& arg, ...@@ -31,10 +37,10 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::Slice::Slice(const shared_ptr<Node>& arg, op::Slice::Slice(const Output<Node>& arg,
const Coordinate& lower_bounds, const Coordinate& lower_bounds,
const Coordinate& upper_bounds) const Coordinate& upper_bounds)
: Op("Slice", check_single_output_args({arg})) : Op({arg})
, m_lower_bounds(lower_bounds) , m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds) , m_upper_bounds(upper_bounds)
, m_strides(Strides()) , m_strides(Strides())
......
...@@ -28,6 +28,11 @@ namespace ngraph ...@@ -28,6 +28,11 @@ namespace ngraph
class Slice : public Op class Slice : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a tensor slice operation
Slice();
/// \brief Constructs a tensor slice operation. /// \brief Constructs a tensor slice operation.
/// ///
/// \param arg The tensor to be sliced. /// \param arg The tensor to be sliced.
...@@ -35,17 +40,16 @@ namespace ngraph ...@@ -35,17 +40,16 @@ namespace ngraph
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive). /// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
/// \param strides The slicing strides; for example, strides of `{n,m}` means to take /// \param strides The slicing strides; for example, strides of `{n,m}` means to take
/// every nth row and every mth column of the input matrix. /// every nth row and every mth column of the input matrix.
Slice(const std::shared_ptr<Node>& arg, Slice(const Output<Node>& arg,
const Coordinate& lower_bounds, const Coordinate& lower_bounds,
const Coordinate& upper_bounds, const Coordinate& upper_bounds,
const Strides& strides); const Strides& strides);
/// \brief Constructs a tensor slice operation with unit strides; i.e., every element inside the bounding box will be copied to the output slice. /// \brief Constructs a tensor slice operation with unit strides; i.e., every element inside the bounding box will be copied to the output slice.
/// ///
/// \param arg The tensor to be sliced. /// \param arg The tensor to be sliced.
/// \param lower_bounds The axiswise lower bounds of the slice (inclusive). /// \param lower_bounds The axiswise lower bounds of the slice (inclusive).
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive). /// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
Slice(const std::shared_ptr<Node>& arg, Slice(const Output<Node>& arg,
const Coordinate& lower_bounds, const Coordinate& lower_bounds,
const Coordinate& upper_bounds); const Coordinate& upper_bounds);
......
...@@ -30,6 +30,11 @@ op::util::FusedOp::FusedOp(const NodeVector& args) ...@@ -30,6 +30,11 @@ op::util::FusedOp::FusedOp(const NodeVector& args)
{ {
} }
op::util::FusedOp::FusedOp(const OutputVector& args)
: Op(args)
{
}
op::util::FusedOp::FusedOp(const std::string& node_type, const NodeVector& args) op::util::FusedOp::FusedOp(const std::string& node_type, const NodeVector& args)
: Op(node_type, args) : Op(node_type, args)
{ {
......
...@@ -51,6 +51,8 @@ namespace ngraph ...@@ -51,6 +51,8 @@ namespace ngraph
/// \param args Nodes that produce the input tensors for the fused op /// \param args Nodes that produce the input tensors for the fused op
FusedOp(const NodeVector& args); FusedOp(const NodeVector& args);
FusedOp(const OutputVector& args);
/// \brief Constructs a FusedOp /// \brief Constructs a FusedOp
/// ///
/// \param args Nodes that produce the input tensors for the fused op /// \param args Nodes that produce the input tensors for the fused op
......
...@@ -150,3 +150,19 @@ TEST(build_graph, no_arg_construction) ...@@ -150,3 +150,19 @@ TEST(build_graph, no_arg_construction)
validate_nodes_and_infer_types(ops); validate_nodes_and_infer_types(ops);
ASSERT_EQ(add1->get_output_shape(0), Shape{7}); ASSERT_EQ(add1->get_output_shape(0), Shape{7});
} }
TEST(build_graph, multi_output_split)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{64, 8, 100, 150});
auto filters = make_shared<op::Parameter>(element::f32, Shape{128, 2, 10, 20});
const auto split = make_shared<op::Split>(data, 1, 2);
auto conv = make_shared<op::GroupConvolution>(split->output(1),
filters,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1},
2);
EXPECT_EQ(conv->get_shape(), (Shape{64, 128, 91, 131}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment