Commit c546282d authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Scott Cyphers

[SPEC] v1::Split + onnx_importer adaptation (#3993)

* fused::v1::Split op skeleton & serialization

* Shape inference for Split and upgrade/downgrade passes

* Upgrade pass for the Split op

* Disable the failing onnx_importer Split op tests on interpreter

* Opset1 test correction for v1::Split

* Disable the failing Split op unit tests on PlaidML

* Fix of the upgrade pass for Split

* Remove one of the obsolete v0::Split constructors
parent e5e8faef
......@@ -19,6 +19,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/split.hpp"
#include "ngraph/op/variadic_split.hpp"
#include "op/split.hpp"
#include "utils/common.hpp"
......@@ -37,21 +38,25 @@ namespace ngraph
const auto axis_node =
ngraph::op::Constant::create(element::i64, Shape{}, {axis});
std::shared_ptr<ngraph::Node> fused_split;
std::shared_ptr<ngraph::Node> split;
if (node.has_attribute("split"))
{
const auto length_parts =
const auto splits =
node.get_attribute_value<std::vector<std::size_t>>("split");
fused_split =
std::make_shared<ngraph::op::Split>(input, axis_node, length_parts);
const auto split_lengths = ngraph::op::Constant::create(
element::u64, Shape{splits.size()}, splits);
split = std::make_shared<ngraph::op::v1::VariadicSplit>(
input, axis_node, split_lengths);
}
else
{
const auto outputs_number = node.get_output_names().size();
fused_split =
std::make_shared<ngraph::op::Split>(input, axis_node, outputs_number);
split = std::make_shared<ngraph::op::v1::Split>(
input, axis_node, outputs_number);
}
return common::get_outputs(fused_split);
return common::get_outputs(split);
}
} // namespace set_1
......
......@@ -317,6 +317,7 @@ const std::string& Node::description() const
// type_name to const_char and virtual description() to virtual get_type_name()
const_cast<Node*>(this)->m_node_type = get_type_name();
}
return m_node_type;
}
......
......@@ -23,9 +23,9 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Split::type_info;
constexpr NodeTypeInfo op::v0::Split::type_info;
op::Split::Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_split)
op::v0::Split::Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_split)
: FusedOp({data, axis})
, m_split_evenly{true}
, m_num_split{num_split}
......@@ -33,9 +33,9 @@ op::Split::Split(const Output<Node>& data, const Output<Node>& axis, const size_
constructor_validate_and_infer_types();
}
op::Split::Split(const Output<Node>& data,
const Output<Node>& axis,
const std::vector<size_t>& splits)
op::v0::Split::Split(const Output<Node>& data,
const Output<Node>& axis,
const std::vector<size_t>& splits)
: FusedOp({data, axis})
, m_split_evenly{false}
, m_num_split{0}
......@@ -45,7 +45,7 @@ op::Split::Split(const Output<Node>& data,
}
// TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY
op::Split::Split(const Output<Node>& data, int axis, const std::vector<size_t>& splits)
op::v0::Split::Split(const Output<Node>& data, int axis, const std::vector<size_t>& splits)
: FusedOp({data})
, m_split_evenly{false}
, m_axis{axis}
......@@ -55,17 +55,7 @@ op::Split::Split(const Output<Node>& data, int axis, const std::vector<size_t>&
constructor_validate_and_infer_types();
}
// TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY
op::Split::Split(const Output<Node>& data, int axis, size_t num_split)
: FusedOp({data})
, m_split_evenly{true}
, m_axis{axis}
, m_num_split{num_split}
{
constructor_validate_and_infer_types();
}
void op::Split::pre_validate_and_infer_types()
void op::v0::Split::pre_validate_and_infer_types()
{
// TODO REMOVE IF CHECK. INTRODUCED TO PROVIDE CI COMPATIBILITY
if (get_input_size() == 2)
......@@ -126,12 +116,12 @@ void op::Split::pre_validate_and_infer_types()
set_input_is_relevant_to_shape(0);
}
NodeVector op::Split::decompose_op() const
NodeVector op::v0::Split::decompose_op() const
{
return builder::split(input_value(0), m_splits, m_axis);
}
shared_ptr<Node> op::Split::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::Split::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() == 2)
{
......@@ -140,5 +130,96 @@ shared_ptr<Node> op::Split::copy_with_new_args(const NodeVector& new_args) const
}
// TODO REMOVE THIS RETURN AND IF ABOVE. INTRODUCED TO PROVIDE CI COMPATIBILITY
return make_shared<Split>(new_args.at(0), m_axis, m_splits);
return make_shared<op::v0::Split>(new_args.at(0), m_axis, m_splits);
}
constexpr NodeTypeInfo op::v1::Split::type_info;
op::v1::Split::Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_splits)
: FusedOp({data, axis})
, m_num_splits{num_splits}
{
constructor_validate_and_infer_types();
}
void op::v1::Split::validate_and_infer_types()
{
const auto data_ps = input_value(0).get_partial_shape();
const auto axis_ps = input_value(1).get_partial_shape();
const auto axis_et = input_value(1).get_element_type();
NODE_VALIDATION_CHECK(this,
axis_ps.rank().is_static() && (size_t)axis_ps.rank() == 0,
"The 'axis' input is expected to be a scalar. Got: ",
axis_ps);
NODE_VALIDATION_CHECK(
this, axis_et.is_integral(), "The 'axis' input only accepts integral types");
if (input_value(1).get_node_shared_ptr()->is_constant())
{
auto axis = axis_value_from_input();
if (data_ps.is_static())
{
const auto data_shape = data_ps.to_shape();
axis = ngraph::normalize_axis(this, axis, data_shape.size());
const auto dimension_at_axis = data_shape.at(axis);
NODE_VALIDATION_CHECK(this,
dimension_at_axis % m_num_splits == 0,
"The input tensor's dimension pointed by the 'axis' parameter: ",
dimension_at_axis,
" has to be a multiple of the 'num_splits' attribute value: ",
m_num_splits);
Shape each_output_shape{data_shape};
each_output_shape.at(axis) = dimension_at_axis / m_num_splits;
for (size_t i = 0; i < m_num_splits; ++i)
{
set_output_type(i, input(0).get_element_type(), each_output_shape);
}
}
}
else
{
for (size_t i = 0; i < m_num_splits; ++i)
{
set_output_type(i, input(0).get_element_type(), PartialShape::dynamic());
}
set_input_is_relevant_to_shape(0);
}
}
shared_ptr<Node> op::v1::Split::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::Split>(new_args.at(0), new_args.at(1), m_num_splits);
}
int64_t op::v1::Split::axis_value_from_input() const
{
int64_t axis_value{0};
const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wswitch-enum"
#endif
switch (static_cast<element::Type_t>(axis_input->get_element_type()))
{
case element::Type_t::i8: axis_value = axis_input->get_vector<int8_t>().at(0); break;
case element::Type_t::i32: axis_value = axis_input->get_vector<int32_t>().at(0); break;
case element::Type_t::i64: axis_value = axis_input->get_vector<int64_t>().at(0); break;
default: break;
}
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
return axis_value;
}
......@@ -61,9 +61,6 @@ namespace ngraph
// TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY
Split(const Output<Node>& data, int axis, const std::vector<size_t>& splits);
// TODO REMOVE THIS CONSTRUCTOR. INTRODUCED TO PROVIDE CI COMPATIBILITY
Split(const Output<Node>& data, int axis, const size_t num_split);
void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
......@@ -82,6 +79,40 @@ namespace ngraph
std::vector<size_t> m_splits;
};
}
namespace v1
{
/// \brief Splits the input tensor into a list of equal sized tensors
class NGRAPH_API Split : public ngraph::op::util::FusedOp
{
public:
static constexpr NodeTypeInfo type_info{"Split", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a split operation.
Split() = default;
/// \brief Constructs a split operation.
/// \param data The tensor to be split.
/// \param axis The index of an axis in "data" along which to perform
/// the split.
/// \param num_splits The number of pieces that the data tensor should be
/// split into.
Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_splits);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
size_t get_num_splits() const { return m_num_splits; }
void set_num_splits(const size_t num_splits) { m_num_splits = num_splits; }
bool supports_decompose() const override { return false; }
protected:
size_t m_num_splits;
private:
int64_t axis_value_from_input() const;
};
}
using v0::Split;
}
}
......@@ -227,6 +227,7 @@ NGRAPH_OP(Softmax, ngraph::op::v1, 1)
NGRAPH_OP(SoftmaxCrossEntropy, ngraph::op::v0, 0)
NGRAPH_OP(SoftmaxCrossEntropyBackprop, ngraph::op::v0, 0)
NGRAPH_OP(SpaceToDepth, ngraph::op::v0, 0)
NGRAPH_OP(Split, ngraph::op::v1, 1)
NGRAPH_OP(Split, ngraph::op::v0, 0)
NGRAPH_OP(Sqrt, ngraph::op, 0)
NGRAPH_OP(SquaredDifference, ngraph::op::v0, 0)
......
......@@ -146,7 +146,7 @@ NGRAPH_OP(Sinh, ngraph::op::v0)
NGRAPH_OP(Softmax, ngraph::op::v1)
NGRAPH_OP(Sqrt, ngraph::op::v0)
NGRAPH_OP(SpaceToDepth, ngraph::op::v0)
NGRAPH_OP(Split, ngraph::op::v0)
NGRAPH_OP(Split, ngraph::op::v1)
NGRAPH_OP(SquaredDifference, ngraph::op::v0)
NGRAPH_OP(Squeeze, ngraph::op::v0)
NGRAPH_OP(StridedSlice, ngraph::op::v1)
......
......@@ -578,6 +578,17 @@ namespace
return true;
}
bool op_cast(shared_ptr<op::v1::Split> node)
{
const auto num_splits = node->get_num_splits();
auto replacement_node =
make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), num_splits);
replace_node(node, replacement_node);
return true;
}
bool op_cast(shared_ptr<op::v1::Subtract> node)
{
op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
......@@ -643,6 +654,25 @@ namespace
return true;
}
bool op_cast(shared_ptr<op::v1::VariadicSplit> node)
{
const auto split_lengths = node->input_value(2).get_node_shared_ptr();
NGRAPH_CHECK(split_lengths->is_constant(),
"Unable to convert VariadicSplit:v1 to Split:v0 "
"if 'split_lengths' input is not constant. Node: ",
*node);
const auto splits = as_type_ptr<op::Constant>(split_lengths)->get_vector<int64_t>();
const std::vector<size_t> splits_unsigned{splits.begin(), splits.end()};
auto replacement_node =
make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), splits_unsigned);
replace_node(node, replacement_node);
return true;
}
using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
template <typename T>
......
......@@ -547,6 +547,35 @@ namespace
return true;
}
bool op_cast(shared_ptr<op::Split> node)
{
const auto& splits_vec = node->get_splits();
const auto first_elem = splits_vec.front();
const bool split_evenly =
std::all_of(splits_vec.begin(), splits_vec.end(), [first_elem](const size_t split) {
return split == first_elem;
});
std::shared_ptr<Node> replacement_node;
if (split_evenly)
{
replacement_node = make_shared<op::v1::Split>(
node->input_value(0), node->input_value(1), splits_vec.front());
}
else
{
const auto split_lengths =
ngraph::op::Constant::create(element::u64, Shape{splits_vec.size()}, splits_vec);
replacement_node = make_shared<op::v1::VariadicSplit>(
node->input_value(0), node->input_value(1), split_lengths);
}
replace_node(node, replacement_node);
return true;
}
bool op_cast(shared_ptr<op::Subtract> node)
{
op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
......
......@@ -22,3 +22,8 @@ top_k_opset_11_const_k_smallest
# Tile op case that the number of elements in "repeats" and shape of "data" are different
tile_3d_small_data_rank
tile_3d_few_repeats
# Another fused op decomposition pass required after the downgrade pass
model_split_equal_parts_default
model_split_equal_parts_2d
model_split_variable_parts_2d
......@@ -316,6 +316,11 @@ layer_norm_affine_stats
layer_norm_bprop_affine_stats
layer_norm_bprop_affine
# Another fused op decomposition pass required after the downgrade pass
model_split_equal_parts_default
model_split_equal_parts_2d
model_split_variable_parts_2d
# shapes with zeros dimensions like (5, 0, 5) not supported in PlaidML backend
dyn_replace_slice
......
......@@ -2747,6 +2747,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Split>(args[0], args[1], splits);
break;
}
case OP_TYPEID::Split_v1:
{
const auto num_splits = node_js.at("num_splits").get<size_t>();
node = make_shared<op::Split>(args[0], args[1], num_splits);
break;
}
case OP_TYPEID::Sqrt:
{
node = make_shared<op::Sqrt>(args[0]);
......@@ -4415,10 +4421,16 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Split:
{
auto tmp = static_cast<const op::Split*>(&n);
const auto tmp = static_cast<const op::Split*>(&n);
node["splits"] = tmp->get_splits();
break;
}
case OP_TYPEID::Split_v1:
{
const auto tmp = static_cast<const op::v1::Split*>(&n);
node["num_splits"] = tmp->get_num_splits();
break;
}
case OP_TYPEID::Sqrt: { break;
}
case OP_TYPEID::SquaredDifference:
......
......@@ -134,7 +134,7 @@ TEST(opset, check_opset1)
CHECK_OPSET(op::v0::Sinh, opset1::Sinh)
CHECK_OPSET(op::v1::Softmax, opset1::Softmax)
CHECK_OPSET(op::v0::SpaceToDepth, opset1::SpaceToDepth)
CHECK_OPSET(op::v0::Split, opset1::Split)
CHECK_OPSET(op::v1::Split, opset1::Split)
CHECK_OPSET(op::v0::Sqrt, opset1::Sqrt)
CHECK_OPSET(op::v0::SquaredDifference, opset1::SquaredDifference)
CHECK_OPSET(op::v0::Squeeze, opset1::Squeeze)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment