Commit 371b47fb authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Adjust Split (#3943)

* Changed axis to Node

* Added using normalize from validation util

* refactored split

* Added typrop tests to Split

* Added set_input_is_relevant_to_shape for Split

* clang style applied

* Fixed var name

* Code refactor

* mergre from master. part.2

* Constructor to provide CI compatibility

* CI compatibility

* CI compatibility

* Updated get_outputs

* CI compitability

* Fixed get_outputs function
parent 8925436a
......@@ -17,6 +17,7 @@
#include <cstdint>
#include <vector>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/split.hpp"
#include "op/split.hpp"
#include "utils/common.hpp"
......@@ -32,30 +33,25 @@ namespace ngraph
NodeVector split(const Node& node)
{
const auto input = node.get_ng_inputs().at(0);
const auto outputs_number = node.get_output_names().size();
const auto axis = node.get_attribute_value<int64_t>("axis", 0);
std::size_t valid_axis =
common::validate_axis(node, axis, input->get_shape().size());
const auto axis_node =
ngraph::op::Constant::create(element::i64, Shape{}, {axis});
try
std::shared_ptr<ngraph::Node> fused_split;
if (node.has_attribute("split"))
{
const auto length_parts =
node.get_attribute_value<std::vector<std::size_t>>("split");
const auto fused_split =
std::make_shared<ngraph::op::Split>(input, valid_axis, length_parts);
return fused_split->decompose_op();
fused_split =
std::make_shared<ngraph::op::Split>(input, axis_node, length_parts);
}
catch (const error::node::UnknownAttribute&)
else
{
// an exception will be caught if the input node does not contain
// the 'split' attribute - this means we should split the input tensor
// into same-length parts equal to the number of node outputs
const auto fused_split =
std::make_shared<ngraph::op::Split>(input, valid_axis, outputs_number);
return fused_split->decompose_op();
const auto outputs_number = node.get_output_names().size();
fused_split =
std::make_shared<ngraph::op::Split>(input, axis_node, outputs_number);
}
return common::get_outputs(fused_split);
}
} // namespace set_1
......
......@@ -16,6 +16,7 @@
#include <onnx/onnx_pb.h> // onnx types
#include "common.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "validation_util.hpp"
namespace ngraph
......@@ -79,6 +80,24 @@ namespace ngraph
return new_axes;
}
ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node)
{
const auto outputs_number = node->get_output_size();
ngraph::NodeVector outputs(outputs_number);
for (int i = 0; i < outputs_number; ++i)
{
if (node->output(i).get_node_shared_ptr()->get_output_size() == 1)
{
outputs[i] = node->get_output_as_single_output_node(i);
}
else
{
outputs[i] = std::make_shared<ngraph::op::GetOutputElement>(node, i);
}
}
return outputs;
}
} // namespace common
} // namespace onnx_import
} // namespace ngraph
......@@ -115,6 +115,13 @@ namespace ngraph
std::vector<std::int64_t> axes,
std::int64_t tensor_rank);
/// \brief Return the outputs of the node as vector.
///
/// \param[in] node Node with multiple outputs.
///
/// \return Vector of outputs of input node.
ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node);
/// \brief Creates a shifted square identity matrix.
/// \note Shifting in the context of this operator means that
/// the matrix can be created with elements equal to 1 not only in the main
......
......@@ -16,23 +16,36 @@
#include <numeric>
#include "ngraph/builder/split.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/split.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Split::type_info;
op::Split::Split(const Output<Node>& data, const int axis, const size_t num_split)
: FusedOp({data})
op::Split::Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_split)
: FusedOp({data, axis})
, m_split_evenly{true}
, m_axis{axis}
, m_num_split{num_split}
{
constructor_validate_and_infer_types();
}
op::Split::Split(const Output<Node>& data, const int axis, const std::vector<size_t>& splits)
op::Split::Split(const Output<Node>& data,
const Output<Node>& axis,
const std::vector<size_t>& splits)
: FusedOp({data, axis})
, m_split_evenly{false}
, m_num_split{0}
, m_splits{splits}
{
constructor_validate_and_infer_types();
}
// TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY
op::Split::Split(const Output<Node>& data, int axis, const std::vector<size_t>& splits)
: FusedOp({data})
, m_split_evenly{false}
, m_axis{axis}
......@@ -42,8 +55,30 @@ op::Split::Split(const Output<Node>& data, const int axis, const std::vector<siz
constructor_validate_and_infer_types();
}
// TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY
op::Split::Split(const Output<Node>& data, int axis, size_t num_split)
: FusedOp({data})
, m_split_evenly{true}
, m_axis{axis}
, m_num_split{num_split}
{
constructor_validate_and_infer_types();
}
void op::Split::pre_validate_and_infer_types()
{
// TODO REMOVE IF CHECK. INTRODUCED TO PROVIDE CI COMPATIBILITY
if (get_input_size() == 2)
{
const auto axis_shape = input(1).get_shape();
NODE_VALIDATION_CHECK(this, is_scalar(axis_shape), "The 'axis' input node must be scalar");
const auto axis_node = input_value(1).get_node_shared_ptr();
NODE_VALIDATION_CHECK(
this, axis_node->is_constant(), "The 'axis' input node must be constant");
const auto axis_node_const = as_type_ptr<op::Constant>(axis_node);
m_axis = axis_node_const->get_vector<int64_t>()[0];
}
// Create dynamic-typed outputs. Actual shape/type will be computed during shape inference
for (size_t i = 0; i < std::max(m_splits.size(), m_num_split); i++)
{
......@@ -57,11 +92,7 @@ void op::Split::pre_validate_and_infer_types()
const auto shape = input(0).get_shape();
m_axis = adjust_axis_value(m_axis, shape.size());
NODE_VALIDATION_CHECK(this,
m_axis >= 0 && m_axis < static_cast<int64_t>(shape.size()),
"The 'axis' parameter for Split has to point to one of the "
"input tensor's shape dimensions.");
m_axis = ngraph::normalize_axis(this, m_axis, shape.size());
const auto dimension_at_axis = shape.at(m_axis);
if (m_split_evenly)
......@@ -92,6 +123,7 @@ void op::Split::pre_validate_and_infer_types()
all_splits_positive == true,
"All values of the 'splits' attribute must be greater than zero");
}
set_input_is_relevant_to_shape(0);
}
NodeVector op::Split::decompose_op() const
......@@ -101,18 +133,12 @@ NodeVector op::Split::decompose_op() const
shared_ptr<Node> op::Split::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Split>(new_args.at(0), m_axis, m_splits);
}
size_t op::Split::adjust_axis_value(const int axis, const size_t input_tensor_rank) const
{
if (axis < 0)
if (new_args.size() == 2)
{
return axis + input_tensor_rank;
}
else
{
return axis;
check_new_args_count(this, new_args);
return make_shared<Split>(new_args.at(0), new_args.at(1), m_splits);
}
// TODO REMOVE THIS RETURN AND IF ABOVE. INTRODUCED TO PROVIDE CI COMPATIBILITY
return make_shared<Split>(new_args.at(0), m_axis, m_splits);
}
......@@ -37,24 +37,32 @@ namespace ngraph
Split() = default;
/// \brief Constructs a Split op that evenly divides the input tensor.
///
/// \param data - Node producing the input tensor
/// \param axis - indicates an axis along which the input tensor should be split.
/// Negative values mean counting from the back of the input tensor's
/// shape.
/// \param num_split - a number of "pieces" the input tensor will be split to
Split(const Output<Node>& data, const int axis, const size_t num_split);
/// \param data Node producing the input tensor
/// \param axis Node producing an axis along which the input tensor
/// should be split. Negative values mean counting from
/// the back of the input tensor's shape.
/// \param num_split a number of "pieces" the input tensor will be split to
Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_split);
/// \brief Constructs a Split op that splits the input tensor into variable length
/// "pieces"
///
/// \param data - Node producing the input tensor
/// \param axis - indicates an axis along which the input tensor should be split.
/// Negative values mean counting from the back of the input tensor's
/// shape.
/// \param splits - a list of lengths that the input tensor should be split to. Use
/// this
/// constructor to split the input tensor to variable length chunks.
Split(const Output<Node>& data, const int axis, const std::vector<size_t>& splits);
/// \param data Node producing the input tensor
/// \param axis Node producing an axis along which the input tensor
/// should be split. Negative values mean counting from
/// the back of the input tensor's shape.
/// \param splits a list of lengths that the input tensor should be
/// split to. Use this constructor to split the input
/// tensor to variable length chunks.
Split(const Output<Node>& data,
const Output<Node>& axis,
const std::vector<size_t>& splits);
// TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY
Split(const Output<Node>& data, int axis, const std::vector<size_t>& splits);
// TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY
Split(const Output<Node>& data, int axis, const size_t num_split);
void pre_validate_and_infer_types() override;
......@@ -66,21 +74,9 @@ namespace ngraph
size_t get_axis() const { return m_axis; }
const std::vector<size_t>& get_splits() const { return m_splits; }
private:
/// \brief Adjusts the axis for negative values
///
/// \note Negative values mean that the API consumer wants to point the axis
/// location
/// from the back of the tensor. This is similar to the way NumPy works.
///
/// \param axis - original axis value; negative values are accepted
/// \param input_tensor_rank - rank of the input data tensor
/// \return Returns a sum of parameters for negative axis value, or axis itself
/// otherwise
size_t adjust_axis_value(const int axis, const size_t input_tensor_rank) const;
/// used internally for validation purposes, indicates which constructor was used
bool m_split_evenly;
int m_axis;
int64_t m_axis;
size_t m_num_split;
/// contains lengths of chunks that the input tensor will be split into
std::vector<size_t> m_splits;
......
......@@ -2626,9 +2626,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Split:
{
const auto axis = node_js.at("axis").get<size_t>();
const auto splits = node_js.at("splits").get<vector<size_t>>();
node = make_shared<op::Split>(args[0], axis, splits);
node = make_shared<op::Split>(args[0], args[1], splits);
break;
}
case OP_TYPEID::Sqrt:
......@@ -4232,7 +4231,6 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Split:
{
auto tmp = static_cast<const op::Split*>(&n);
node["axis"] = tmp->get_axis();
node["splits"] = tmp->get_splits();
break;
}
......
......@@ -1359,8 +1359,9 @@ NGRAPH_TEST(${BACKEND_NAME}, squared_difference_broadcast)
NGRAPH_TEST(${BACKEND_NAME}, split_3_equal_parts)
{
const auto data = make_shared<op::Parameter>(element::i32, Shape{6});
const auto axis = op::Constant::create(element::i64, Shape{}, {0});
const auto tested_op = make_shared<op::Split>(data, 0, 3);
const auto tested_op = make_shared<op::Split>(data, axis, 3);
const auto function = make_shared<Function>(tested_op->decompose_op(), ParameterVector{data});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
......@@ -1378,7 +1379,8 @@ NGRAPH_TEST(${BACKEND_NAME}, split_var_len_parts)
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
const std::vector<size_t> splits = {2, 4};
const auto tested_op = make_shared<op::Split>(data, 1, splits);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
const auto tested_op = make_shared<op::Split>(data, axis, splits);
const auto function = make_shared<Function>(tested_op->decompose_op(), ParameterVector{data});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
......
......@@ -155,7 +155,8 @@ TEST(build_graph, multi_output_split)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{64, 8, 100, 150});
auto filters = make_shared<op::Parameter>(element::f32, Shape{128, 2, 10, 20});
const auto split = make_shared<op::Split>(data, 1, 2);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
const auto split = make_shared<op::Split>(data, axis, 2);
auto conv = make_shared<op::GroupConvolution>(split->output(1),
filters,
Strides{1, 1},
......@@ -170,7 +171,8 @@ TEST(build_graph, multi_output_split)
TEST(build_graph, multi_output_split_dynamic)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto split = make_shared<op::Split>(data, 1, 2);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
const auto split = make_shared<op::Split>(data, axis, 2);
auto abs = make_shared<op::Abs>(split->output(1));
EXPECT_TRUE(abs->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
......
......@@ -28,7 +28,8 @@ TEST(type_prop, split)
try
{
const std::vector<size_t> splits = {1, 6}; // should sum up to 6
const auto split = make_shared<op::Split>(data, 1, splits);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
const auto split = make_shared<op::Split>(data, axis, splits);
FAIL() << "Split node was created with incorrect data.";
}
catch (const NodeValidationFailure& error)
......@@ -40,20 +41,62 @@ TEST(type_prop, split)
try
{
const std::vector<size_t> splits = {4, 2};
const auto split = make_shared<op::Split>(data, -5, splits); // invalid axis
const auto axis = op::Constant::create(element::i64, Shape{}, {-5});
const auto split = make_shared<op::Split>(data, axis, splits); // invalid axis
FAIL() << "Split node was created with incorrect data.";
}
catch (const NodeValidationFailure& error)
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("The 'axis' parameter for Split has to point to one of "
"the input tensor's shape dimensions."));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter axis -5 out of the tensor rank"));
}
const auto split = make_shared<op::Split>(data, 1, 2);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
const auto split = make_shared<op::Split>(data, axis, 2);
EXPECT_EQ(split->outputs().size(), 2);
EXPECT_EQ(split->output(0).get_shape(), (Shape{2, 3}));
EXPECT_EQ(split->output(1).get_shape(), (Shape{2, 3}));
EXPECT_EQ(split->output(0).get_element_type(), element::i32);
EXPECT_EQ(split->output(1).get_element_type(), element::i32);
}
TEST(type_prop, split_axis_must_be_scalar)
{
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
const std::vector<size_t> splits = {1, 6};
const auto axis = op::Constant::create(element::i64, Shape{2}, {0, 1});
try
{
const auto split = make_shared<op::Split>(data, axis, splits);
FAIL() << "Incorrect axis of Split not detected.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("The 'axis' input node must be scalar"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason.";
}
}
TEST(type_prop, split_axis_must_be_constant)
{
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
const std::vector<size_t> splits = {1, 6};
const auto axis = make_shared<op::Parameter>(element::i32, Shape{});
try
{
const auto split = make_shared<op::Split>(data, axis, splits);
FAIL() << "Not constant axis of Split not detected.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("The 'axis' input node must be constant"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason.";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment