Commit 371b47fb authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Adjust Split (#3943)

* Changed axis to Node

* Added using normalize from validation util

* refactored split

* Added typrop tests to Split

* Added set_input_is_relevant_to_shape for Split

* clang style applied

* Fixed var name

* Code refactor

* mergre from master. part.2

* Constructor to provide CI compatibility

* CI compatibility

* CI compatibility

* Updated get_outputs

* CI compitability

* Fixed get_outputs function
parent 8925436a
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/split.hpp" #include "ngraph/op/fused/split.hpp"
#include "op/split.hpp" #include "op/split.hpp"
#include "utils/common.hpp" #include "utils/common.hpp"
...@@ -32,30 +33,25 @@ namespace ngraph ...@@ -32,30 +33,25 @@ namespace ngraph
NodeVector split(const Node& node) NodeVector split(const Node& node)
{ {
const auto input = node.get_ng_inputs().at(0); const auto input = node.get_ng_inputs().at(0);
const auto outputs_number = node.get_output_names().size();
const auto axis = node.get_attribute_value<int64_t>("axis", 0); const auto axis = node.get_attribute_value<int64_t>("axis", 0);
std::size_t valid_axis = const auto axis_node =
common::validate_axis(node, axis, input->get_shape().size()); ngraph::op::Constant::create(element::i64, Shape{}, {axis});
try std::shared_ptr<ngraph::Node> fused_split;
if (node.has_attribute("split"))
{ {
const auto length_parts = const auto length_parts =
node.get_attribute_value<std::vector<std::size_t>>("split"); node.get_attribute_value<std::vector<std::size_t>>("split");
const auto fused_split = fused_split =
std::make_shared<ngraph::op::Split>(input, valid_axis, length_parts); std::make_shared<ngraph::op::Split>(input, axis_node, length_parts);
return fused_split->decompose_op();
} }
catch (const error::node::UnknownAttribute&) else
{ {
// an exception will be caught if the input node does not contain const auto outputs_number = node.get_output_names().size();
// the 'split' attribute - this means we should split the input tensor fused_split =
// into same-length parts equal to the number of node outputs std::make_shared<ngraph::op::Split>(input, axis_node, outputs_number);
const auto fused_split =
std::make_shared<ngraph::op::Split>(input, valid_axis, outputs_number);
return fused_split->decompose_op();
} }
return common::get_outputs(fused_split);
} }
} // namespace set_1 } // namespace set_1
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <onnx/onnx_pb.h> // onnx types #include <onnx/onnx_pb.h> // onnx types
#include "common.hpp" #include "common.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "validation_util.hpp" #include "validation_util.hpp"
namespace ngraph namespace ngraph
...@@ -79,6 +80,24 @@ namespace ngraph ...@@ -79,6 +80,24 @@ namespace ngraph
return new_axes; return new_axes;
} }
ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node)
{
const auto outputs_number = node->get_output_size();
ngraph::NodeVector outputs(outputs_number);
for (int i = 0; i < outputs_number; ++i)
{
if (node->output(i).get_node_shared_ptr()->get_output_size() == 1)
{
outputs[i] = node->get_output_as_single_output_node(i);
}
else
{
outputs[i] = std::make_shared<ngraph::op::GetOutputElement>(node, i);
}
}
return outputs;
}
} // namespace common } // namespace common
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -115,6 +115,13 @@ namespace ngraph ...@@ -115,6 +115,13 @@ namespace ngraph
std::vector<std::int64_t> axes, std::vector<std::int64_t> axes,
std::int64_t tensor_rank); std::int64_t tensor_rank);
/// \brief Return the outputs of the node as vector.
///
/// \param[in] node Node with multiple outputs.
///
/// \return Vector of outputs of input node.
ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node);
/// \brief Creates a shifted square identity matrix. /// \brief Creates a shifted square identity matrix.
/// \note Shifting in the context of this operator means that /// \note Shifting in the context of this operator means that
/// the matrix can be created with elements equal to 1 not only in the main /// the matrix can be created with elements equal to 1 not only in the main
......
...@@ -16,23 +16,36 @@ ...@@ -16,23 +16,36 @@
#include <numeric> #include <numeric>
#include "ngraph/builder/split.hpp" #include "ngraph/builder/split.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/split.hpp" #include "ngraph/op/fused/split.hpp"
#include "ngraph/validation_util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::Split::type_info; constexpr NodeTypeInfo op::Split::type_info;
op::Split::Split(const Output<Node>& data, const int axis, const size_t num_split) op::Split::Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_split)
: FusedOp({data}) : FusedOp({data, axis})
, m_split_evenly{true} , m_split_evenly{true}
, m_axis{axis}
, m_num_split{num_split} , m_num_split{num_split}
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::Split::Split(const Output<Node>& data, const int axis, const std::vector<size_t>& splits) op::Split::Split(const Output<Node>& data,
const Output<Node>& axis,
const std::vector<size_t>& splits)
: FusedOp({data, axis})
, m_split_evenly{false}
, m_num_split{0}
, m_splits{splits}
{
constructor_validate_and_infer_types();
}
// TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY
op::Split::Split(const Output<Node>& data, int axis, const std::vector<size_t>& splits)
: FusedOp({data}) : FusedOp({data})
, m_split_evenly{false} , m_split_evenly{false}
, m_axis{axis} , m_axis{axis}
...@@ -42,8 +55,30 @@ op::Split::Split(const Output<Node>& data, const int axis, const std::vector<siz ...@@ -42,8 +55,30 @@ op::Split::Split(const Output<Node>& data, const int axis, const std::vector<siz
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
// TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY
op::Split::Split(const Output<Node>& data, int axis, size_t num_split)
: FusedOp({data})
, m_split_evenly{true}
, m_axis{axis}
, m_num_split{num_split}
{
constructor_validate_and_infer_types();
}
void op::Split::pre_validate_and_infer_types() void op::Split::pre_validate_and_infer_types()
{ {
// TODO REMOVE IF CHECK. INTRODUCED TO PROVIDE CI COMPATIBILITY
if (get_input_size() == 2)
{
const auto axis_shape = input(1).get_shape();
NODE_VALIDATION_CHECK(this, is_scalar(axis_shape), "The 'axis' input node must be scalar");
const auto axis_node = input_value(1).get_node_shared_ptr();
NODE_VALIDATION_CHECK(
this, axis_node->is_constant(), "The 'axis' input node must be constant");
const auto axis_node_const = as_type_ptr<op::Constant>(axis_node);
m_axis = axis_node_const->get_vector<int64_t>()[0];
}
// Create dynamic-typed outputs. Actual shape/type will be computed during shape inference // Create dynamic-typed outputs. Actual shape/type will be computed during shape inference
for (size_t i = 0; i < std::max(m_splits.size(), m_num_split); i++) for (size_t i = 0; i < std::max(m_splits.size(), m_num_split); i++)
{ {
...@@ -57,11 +92,7 @@ void op::Split::pre_validate_and_infer_types() ...@@ -57,11 +92,7 @@ void op::Split::pre_validate_and_infer_types()
const auto shape = input(0).get_shape(); const auto shape = input(0).get_shape();
m_axis = adjust_axis_value(m_axis, shape.size()); m_axis = ngraph::normalize_axis(this, m_axis, shape.size());
NODE_VALIDATION_CHECK(this,
m_axis >= 0 && m_axis < static_cast<int64_t>(shape.size()),
"The 'axis' parameter for Split has to point to one of the "
"input tensor's shape dimensions.");
const auto dimension_at_axis = shape.at(m_axis); const auto dimension_at_axis = shape.at(m_axis);
if (m_split_evenly) if (m_split_evenly)
...@@ -92,6 +123,7 @@ void op::Split::pre_validate_and_infer_types() ...@@ -92,6 +123,7 @@ void op::Split::pre_validate_and_infer_types()
all_splits_positive == true, all_splits_positive == true,
"All values of the 'splits' attribute must be greater than zero"); "All values of the 'splits' attribute must be greater than zero");
} }
set_input_is_relevant_to_shape(0);
} }
NodeVector op::Split::decompose_op() const NodeVector op::Split::decompose_op() const
...@@ -101,18 +133,12 @@ NodeVector op::Split::decompose_op() const ...@@ -101,18 +133,12 @@ NodeVector op::Split::decompose_op() const
shared_ptr<Node> op::Split::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Split::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); if (new_args.size() == 2)
return make_shared<Split>(new_args.at(0), m_axis, m_splits);
}
size_t op::Split::adjust_axis_value(const int axis, const size_t input_tensor_rank) const
{
if (axis < 0)
{ {
return axis + input_tensor_rank; check_new_args_count(this, new_args);
} return make_shared<Split>(new_args.at(0), new_args.at(1), m_splits);
else
{
return axis;
} }
// TODO REMOVE THIS RETURN AND IF ABOVE. INTRODUCED TO PROVIDE CI COMPATIBILITY
return make_shared<Split>(new_args.at(0), m_axis, m_splits);
} }
...@@ -37,24 +37,32 @@ namespace ngraph ...@@ -37,24 +37,32 @@ namespace ngraph
Split() = default; Split() = default;
/// \brief Constructs a Split op that evenly divides the input tensor. /// \brief Constructs a Split op that evenly divides the input tensor.
/// ///
/// \param data - Node producing the input tensor /// \param data Node producing the input tensor
/// \param axis - indicates an axis along which the input tensor should be split. /// \param axis Node producing an axis along which the input tensor
/// Negative values mean counting from the back of the input tensor's /// should be split. Negative values mean counting from
/// shape. /// the back of the input tensor's shape.
/// \param num_split - a number of "pieces" the input tensor will be split to /// \param num_split a number of "pieces" the input tensor will be split to
Split(const Output<Node>& data, const int axis, const size_t num_split); Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_split);
/// \brief Constructs a Split op that splits the input tensor into variable length /// \brief Constructs a Split op that splits the input tensor into variable length
/// "pieces" /// "pieces"
/// ///
/// \param data - Node producing the input tensor /// \param data Node producing the input tensor
/// \param axis - indicates an axis along which the input tensor should be split. /// \param axis Node producing an axis along which the input tensor
/// Negative values mean counting from the back of the input tensor's /// should be split. Negative values mean counting from
/// shape. /// the back of the input tensor's shape.
/// \param splits - a list of lengths that the input tensor should be split to. Use /// \param splits a list of lengths that the input tensor should be
/// this /// split to. Use this constructor to split the input
/// constructor to split the input tensor to variable length chunks. /// tensor to variable length chunks.
Split(const Output<Node>& data, const int axis, const std::vector<size_t>& splits); Split(const Output<Node>& data,
const Output<Node>& axis,
const std::vector<size_t>& splits);
// TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY
Split(const Output<Node>& data, int axis, const std::vector<size_t>& splits);
// TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY
Split(const Output<Node>& data, int axis, const size_t num_split);
void pre_validate_and_infer_types() override; void pre_validate_and_infer_types() override;
...@@ -66,21 +74,9 @@ namespace ngraph ...@@ -66,21 +74,9 @@ namespace ngraph
size_t get_axis() const { return m_axis; } size_t get_axis() const { return m_axis; }
const std::vector<size_t>& get_splits() const { return m_splits; } const std::vector<size_t>& get_splits() const { return m_splits; }
private: private:
/// \brief Adjusts the axis for negative values
///
/// \note Negative values mean that the API consumer wants to point the axis
/// location
/// from the back of the tensor. This is similar to the way NumPy works.
///
/// \param axis - original axis value; negative values are accepted
/// \param input_tensor_rank - rank of the input data tensor
/// \return Returns a sum of parameters for negative axis value, or axis itself
/// otherwise
size_t adjust_axis_value(const int axis, const size_t input_tensor_rank) const;
/// used internally for validation purposes, indicates which constructor was used /// used internally for validation purposes, indicates which constructor was used
bool m_split_evenly; bool m_split_evenly;
int m_axis; int64_t m_axis;
size_t m_num_split; size_t m_num_split;
/// contains lengths of chunks that the input tensor will be split into /// contains lengths of chunks that the input tensor will be split into
std::vector<size_t> m_splits; std::vector<size_t> m_splits;
......
...@@ -2626,9 +2626,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2626,9 +2626,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
} }
case OP_TYPEID::Split: case OP_TYPEID::Split:
{ {
const auto axis = node_js.at("axis").get<size_t>();
const auto splits = node_js.at("splits").get<vector<size_t>>(); const auto splits = node_js.at("splits").get<vector<size_t>>();
node = make_shared<op::Split>(args[0], axis, splits); node = make_shared<op::Split>(args[0], args[1], splits);
break; break;
} }
case OP_TYPEID::Sqrt: case OP_TYPEID::Sqrt:
...@@ -4232,7 +4231,6 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -4232,7 +4231,6 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Split: case OP_TYPEID::Split:
{ {
auto tmp = static_cast<const op::Split*>(&n); auto tmp = static_cast<const op::Split*>(&n);
node["axis"] = tmp->get_axis();
node["splits"] = tmp->get_splits(); node["splits"] = tmp->get_splits();
break; break;
} }
......
...@@ -1359,8 +1359,9 @@ NGRAPH_TEST(${BACKEND_NAME}, squared_difference_broadcast) ...@@ -1359,8 +1359,9 @@ NGRAPH_TEST(${BACKEND_NAME}, squared_difference_broadcast)
NGRAPH_TEST(${BACKEND_NAME}, split_3_equal_parts) NGRAPH_TEST(${BACKEND_NAME}, split_3_equal_parts)
{ {
const auto data = make_shared<op::Parameter>(element::i32, Shape{6}); const auto data = make_shared<op::Parameter>(element::i32, Shape{6});
const auto axis = op::Constant::create(element::i64, Shape{}, {0});
const auto tested_op = make_shared<op::Split>(data, 0, 3); const auto tested_op = make_shared<op::Split>(data, axis, 3);
const auto function = make_shared<Function>(tested_op->decompose_op(), ParameterVector{data}); const auto function = make_shared<Function>(tested_op->decompose_op(), ParameterVector{data});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}"); auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
...@@ -1378,7 +1379,8 @@ NGRAPH_TEST(${BACKEND_NAME}, split_var_len_parts) ...@@ -1378,7 +1379,8 @@ NGRAPH_TEST(${BACKEND_NAME}, split_var_len_parts)
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6}); const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
const std::vector<size_t> splits = {2, 4}; const std::vector<size_t> splits = {2, 4};
const auto tested_op = make_shared<op::Split>(data, 1, splits); const auto axis = op::Constant::create(element::i64, Shape{}, {1});
const auto tested_op = make_shared<op::Split>(data, axis, splits);
const auto function = make_shared<Function>(tested_op->decompose_op(), ParameterVector{data}); const auto function = make_shared<Function>(tested_op->decompose_op(), ParameterVector{data});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}"); auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
......
...@@ -155,7 +155,8 @@ TEST(build_graph, multi_output_split) ...@@ -155,7 +155,8 @@ TEST(build_graph, multi_output_split)
{ {
const auto data = make_shared<op::Parameter>(element::f32, Shape{64, 8, 100, 150}); const auto data = make_shared<op::Parameter>(element::f32, Shape{64, 8, 100, 150});
auto filters = make_shared<op::Parameter>(element::f32, Shape{128, 2, 10, 20}); auto filters = make_shared<op::Parameter>(element::f32, Shape{128, 2, 10, 20});
const auto split = make_shared<op::Split>(data, 1, 2); const auto axis = op::Constant::create(element::i64, Shape{}, {1});
const auto split = make_shared<op::Split>(data, axis, 2);
auto conv = make_shared<op::GroupConvolution>(split->output(1), auto conv = make_shared<op::GroupConvolution>(split->output(1),
filters, filters,
Strides{1, 1}, Strides{1, 1},
...@@ -170,7 +171,8 @@ TEST(build_graph, multi_output_split) ...@@ -170,7 +171,8 @@ TEST(build_graph, multi_output_split)
TEST(build_graph, multi_output_split_dynamic) TEST(build_graph, multi_output_split_dynamic)
{ {
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic()); const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto split = make_shared<op::Split>(data, 1, 2); const auto axis = op::Constant::create(element::i64, Shape{}, {1});
const auto split = make_shared<op::Split>(data, axis, 2);
auto abs = make_shared<op::Abs>(split->output(1)); auto abs = make_shared<op::Abs>(split->output(1));
EXPECT_TRUE(abs->get_output_partial_shape(0).same_scheme(PartialShape::dynamic())); EXPECT_TRUE(abs->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
......
...@@ -28,7 +28,8 @@ TEST(type_prop, split) ...@@ -28,7 +28,8 @@ TEST(type_prop, split)
try try
{ {
const std::vector<size_t> splits = {1, 6}; // should sum up to 6 const std::vector<size_t> splits = {1, 6}; // should sum up to 6
const auto split = make_shared<op::Split>(data, 1, splits); const auto axis = op::Constant::create(element::i64, Shape{}, {1});
const auto split = make_shared<op::Split>(data, axis, splits);
FAIL() << "Split node was created with incorrect data."; FAIL() << "Split node was created with incorrect data.";
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
...@@ -40,20 +41,62 @@ TEST(type_prop, split) ...@@ -40,20 +41,62 @@ TEST(type_prop, split)
try try
{ {
const std::vector<size_t> splits = {4, 2}; const std::vector<size_t> splits = {4, 2};
const auto split = make_shared<op::Split>(data, -5, splits); // invalid axis const auto axis = op::Constant::create(element::i64, Shape{}, {-5});
const auto split = make_shared<op::Split>(data, axis, splits); // invalid axis
FAIL() << "Split node was created with incorrect data."; FAIL() << "Split node was created with incorrect data.";
} }
catch (const NodeValidationFailure& error) catch (const ngraph_error& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter axis -5 out of the tensor rank"));
std::string("The 'axis' parameter for Split has to point to one of "
"the input tensor's shape dimensions."));
} }
const auto split = make_shared<op::Split>(data, 1, 2); const auto axis = op::Constant::create(element::i64, Shape{}, {1});
const auto split = make_shared<op::Split>(data, axis, 2);
EXPECT_EQ(split->outputs().size(), 2); EXPECT_EQ(split->outputs().size(), 2);
EXPECT_EQ(split->output(0).get_shape(), (Shape{2, 3})); EXPECT_EQ(split->output(0).get_shape(), (Shape{2, 3}));
EXPECT_EQ(split->output(1).get_shape(), (Shape{2, 3})); EXPECT_EQ(split->output(1).get_shape(), (Shape{2, 3}));
EXPECT_EQ(split->output(0).get_element_type(), element::i32); EXPECT_EQ(split->output(0).get_element_type(), element::i32);
EXPECT_EQ(split->output(1).get_element_type(), element::i32); EXPECT_EQ(split->output(1).get_element_type(), element::i32);
} }
TEST(type_prop, split_axis_must_be_scalar)
{
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
const std::vector<size_t> splits = {1, 6};
const auto axis = op::Constant::create(element::i64, Shape{2}, {0, 1});
try
{
const auto split = make_shared<op::Split>(data, axis, splits);
FAIL() << "Incorrect axis of Split not detected.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("The 'axis' input node must be scalar"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason.";
}
}
TEST(type_prop, split_axis_must_be_constant)
{
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
const std::vector<size_t> splits = {1, 6};
const auto axis = make_shared<op::Parameter>(element::i32, Shape{});
try
{
const auto split = make_shared<op::Split>(data, axis, splits);
FAIL() << "Not constant axis of Split not detected.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("The 'axis' input node must be constant"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason.";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment