Commit c546282d authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Scott Cyphers

[SPEC] v1::Split + onnx_importer adaptation (#3993)

* fused::v1::Split op skeleton & serialization

* Shape inference for Split and upgrade/downgrade passes

* Upgrade pass for the Split op

* Disable the failing onnx_importer Split op tests on interpreter

* Opset1 test correction for v1::Split

* Disable the failing Split op unit tests on PlaidML

* Fix of the upgrade pass for Split

* Remove one of the obsolete v0::Split constructors
parent e5e8faef
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/split.hpp" #include "ngraph/op/fused/split.hpp"
#include "ngraph/op/variadic_split.hpp"
#include "op/split.hpp" #include "op/split.hpp"
#include "utils/common.hpp" #include "utils/common.hpp"
...@@ -37,21 +38,25 @@ namespace ngraph ...@@ -37,21 +38,25 @@ namespace ngraph
const auto axis_node = const auto axis_node =
ngraph::op::Constant::create(element::i64, Shape{}, {axis}); ngraph::op::Constant::create(element::i64, Shape{}, {axis});
std::shared_ptr<ngraph::Node> fused_split; std::shared_ptr<ngraph::Node> split;
if (node.has_attribute("split")) if (node.has_attribute("split"))
{ {
const auto length_parts = const auto splits =
node.get_attribute_value<std::vector<std::size_t>>("split"); node.get_attribute_value<std::vector<std::size_t>>("split");
fused_split =
std::make_shared<ngraph::op::Split>(input, axis_node, length_parts); const auto split_lengths = ngraph::op::Constant::create(
element::u64, Shape{splits.size()}, splits);
split = std::make_shared<ngraph::op::v1::VariadicSplit>(
input, axis_node, split_lengths);
} }
else else
{ {
const auto outputs_number = node.get_output_names().size(); const auto outputs_number = node.get_output_names().size();
fused_split = split = std::make_shared<ngraph::op::v1::Split>(
std::make_shared<ngraph::op::Split>(input, axis_node, outputs_number); input, axis_node, outputs_number);
} }
return common::get_outputs(fused_split); return common::get_outputs(split);
} }
} // namespace set_1 } // namespace set_1
......
...@@ -317,6 +317,7 @@ const std::string& Node::description() const ...@@ -317,6 +317,7 @@ const std::string& Node::description() const
// type_name to const_char and virtual description() to virtual get_type_name() // type_name to const_char and virtual description() to virtual get_type_name()
const_cast<Node*>(this)->m_node_type = get_type_name(); const_cast<Node*>(this)->m_node_type = get_type_name();
} }
return m_node_type; return m_node_type;
} }
......
...@@ -23,9 +23,9 @@ ...@@ -23,9 +23,9 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::Split::type_info; constexpr NodeTypeInfo op::v0::Split::type_info;
op::Split::Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_split) op::v0::Split::Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_split)
: FusedOp({data, axis}) : FusedOp({data, axis})
, m_split_evenly{true} , m_split_evenly{true}
, m_num_split{num_split} , m_num_split{num_split}
...@@ -33,7 +33,7 @@ op::Split::Split(const Output<Node>& data, const Output<Node>& axis, const size_ ...@@ -33,7 +33,7 @@ op::Split::Split(const Output<Node>& data, const Output<Node>& axis, const size_
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::Split::Split(const Output<Node>& data, op::v0::Split::Split(const Output<Node>& data,
const Output<Node>& axis, const Output<Node>& axis,
const std::vector<size_t>& splits) const std::vector<size_t>& splits)
: FusedOp({data, axis}) : FusedOp({data, axis})
...@@ -45,7 +45,7 @@ op::Split::Split(const Output<Node>& data, ...@@ -45,7 +45,7 @@ op::Split::Split(const Output<Node>& data,
} }
// TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY // TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY
op::Split::Split(const Output<Node>& data, int axis, const std::vector<size_t>& splits) op::v0::Split::Split(const Output<Node>& data, int axis, const std::vector<size_t>& splits)
: FusedOp({data}) : FusedOp({data})
, m_split_evenly{false} , m_split_evenly{false}
, m_axis{axis} , m_axis{axis}
...@@ -55,17 +55,7 @@ op::Split::Split(const Output<Node>& data, int axis, const std::vector<size_t>& ...@@ -55,17 +55,7 @@ op::Split::Split(const Output<Node>& data, int axis, const std::vector<size_t>&
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
// TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY void op::v0::Split::pre_validate_and_infer_types()
op::Split::Split(const Output<Node>& data, int axis, size_t num_split)
: FusedOp({data})
, m_split_evenly{true}
, m_axis{axis}
, m_num_split{num_split}
{
constructor_validate_and_infer_types();
}
void op::Split::pre_validate_and_infer_types()
{ {
// TODO REMOVE IF CHECK. INTRODUCED TO PROVIDE CI COMPATIBILITY // TODO REMOVE IF CHECK. INTRODUCED TO PROVIDE CI COMPATIBILITY
if (get_input_size() == 2) if (get_input_size() == 2)
...@@ -126,12 +116,12 @@ void op::Split::pre_validate_and_infer_types() ...@@ -126,12 +116,12 @@ void op::Split::pre_validate_and_infer_types()
set_input_is_relevant_to_shape(0); set_input_is_relevant_to_shape(0);
} }
NodeVector op::Split::decompose_op() const NodeVector op::v0::Split::decompose_op() const
{ {
return builder::split(input_value(0), m_splits, m_axis); return builder::split(input_value(0), m_splits, m_axis);
} }
shared_ptr<Node> op::Split::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::Split::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() == 2) if (new_args.size() == 2)
{ {
...@@ -140,5 +130,96 @@ shared_ptr<Node> op::Split::copy_with_new_args(const NodeVector& new_args) const ...@@ -140,5 +130,96 @@ shared_ptr<Node> op::Split::copy_with_new_args(const NodeVector& new_args) const
} }
// TODO REMOVE THIS RETURN AND IF ABOVE. INTRODUCED TO PROVIDE CI COMPATIBILITY // TODO REMOVE THIS RETURN AND IF ABOVE. INTRODUCED TO PROVIDE CI COMPATIBILITY
return make_shared<Split>(new_args.at(0), m_axis, m_splits); return make_shared<op::v0::Split>(new_args.at(0), m_axis, m_splits);
}
constexpr NodeTypeInfo op::v1::Split::type_info;
op::v1::Split::Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_splits)
: FusedOp({data, axis})
, m_num_splits{num_splits}
{
constructor_validate_and_infer_types();
}
void op::v1::Split::validate_and_infer_types()
{
const auto data_ps = input_value(0).get_partial_shape();
const auto axis_ps = input_value(1).get_partial_shape();
const auto axis_et = input_value(1).get_element_type();
NODE_VALIDATION_CHECK(this,
axis_ps.rank().is_static() && (size_t)axis_ps.rank() == 0,
"The 'axis' input is expected to be a scalar. Got: ",
axis_ps);
NODE_VALIDATION_CHECK(
this, axis_et.is_integral(), "The 'axis' input only accepts integral types");
if (input_value(1).get_node_shared_ptr()->is_constant())
{
auto axis = axis_value_from_input();
if (data_ps.is_static())
{
const auto data_shape = data_ps.to_shape();
axis = ngraph::normalize_axis(this, axis, data_shape.size());
const auto dimension_at_axis = data_shape.at(axis);
NODE_VALIDATION_CHECK(this,
dimension_at_axis % m_num_splits == 0,
"The input tensor's dimension pointed by the 'axis' parameter: ",
dimension_at_axis,
" has to be a multiple of the 'num_splits' attribute value: ",
m_num_splits);
Shape each_output_shape{data_shape};
each_output_shape.at(axis) = dimension_at_axis / m_num_splits;
for (size_t i = 0; i < m_num_splits; ++i)
{
set_output_type(i, input(0).get_element_type(), each_output_shape);
}
}
}
else
{
for (size_t i = 0; i < m_num_splits; ++i)
{
set_output_type(i, input(0).get_element_type(), PartialShape::dynamic());
}
set_input_is_relevant_to_shape(0);
}
}
shared_ptr<Node> op::v1::Split::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::Split>(new_args.at(0), new_args.at(1), m_num_splits);
}
int64_t op::v1::Split::axis_value_from_input() const
{
int64_t axis_value{0};
const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wswitch-enum"
#endif
switch (static_cast<element::Type_t>(axis_input->get_element_type()))
{
case element::Type_t::i8: axis_value = axis_input->get_vector<int8_t>().at(0); break;
case element::Type_t::i32: axis_value = axis_input->get_vector<int32_t>().at(0); break;
case element::Type_t::i64: axis_value = axis_input->get_vector<int64_t>().at(0); break;
default: break;
}
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
return axis_value;
} }
...@@ -61,9 +61,6 @@ namespace ngraph ...@@ -61,9 +61,6 @@ namespace ngraph
// TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY // TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY
Split(const Output<Node>& data, int axis, const std::vector<size_t>& splits); Split(const Output<Node>& data, int axis, const std::vector<size_t>& splits);
// TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY
Split(const Output<Node>& data, int axis, const size_t num_split);
void pre_validate_and_infer_types() override; void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
...@@ -82,6 +79,40 @@ namespace ngraph ...@@ -82,6 +79,40 @@ namespace ngraph
std::vector<size_t> m_splits; std::vector<size_t> m_splits;
}; };
} }
namespace v1
{
/// \brief Splits the input tensor into a list of equal sized tensors
class NGRAPH_API Split : public ngraph::op::util::FusedOp
{
public:
static constexpr NodeTypeInfo type_info{"Split", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a split operation.
Split() = default;
/// \brief Constructs a split operation.
/// \param data The tensor to be split.
/// \param axis The index of an axis in "data" along which to perform
/// the split.
/// \param num_splits The number of pieces that the data tensor should be
/// split into.
Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_splits);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
size_t get_num_splits() const { return m_num_splits; }
void set_num_splits(const size_t num_splits) { m_num_splits = num_splits; }
bool supports_decompose() const override { return false; }
protected:
size_t m_num_splits;
private:
int64_t axis_value_from_input() const;
};
}
using v0::Split; using v0::Split;
} }
} }
...@@ -227,6 +227,7 @@ NGRAPH_OP(Softmax, ngraph::op::v1, 1) ...@@ -227,6 +227,7 @@ NGRAPH_OP(Softmax, ngraph::op::v1, 1)
NGRAPH_OP(SoftmaxCrossEntropy, ngraph::op::v0, 0) NGRAPH_OP(SoftmaxCrossEntropy, ngraph::op::v0, 0)
NGRAPH_OP(SoftmaxCrossEntropyBackprop, ngraph::op::v0, 0) NGRAPH_OP(SoftmaxCrossEntropyBackprop, ngraph::op::v0, 0)
NGRAPH_OP(SpaceToDepth, ngraph::op::v0, 0) NGRAPH_OP(SpaceToDepth, ngraph::op::v0, 0)
NGRAPH_OP(Split, ngraph::op::v1, 1)
NGRAPH_OP(Split, ngraph::op::v0, 0) NGRAPH_OP(Split, ngraph::op::v0, 0)
NGRAPH_OP(Sqrt, ngraph::op, 0) NGRAPH_OP(Sqrt, ngraph::op, 0)
NGRAPH_OP(SquaredDifference, ngraph::op::v0, 0) NGRAPH_OP(SquaredDifference, ngraph::op::v0, 0)
......
...@@ -146,7 +146,7 @@ NGRAPH_OP(Sinh, ngraph::op::v0) ...@@ -146,7 +146,7 @@ NGRAPH_OP(Sinh, ngraph::op::v0)
NGRAPH_OP(Softmax, ngraph::op::v1) NGRAPH_OP(Softmax, ngraph::op::v1)
NGRAPH_OP(Sqrt, ngraph::op::v0) NGRAPH_OP(Sqrt, ngraph::op::v0)
NGRAPH_OP(SpaceToDepth, ngraph::op::v0) NGRAPH_OP(SpaceToDepth, ngraph::op::v0)
NGRAPH_OP(Split, ngraph::op::v0) NGRAPH_OP(Split, ngraph::op::v1)
NGRAPH_OP(SquaredDifference, ngraph::op::v0) NGRAPH_OP(SquaredDifference, ngraph::op::v0)
NGRAPH_OP(Squeeze, ngraph::op::v0) NGRAPH_OP(Squeeze, ngraph::op::v0)
NGRAPH_OP(StridedSlice, ngraph::op::v1) NGRAPH_OP(StridedSlice, ngraph::op::v1)
......
...@@ -578,6 +578,17 @@ namespace ...@@ -578,6 +578,17 @@ namespace
return true; return true;
} }
bool op_cast(shared_ptr<op::v1::Split> node)
{
const auto num_splits = node->get_num_splits();
auto replacement_node =
make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), num_splits);
replace_node(node, replacement_node);
return true;
}
bool op_cast(shared_ptr<op::v1::Subtract> node) bool op_cast(shared_ptr<op::v1::Subtract> node)
{ {
op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node); op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
...@@ -643,6 +654,25 @@ namespace ...@@ -643,6 +654,25 @@ namespace
return true; return true;
} }
bool op_cast(shared_ptr<op::v1::VariadicSplit> node)
{
const auto split_lengths = node->input_value(2).get_node_shared_ptr();
NGRAPH_CHECK(split_lengths->is_constant(),
"Unable to convert VariadicSplit:v1 to Split:v0 "
"if 'split_lengths' input is not constant. Node: ",
*node);
const auto splits = as_type_ptr<op::Constant>(split_lengths)->get_vector<int64_t>();
const std::vector<size_t> splits_unsigned{splits.begin(), splits.end()};
auto replacement_node =
make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), splits_unsigned);
replace_node(node, replacement_node);
return true;
}
using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>; using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
template <typename T> template <typename T>
......
...@@ -547,6 +547,35 @@ namespace ...@@ -547,6 +547,35 @@ namespace
return true; return true;
} }
bool op_cast(shared_ptr<op::Split> node)
{
const auto& splits_vec = node->get_splits();
const auto first_elem = splits_vec.front();
const bool split_evenly =
std::all_of(splits_vec.begin(), splits_vec.end(), [first_elem](const size_t split) {
return split == first_elem;
});
std::shared_ptr<Node> replacement_node;
if (split_evenly)
{
replacement_node = make_shared<op::v1::Split>(
node->input_value(0), node->input_value(1), splits_vec.front());
}
else
{
const auto split_lengths =
ngraph::op::Constant::create(element::u64, Shape{splits_vec.size()}, splits_vec);
replacement_node = make_shared<op::v1::VariadicSplit>(
node->input_value(0), node->input_value(1), split_lengths);
}
replace_node(node, replacement_node);
return true;
}
bool op_cast(shared_ptr<op::Subtract> node) bool op_cast(shared_ptr<op::Subtract> node)
{ {
op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node); op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
......
...@@ -22,3 +22,8 @@ top_k_opset_11_const_k_smallest ...@@ -22,3 +22,8 @@ top_k_opset_11_const_k_smallest
# Tile op case that the number of elements in "repeats" and shape of "data" are different # Tile op case that the number of elements in "repeats" and shape of "data" are different
tile_3d_small_data_rank tile_3d_small_data_rank
tile_3d_few_repeats tile_3d_few_repeats
# Another fused op decomposition pass required after the downgrade pass
model_split_equal_parts_default
model_split_equal_parts_2d
model_split_variable_parts_2d
...@@ -316,6 +316,11 @@ layer_norm_affine_stats ...@@ -316,6 +316,11 @@ layer_norm_affine_stats
layer_norm_bprop_affine_stats layer_norm_bprop_affine_stats
layer_norm_bprop_affine layer_norm_bprop_affine
# Another fused op decomposition pass required after the downgrade pass
model_split_equal_parts_default
model_split_equal_parts_2d
model_split_variable_parts_2d
# shapes with zeros dimensions like (5, 0, 5) not supported in PlaidML backend # shapes with zeros dimensions like (5, 0, 5) not supported in PlaidML backend
dyn_replace_slice dyn_replace_slice
......
...@@ -2747,6 +2747,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2747,6 +2747,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Split>(args[0], args[1], splits); node = make_shared<op::Split>(args[0], args[1], splits);
break; break;
} }
case OP_TYPEID::Split_v1:
{
const auto num_splits = node_js.at("num_splits").get<size_t>();
node = make_shared<op::Split>(args[0], args[1], num_splits);
break;
}
case OP_TYPEID::Sqrt: case OP_TYPEID::Sqrt:
{ {
node = make_shared<op::Sqrt>(args[0]); node = make_shared<op::Sqrt>(args[0]);
...@@ -4415,10 +4421,16 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -4415,10 +4421,16 @@ json JSONSerializer::serialize_node(const Node& n)
} }
case OP_TYPEID::Split: case OP_TYPEID::Split:
{ {
auto tmp = static_cast<const op::Split*>(&n); const auto tmp = static_cast<const op::Split*>(&n);
node["splits"] = tmp->get_splits(); node["splits"] = tmp->get_splits();
break; break;
} }
case OP_TYPEID::Split_v1:
{
const auto tmp = static_cast<const op::v1::Split*>(&n);
node["num_splits"] = tmp->get_num_splits();
break;
}
case OP_TYPEID::Sqrt: { break; case OP_TYPEID::Sqrt: { break;
} }
case OP_TYPEID::SquaredDifference: case OP_TYPEID::SquaredDifference:
......
...@@ -134,7 +134,7 @@ TEST(opset, check_opset1) ...@@ -134,7 +134,7 @@ TEST(opset, check_opset1)
CHECK_OPSET(op::v0::Sinh, opset1::Sinh) CHECK_OPSET(op::v0::Sinh, opset1::Sinh)
CHECK_OPSET(op::v1::Softmax, opset1::Softmax) CHECK_OPSET(op::v1::Softmax, opset1::Softmax)
CHECK_OPSET(op::v0::SpaceToDepth, opset1::SpaceToDepth) CHECK_OPSET(op::v0::SpaceToDepth, opset1::SpaceToDepth)
CHECK_OPSET(op::v0::Split, opset1::Split) CHECK_OPSET(op::v1::Split, opset1::Split)
CHECK_OPSET(op::v0::Sqrt, opset1::Sqrt) CHECK_OPSET(op::v0::Sqrt, opset1::Sqrt)
CHECK_OPSET(op::v0::SquaredDifference, opset1::SquaredDifference) CHECK_OPSET(op::v0::SquaredDifference, opset1::SquaredDifference)
CHECK_OPSET(op::v0::Squeeze, opset1::Squeeze) CHECK_OPSET(op::v0::Squeeze, opset1::Squeeze)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment