Commit 2a75c961 authored by Tomasz Socha's avatar Tomasz Socha Committed by Sang Ik Lee

[SPEC] Add new v1::VariadicSplit operator (#3868)

* [SPEC] Add new v1::VariadicSplit operator

* Add missing namespace, fix a typo in doc

* Apply suggestions from code review
Co-Authored-By: 's avatarMichał Karzyński <postrational@users.noreply.github.com>

* Style fix

* Set all of the inputs to be relevant to output shape

* Set output type if numer of outputs is known

* Add node validation for known input
parent 3eb99596
......@@ -323,6 +323,8 @@ set (SRC
op/subtract.hpp
op/sum.cpp
op/sum.hpp
op/variadic_split.cpp
op/variadic_split.hpp
op/tan.cpp
op/tan.hpp
op/tanh.cpp
......
......@@ -221,6 +221,7 @@ namespace ngraph
#include "ngraph/op/tensor_iterator.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/variadic_split.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/runtime/backend.hpp"
......
......@@ -167,4 +167,5 @@ NGRAPH_OP(Tanh, ngraph::op)
NGRAPH_OP(Tile, ngraph::op)
NGRAPH_OP(TopK, ngraph::op)
NGRAPH_OP(Transpose, ngraph::op)
NGRAPH_OP(VariadicSplit, ngraph::op)
NGRAPH_OP(Xor, ngraph::op)
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <numeric>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/variadic_split.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v1::VariadicSplit::type_info;
op::v1::VariadicSplit::VariadicSplit(const Output<Node>& data,
const Output<Node>& axis,
const Output<Node>& split_lengths)
: Op({data, axis, split_lengths})
{
constructor_validate_and_infer_types();
}
void ngraph::op::v1::VariadicSplit::validate_and_infer_types()
{
set_input_is_relevant_to_value(0);
set_input_is_relevant_to_value(1);
set_input_is_relevant_to_value(2);
auto split_lengths_pshape_rank = get_input_partial_shape(2).rank();
if (split_lengths_pshape_rank.is_static())
{
auto num_outputs = static_cast<size_t>(split_lengths_pshape_rank);
auto data = input_value(0);
auto axis_input = input_value(1).get_node_shared_ptr();
auto split_lengths_input = input_value(2).get_node_shared_ptr();
auto data_shape = data.get_partial_shape();
auto data_type = data.get_element_type();
set_output_size(num_outputs);
if (data_shape.is_static() && axis_input->is_constant() &&
split_lengths_input->is_constant())
{
auto axis = as_type_ptr<op::Constant>(axis_input)->get_vector<size_t>()[0];
auto split_lengths = as_type_ptr<op::Constant>(axis_input)->get_vector<size_t>();
auto splits_length = std::accumulate(split_lengths.begin(), split_lengths.end(), 0UL);
NODE_VALIDATION_CHECK(this, axis > 0, "Provided axis:", axis, " can not be negative");
auto data_rank = static_cast<size_t>(data_shape.rank());
NODE_VALIDATION_CHECK(this,
axis < data_rank,
"Provided axis:",
axis,
" can not be higher than input data rank: ",
data_rank);
NODE_VALIDATION_CHECK(this,
splits_length == static_cast<size_t>(data_shape[axis]),
"Total length of splits:",
splits_length,
" does not sum to length of the choosen axis: ",
static_cast<size_t>(data_shape[axis]));
for (size_t output{0}; output < num_outputs; ++output)
{
auto tmp_shape = data_shape.to_shape();
tmp_shape.at(axis) = split_lengths.at(axis);
set_output_type(output, data_type, tmp_shape);
}
}
else
{
for (size_t output{0}; output < num_outputs; ++output)
{
set_output_type(output, data_type, PartialShape::dynamic());
}
}
}
}
shared_ptr<Node> op::v1::VariadicSplit::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::VariadicSplit>(new_args.at(0), new_args.at(1), new_args.at(2));
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/coordinate.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/strides.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
/// \brief VariadicSplit operation splits an input tensor into pieces along some axis.
/// The pieces may have variadic lengths depending on "split_lengths" attribute.
class VariadicSplit : public Op
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"VariadicSplit", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a variadic split operation.
VariadicSplit() = default;
/// \brief Constructs a variadic split operation.
///
/// \param data The tensor to be split.
/// \param axis The index of an axis in "data" along which to perform the
/// split.
/// \param split_lengths A list containing the sizes of each output tensor
/// along the split "axis". Size of "split_lengths" should be equal to the number of
///
/// outputs. The sum of split_lengths must match data.shape[axis]
VariadicSplit(const Output<Node>& data,
const Output<Node>& axis,
const Output<Node>& split_lengths);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
} // namespace v1
using v1::VariadicSplit;
} // namespace op
} // namespace ngraph
......@@ -1868,6 +1868,7 @@ private:
case OP_TYPEID::Tile:
case OP_TYPEID::DynReplaceSlice:
case OP_TYPEID::FloorMod:
case OP_TYPEID::VariadicSplit:
throw unsupported_op("Unsupported op '" + node.description() + "'");
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
......
......@@ -161,6 +161,7 @@
#include "ngraph/op/tensor_iterator.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/variadic_split.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/provenance.hpp"
#include "ngraph/serializer.hpp"
......@@ -2787,6 +2788,11 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Unsqueeze>(args[0], args[1]);
break;
}
case OP_TYPEID::VariadicSplit:
{
node = make_shared<op::v1::VariadicSplit>(args[0], args[1], args[2]);
break;
}
case OP_TYPEID::Xor:
{
node = make_shared<op::v0::Xor>(
......@@ -4278,6 +4284,8 @@ json JSONSerializer::serialize_node(const Node& n)
}
break;
}
case OP_TYPEID::VariadicSplit: { break;
}
case OP_TYPEID::UnknownOp: { break;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment