Commit ba546455 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Robert Kimball

[FusedOps] Split (#2951)

* Split op skeleton

* Two ways to construct a fused Split to be able to use it in onnx importer

* refactor: move the util::split() helper functions to the core

* Split's decompose_op() implementation using a helper function

* Use fused Split in the onnx_importer

* Code formatting

* PR feedback

* Split helpers moved to ngraph/builder

* Basic UT - split a 1D tensor to 3 equal parts

* UT: Split 2D tensor into variable length parts

* Code formatting

* Catch the proper type of exception in the onnx_importer split()

* Initialize members in the correct order

* Type prop tests for Split

* Code formatting

* PR feedback
parent 9335e41c
......@@ -38,6 +38,8 @@ set (SRC
builder/quantization_util.hpp
builder/reduce_ops.cpp
builder/reduce_ops.hpp
builder/split.cpp
builder/split.hpp
builder/tensor_mask.hpp
check.hpp
code_writer.hpp
......@@ -310,6 +312,8 @@ set (SRC
op/fused/scale_shift.hpp
op/fused/space_to_depth.cpp
op/fused/space_to_depth.hpp
op/fused/split.cpp
op/fused/split.hpp
op/fused/squared_difference.cpp
op/fused/squared_difference.hpp
op/fused/squeeze.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/builder/split.hpp"
#include "ngraph/op/slice.hpp"
using namespace ngraph;
namespace
{
inline std::size_t get_valid_array_index(std::size_t idx, std::size_t axis_size)
{
return std::min(idx, axis_size);
}
std::shared_ptr<op::Slice> make_ng_slice(const std::shared_ptr<ngraph::Node>& node,
const std::vector<std::size_t>& axes,
const std::vector<std::size_t>& starts,
const std::vector<std::size_t>& ends)
{
std::vector<std::size_t> upper_bounds{node->get_shape()};
std::vector<std::size_t> lower_bounds(upper_bounds.size());
for (std::size_t index{0}; index < axes.size(); ++index)
{
std::size_t axis{axes.at(index)};
lower_bounds.at(axis) =
get_valid_array_index(starts.at(index), node->get_shape().at(axis));
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), node->get_shape().at(axis));
}
return std::make_shared<op::Slice>(node, lower_bounds, upper_bounds);
}
}
NodeVector builder::split(const std::shared_ptr<ngraph::Node>& node,
const std::vector<size_t>& length_parts,
size_t axis)
{
size_t start_index{0};
NodeVector outputs;
for (const auto& length_part : length_parts)
{
size_t end_index{start_index + length_part};
outputs.push_back(make_ng_slice(node, {axis}, {start_index}, {end_index}));
start_index = end_index;
}
return outputs;
}
NodeVector builder::split(const std::shared_ptr<ngraph::Node>& node, size_t split_parts, int axis)
{
size_t axis_to_split{static_cast<size_t>(axis)};
if (axis < 0)
{
axis_to_split = node->get_shape().size() + axis;
}
size_t length_axis_to_split{node->get_shape().at(axis_to_split)};
std::vector<size_t> length_parts(split_parts, length_axis_to_split / split_parts);
return split(node, length_parts, axis_to_split);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node.hpp"
namespace ngraph
{
namespace builder
{
/// \brief Split node on specified axis into multiple parts.
///
/// \param[in] node The input node.
/// \param[in] length_parts The vector defining the lengths of each split part.
/// \param[in] axis The axis we split input node on. Default value is zero axis.
///
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector split(const std::shared_ptr<ngraph::Node>& node,
const std::vector<std::size_t>& length_parts,
std::size_t axis = 0);
/// \brief Split node on specified axis into multiple parts.
///
/// \param[in] node The input node.
/// \param[in] split_parts The number of parts we want to split input node at given
/// axis. The length of the axis to split must be divisible by
/// this value.
/// \param[in] axis The axis we split input node on. Default value is zero axis.
///
/// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing). This means that the axis to split on will be counted from
/// the back of the tensor (negative values are subtracted from its rank).
///
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector
split(const std::shared_ptr<ngraph::Node>& node, std::size_t split_parts, int axis = 0);
} // namespace builder
} // namespace ngraph
......@@ -22,12 +22,11 @@
#include "lp_pool.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/norm.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/util.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
......@@ -47,7 +46,7 @@ namespace ngraph
ASSERT_VALID_ARGUMENT(node, p_norm >= 0)
<< "Only positive (including zero) values are supported for 'p' attribute.";
NodeVector slices = reshape::split(data, channels_count, channel_axis);
NodeVector slices = ngraph::builder::split(data, channels_count, channel_axis);
for (auto& slice : slices)
{
......
......@@ -31,6 +31,7 @@
#include "lstm.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/concat.hpp"
......@@ -356,13 +357,13 @@ namespace ngraph
// step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
NodeVector p_iof = reshape::split(m_P, 3);
NodeVector p_iof = ngraph::builder::split(m_P, 3);
const auto& p_i = p_iof.at(0);
const auto& p_o = p_iof.at(1);
const auto& p_f = p_iof.at(2);
NodeVector h_list;
NodeVector b_W_R = reshape::split(m_B, 2);
NodeVector b_W_R = ngraph::builder::split(m_B, 2);
std::shared_ptr<ngraph::Node> bias = b_W_R.at(0) + b_W_R.at(1);
std::shared_ptr<ngraph::Node> H_t = m_initial_h;
std::shared_ptr<ngraph::Node> C_t = m_initial_c;
......@@ -375,7 +376,7 @@ namespace ngraph
NodeVector in_seqs{};
if (m_X->get_shape().at(0) != 1)
{
in_seqs = reshape::split(m_X, m_X->get_shape().at(0));
in_seqs = ngraph::builder::split(m_X, m_X->get_shape().at(0));
}
else
{
......@@ -403,7 +404,7 @@ namespace ngraph
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates.
auto gates = add(Xt_W, add(Ht_R, bias));
NodeVector split_gates = reshape::split(gates, 4, -1);
NodeVector split_gates = ngraph::builder::split(gates, 4, -1);
auto i = split_gates.at(0);
auto o = split_gates.at(1);
auto f = split_gates.at(2);
......@@ -602,12 +603,18 @@ namespace ngraph
if (attributes.m_direction == LSTMDirection::LSTM_DIRECTION_BIDIRECTIONAL)
{
// In bidirectional mode weights are stacked together, so we must split them.
NodeVector W{reshape::split(input_map.at(LSTMInput::LSTM_INPUT_W), 2)};
NodeVector R{reshape::split(input_map.at(LSTMInput::LSTM_INPUT_R), 2)};
NodeVector B{reshape::split(input_map.at(LSTMInput::LSTM_INPUT_B), 2)};
NodeVector P{reshape::split(input_map.at(LSTMInput::LSTM_INPUT_P), 2)};
NodeVector H{reshape::split(input_map.at(LSTMInput::LSTM_INPUT_INIT_H), 2)};
NodeVector C{reshape::split(input_map.at(LSTMInput::LSTM_INPUT_INIT_C), 2)};
NodeVector W{
ngraph::builder::split(input_map.at(LSTMInput::LSTM_INPUT_W), 2)};
NodeVector R{
ngraph::builder::split(input_map.at(LSTMInput::LSTM_INPUT_R), 2)};
NodeVector B{
ngraph::builder::split(input_map.at(LSTMInput::LSTM_INPUT_B), 2)};
NodeVector P{
ngraph::builder::split(input_map.at(LSTMInput::LSTM_INPUT_P), 2)};
NodeVector H{
ngraph::builder::split(input_map.at(LSTMInput::LSTM_INPUT_INIT_H), 2)};
NodeVector C{
ngraph::builder::split(input_map.at(LSTMInput::LSTM_INPUT_INIT_C), 2)};
LSTMForward lstm_fwd(input_map.at(LSTMInput::LSTM_INPUT_X),
W.at(0),
......
......@@ -18,114 +18,42 @@
#include <vector>
#include "exceptions.hpp"
#include "ngraph/op/fused/split.hpp"
#include "op/split.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace error
{
namespace op
{
namespace split
{
namespace detail
{
struct Error : ngraph_error
{
explicit Error(const std::string& name, const std::string& message)
: ngraph_error{"Split node (" + name + "): " + message}
{
}
};
} // namespace detail
struct OutOfRange : detail::Error
{
explicit OutOfRange(const std::string& name)
: Error{name,
"provided split axis is out of input tensor dimensions range."}
{
}
};
struct Parts : detail::Error
{
explicit Parts(const std::string& name,
std::size_t parts,
std::size_t axis_length)
: Error{name,
"tensor cannot be split into " + std::to_string(parts) +
" equal parts, along axis of length " +
std::to_string(axis_length)}
{
}
};
struct Sum : detail::Error
{
explicit Sum(const std::string& name, std::size_t parts, std::size_t axis)
: Error{name,
"provided lengths of split parts does not sum up to "
"length of axis we split on: " +
std::to_string(parts) + " != " + std::to_string(axis)}
{
}
};
} // namespace split
} // namespace op
} // namespace error
namespace op
{
namespace set_1
{
NodeVector split(const Node& node)
{
std::shared_ptr<ngraph::Node> input = node.get_ng_inputs().at(0);
auto input_shape = input->get_shape();
std::size_t count_outputs{node.get_output_names().size()};
int64_t axis{node.get_attribute_value<int64_t>("axis", 0)};
std::size_t axis_to_split{static_cast<std::size_t>(axis)};
if (axis < 0)
{
axis_to_split = input_shape.size() + axis;
}
else if (axis_to_split >= input_shape.size())
{
throw error::op::split::OutOfRange{node.get_name()};
}
std::size_t length_axis_to_split{input_shape.at(axis_to_split)};
std::vector<std::size_t> length_parts;
const auto input = node.get_ng_inputs().at(0);
const auto outputs_number = node.get_output_names().size();
const auto axis = node.get_attribute_value<int64_t>("axis", 0);
try
{
length_parts = node.get_attribute_value<std::vector<std::size_t>>("split");
const auto length_parts =
node.get_attribute_value<std::vector<std::size_t>>("split");
const auto fused_split =
std::make_shared<ngraph::op::Split>(input, axis, length_parts);
return fused_split->decompose_op();
}
catch (const std::exception&)
catch (const error::node::UnknownAttribute&)
{
if (length_axis_to_split % count_outputs)
{
throw error::op::split::Parts{
node.get_name(), count_outputs, length_axis_to_split};
}
length_parts.assign(count_outputs, length_axis_to_split / count_outputs);
}
// an exception will be caught if the input node does not contain
// the 'split' attribute - this means we should split the input tensor
// into same-length parts equal to the number of node outputs
const auto fused_split =
std::make_shared<ngraph::op::Split>(input, axis, outputs_number);
std::size_t total_parts_length = 0;
for (auto length : length_parts)
{
ASSERT_VALID_ARGUMENT(node, length > 0)
<< "Invalid value in 'split' attribute";
total_parts_length += length;
return fused_split->decompose_op();
}
ASSERT_VALID_ARGUMENT(node, total_parts_length == input_shape.at(axis_to_split))
<< "Cannot split using values in 'split' attribute";
return reshape::split(input, length_parts, axis_to_split);
}
} // namespace set_1
......
......@@ -35,33 +35,6 @@ namespace ngraph
{
namespace reshape
{
namespace
{
inline std::size_t get_valid_array_index(std::size_t idx, std::size_t axis_size)
{
return std::min(idx, axis_size);
}
std::shared_ptr<op::Slice> make_ng_slice(const std::shared_ptr<ngraph::Node>& node,
const std::vector<std::size_t>& axes,
const std::vector<std::size_t>& starts,
const std::vector<std::size_t>& ends)
{
std::vector<std::size_t> upper_bounds{node->get_shape()};
std::vector<std::size_t> lower_bounds(upper_bounds.size());
for (std::size_t index{0}; index < axes.size(); ++index)
{
std::size_t axis{axes.at(index)};
lower_bounds.at(axis) =
get_valid_array_index(starts.at(index), node->get_shape().at(axis));
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), node->get_shape().at(axis));
}
return std::make_shared<op::Slice>(node, lower_bounds, upper_bounds);
}
} // namespace anonymous
std::vector<std::size_t> infer_dimensions(const std::string& node_name,
const std::vector<std::size_t>& input_shape,
const std::vector<std::size_t>& output_shape)
......@@ -171,36 +144,6 @@ namespace ngraph
node, ngraph::get_default_order(node->get_shape().size()), output_shape);
}
NodeVector split(const std::shared_ptr<ngraph::Node>& node,
const std::vector<std::size_t>& length_parts,
std::size_t axis)
{
std::size_t start_index{0};
NodeVector outputs;
for (const auto& length_part : length_parts)
{
std::size_t end_index{start_index + length_part};
outputs.push_back(make_ng_slice(node, {axis}, {start_index}, {end_index}));
start_index = end_index;
}
return outputs;
}
NodeVector
split(const std::shared_ptr<ngraph::Node>& node, std::size_t split_parts, int axis)
{
std::size_t axis_to_split{static_cast<std::size_t>(axis)};
if (axis < 0)
{
axis_to_split = node->get_shape().size() + axis;
}
std::size_t length_axis_to_split{node->get_shape().at(axis_to_split)};
std::vector<std::size_t> length_parts(split_parts,
length_axis_to_split / split_parts);
return split(node, length_parts, axis_to_split);
}
std::shared_ptr<ngraph::Node>
interpret_as_scalar(const std::shared_ptr<ngraph::Node>& node)
{
......
......@@ -86,35 +86,6 @@ namespace ngraph
std::shared_ptr<ngraph::Node> expand_dims(const std::shared_ptr<ngraph::Node>& node,
std::size_t axis = 0);
/// \brief Split node on specified axis into multiple parts.
///
/// \param[in] node The input node.
/// \param[in] length_parts The vector defining the lengts of each splitted part.
/// \param[in] axis The axis we split input node on. Default value is zero axis.
///
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector split(const std::shared_ptr<ngraph::Node>& node,
const std::vector<std::size_t>& length_parts,
std::size_t axis = 0);
/// \brief Split node on specified axis into multiple parts.
///
/// \param[in] node The input node.
/// \param[in] split_parts The number of parts we want to split input node at given
/// axis. The length of the axis to split must be divisible by
/// this value.
/// \param[in] axis The axis we split input node on. Default value is zero axis.
///
/// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing).
///
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector split(const std::shared_ptr<ngraph::Node>& node,
std::size_t split_parts,
int axis = 0);
/// \brief Handle a node which represents a scalar value.
///
/// \note Some ONNX nodes, which should provide scalar values are given as
......
......@@ -108,6 +108,7 @@
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/split.hpp"
#include "ngraph/op/fused/squared_difference.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <numeric>
#include "ngraph/builder/split.hpp"
#include "ngraph/op/fused/split.hpp"
using namespace std;
using namespace ngraph;
op::Split::Split(const shared_ptr<Node>& data, const int axis, const size_t num_split)
: FusedOp("Split", {data})
, m_split_evenly{true}
, m_axis{axis}
, m_num_split{num_split}
{
constructor_validate_and_infer_types();
}
op::Split::Split(const std::shared_ptr<ngraph::Node>& data,
const int axis,
const std::vector<size_t>& splits)
: FusedOp("Split", {data})
, m_split_evenly{false}
, m_axis{axis}
, m_splits{splits}
{
constructor_validate_and_infer_types();
}
void op::Split::pre_validate_and_infer_types()
{
const auto shape = get_argument(0)->get_shape();
m_axis = adjust_axis_value(m_axis, shape.size());
NODE_VALIDATION_CHECK(this,
m_axis >= 0 && m_axis < shape.size(),
"The 'axis' parameter for Split has to point to one of the "
"input tensor's shape dimensions.");
const auto dimension_at_axis = shape.at(m_axis);
if (m_split_evenly)
{
NODE_VALIDATION_CHECK(this,
dimension_at_axis % m_num_split == 0,
"The input tensor's dimension pointed by the 'axis' parameter: ",
dimension_at_axis,
" has to be a multiple of the 'num_split' parameter value: ",
m_num_split);
m_splits.assign(m_num_split, dimension_at_axis / m_num_split);
}
else
{
const auto sum_splits = accumulate(begin(m_splits), end(m_splits), 0UL);
NODE_VALIDATION_CHECK(this,
sum_splits == dimension_at_axis,
"The input tensor's dimension pointed by the 'axis' parameter: ",
dimension_at_axis,
" has to be equal to the sum of splits passed to the op: ",
sum_splits);
}
}
NodeVector op::Split::decompose_op() const
{
return builder::split(get_argument(0), m_splits, m_axis);
}
shared_ptr<Node> op::Split::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Split>(new_args.at(0), m_axis, m_splits);
}
size_t op::Split::adjust_axis_value(const int axis, const size_t input_tensor_rank) const
{
if (axis < 0)
{
return axis + input_tensor_rank;
}
else
{
return axis;
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Splits the input tensor into a list of smaller tensors ("pieces")
class Split : public ngraph::op::util::FusedOp
{
public:
/// \brief Constructs a Split op that evenly divides the input tensor.
///
/// \param data - Node producing the input tensor
/// \param axis - indicates an axis along which the input tensor should be split. Negative values mean counting from the back of the input tensor's shape.
/// \param num_split - a number of "pieces" the input tensor will be split to
Split(const std::shared_ptr<ngraph::Node>& data,
const int axis,
const size_t num_split);
/// \brief Constructs a Split op that splits the input tensor into variable length "pieces"
///
/// \param data - Node producing the input tensor
/// \param axis - indicates an axis along which the input tensor should be split. Negative values mean counting from the back of the input tensor's shape.
/// \param splits - a list of lengths that the input tensor should be split to. Use this constructor to split the input tensor to variable length chunks.
Split(const std::shared_ptr<ngraph::Node>& data,
const int axis,
const std::vector<size_t>& splits);
void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
size_t get_axis() const { return m_axis; }
const std::vector<size_t>& get_splits() const { return m_splits; }
private:
/// \brief Adjusts the axis for negative values
///
/// \note Negative values mean that the API consumer wants to point the axis location
/// from the back of the tensor. This is similar to the way NumPy works.
///
/// \param axis - original axis value; negative values are accepted
/// \param input_tensor_rank - rank of the input data tensor
/// \return Returns a sum of parameters for negative axis value, or axis itself otherwise
size_t adjust_axis_value(const int axis, const size_t input_tensor_rank) const;
/// used internally for validation purposes, indicates which constructor was used
bool m_split_evenly;
int m_axis;
size_t m_num_split;
/// contains lengths of chunks that the input tensor will be split into
std::vector<size_t> m_splits;
};
}
}
......@@ -34,4 +34,5 @@ NGRAPH_OP(ScaleShift, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
NGRAPH_OP(SquaredDifference, ngraph::op)
NGRAPH_OP(Squeeze, ngraph::op)
NGRAPH_OP(Split, ngraph::op)
NGRAPH_OP(Unsqueeze, ngraph::op)
......@@ -2078,6 +2078,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::ScatterNDAdd:
case OP_TYPEID::ShapeOf:
case OP_TYPEID::SpaceToDepth:
case OP_TYPEID::Split:
case OP_TYPEID::SquaredDifference:
case OP_TYPEID::Squeeze:
case OP_TYPEID::StopGradient:
......@@ -2178,6 +2179,7 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node)
case OP_TYPEID::PRelu:
case OP_TYPEID::ScaleShift:
case OP_TYPEID::SpaceToDepth:
case OP_TYPEID::Split:
case OP_TYPEID::SquaredDifference:
case OP_TYPEID::Squeeze:
case OP_TYPEID::Unsqueeze: { return false;
......
......@@ -79,6 +79,7 @@
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/split.hpp"
#include "ngraph/op/fused/squared_difference.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
......@@ -1450,6 +1451,13 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::SpaceToDepth>(args[0], block_size);
break;
}
case OP_TYPEID::Split:
{
const auto axis = node_js.at("axis").get<size_t>();
const auto splits = node_js.at("splits").get<vector<size_t>>();
node = make_shared<op::Split>(args[0], axis, splits);
break;
}
case OP_TYPEID::Sqrt:
{
node = make_shared<op::Sqrt>(args[0]);
......@@ -2208,6 +2216,13 @@ static json write(const Node& n, bool binary_constant_data)
node["block_size"] = tmp->get_block_size();
break;
}
case OP_TYPEID::Split:
{
auto tmp = dynamic_cast<const op::Split*>(&n);
node["axis"] = tmp->get_axis();
node["splits"] = tmp->get_splits();
break;
}
case OP_TYPEID::Sqrt: { break;
}
case OP_TYPEID::SquaredDifference: { break;
......
......@@ -900,3 +900,37 @@ NGRAPH_TEST(${BACKEND_NAME}, squared_difference_broadcast)
test_case.add_expected_output<int32_t>(Shape{2, 2}, {0, 0, 0, 0});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, split_3_equal_parts)
{
const auto data = make_shared<op::Parameter>(element::i32, Shape{6});
const auto tested_op = make_shared<op::Split>(data, 0, 3);
const auto function = make_shared<Function>(tested_op->decompose_op(), ParameterVector{data});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<int32_t>({1, 2, 3, 4, 5, 6});
test_case.add_expected_output<int32_t>(Shape{2}, {1, 2});
test_case.add_expected_output<int32_t>(Shape{2}, {3, 4});
test_case.add_expected_output<int32_t>(Shape{2}, {5, 6});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, split_var_len_parts)
{
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
const std::vector<size_t> splits = {2, 4};
const auto tested_op = make_shared<op::Split>(data, 1, splits);
const auto function = make_shared<Function>(tested_op->decompose_op(), ParameterVector{data});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<int32_t>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
test_case.add_expected_output<int32_t>(Shape{2, 2}, {0, 1, 6, 7});
test_case.add_expected_output<int32_t>(Shape{2, 4}, {2, 3, 4, 5, 8, 9, 10, 11});
test_case.run();
}
......@@ -14546,3 +14546,40 @@ TEST(type_prop, squared_difference)
EXPECT_EQ(clamp->get_element_type(), element::f64);
EXPECT_EQ(clamp->get_shape(), (Shape{2, 2}));
}
TEST(type_prop, split)
{
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
try
{
const std::vector<size_t> splits = {1, 6}; // should sum up to 6
const auto split = make_shared<op::Split>(data, 1, splits);
FAIL() << "Split node was created with incorrect data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), std::string("has to be equal to the sum of splits passed to the op: 7"));
}
try
{
const std::vector<size_t> splits = {4, 2};
const auto split = make_shared<op::Split>(data, -5, splits); //invalid axis
FAIL() << "Split node was created with incorrect data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("The 'axis' parameter for Split has to point to one of "
"the input tensor's shape dimensions."));
}
const auto split = make_shared<op::Split>(data, 1, 2);
EXPECT_EQ(split->outputs().size(), 2);
EXPECT_EQ(split->output(0).get_shape(), (Shape{2, 3}));
EXPECT_EQ(split->output(1).get_shape(), (Shape{2, 3}));
EXPECT_EQ(split->output(0).get_element_type(), element::i32);
EXPECT_EQ(split->output(1).get_element_type(), element::i32);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment