Commit 54b9d9aa authored by mbencer's avatar mbencer

Added builder opset1 split

parent 3d9004c0
......@@ -15,7 +15,9 @@
//*****************************************************************************
#include "ngraph/builder/split.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/opsets/opset1.hpp"
using namespace ngraph;
......@@ -45,6 +47,29 @@ namespace
std::make_shared<op::Slice>(output, lower_bounds, upper_bounds)
->add_provenance_group_members_above({output}));
}
/// \brief Return the outputs of the node as vector.
///
/// \param[in] node Node with multiple outputs.
///
/// \return Vector of outputs of input node.
NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node)
{
const auto outputs_number = node->get_output_size();
ngraph::NodeVector outputs(outputs_number);
for (int i = 0; i < outputs_number; ++i)
{
if (node->output(i).get_node_shared_ptr()->get_output_size() == 1)
{
outputs[i] = node->get_output_as_single_output_node(i);
}
else
{
outputs[i] = std::make_shared<op::GetOutputElement>(node, i);
}
}
return outputs;
}
}
NodeVector builder::split(const Output<ngraph::Node>& value,
......@@ -74,3 +99,24 @@ NodeVector builder::split(const Output<Node>& value, size_t split_parts, int axi
std::vector<size_t> length_parts(split_parts, length_axis_to_split / split_parts);
return split(value, length_parts, axis_to_split);
}
NodeVector builder::opset1::split(const Output<Node>& value,
const std::vector<size_t>& length_parts,
int64_t axis)
{
const auto axis_node = op::Constant::create(element::u64, Shape{}, {axis});
const auto length_parts_node =
op::Constant::create(element::u64, Shape{length_parts.size()}, length_parts);
const auto variadic_split =
std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, length_parts_node);
return get_outputs(variadic_split);
}
NodeVector builder::opset1::split(const Output<Node>& value, size_t split_parts, int64_t axis)
{
const auto axis_node = op::Constant::create(element::u64, Shape{}, {axis});
const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, split_parts);
return get_outputs(split);
}
......@@ -50,5 +50,38 @@ namespace ngraph
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector split(const Output<Node>& value, size_t split_parts, int axis = 0);
namespace opset1
{
/// \brief Split value on specified axis into multiple parts.
///
/// \param[in] value The value to be split.
/// \param[in] length_parts The vector defining the lengths of each split part.
/// \param[in] axis The axis we split input node on. Default value is zero
/// axis.
///
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector split(const Output<Node>& value,
const std::vector<size_t>& length_parts,
int64_t axis = 0);
/// \brief Split node on specified axis into multiple parts.
///
/// \param[in] value The value to split.
/// \param[in] split_parts The number of parts we want to split output at given
/// axis. The length of the axis to split must be divisible by
/// this value.
/// \param[in] axis The axis we split input node on. Default value is zero
/// axis.
///
/// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing). This means that the axis to split on will be counted from
/// the back of the tensor (negative values are subtracted from its rank).
///
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector split(const Output<Node>& value, size_t split_parts, int64_t axis = 0);
}
} // namespace builder
} // namespace ngraph
......@@ -18,8 +18,8 @@
#include <vector>
#include "default_opset.hpp"
#include "ngraph/builder/split.hpp"
#include "split.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -33,28 +33,18 @@ namespace ngraph
{
const auto input = node.get_ng_inputs().at(0);
const auto axis = node.get_attribute_value<int64_t>("axis", 0);
const auto axis_node =
default_opset::Constant::create(element::i64, Shape{}, {axis});
std::shared_ptr<ngraph::Node> split;
if (node.has_attribute("split"))
{
const auto splits =
node.get_attribute_value<std::vector<std::size_t>>("split");
const auto split_lengths = default_opset::Constant::create(
element::u64, Shape{splits.size()}, splits);
split = std::make_shared<default_opset::VariadicSplit>(
input, axis_node, split_lengths);
return ngraph::builder::opset1::split(input, splits, axis);
}
else
{
const auto outputs_number = node.get_output_names().size();
split = std::make_shared<default_opset::Split>(
input, axis_node, outputs_number);
return ngraph::builder::opset1::split(input, outputs_number, axis);
}
return common::get_outputs(split);
}
} // namespace set_1
......
......@@ -18,7 +18,6 @@
#include "common.hpp"
#include "default_opset.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "validation_util.hpp"
......@@ -51,24 +50,6 @@ namespace ngraph
static_cast<onnx::TensorProto_DataType>(onnx_type)));
}
ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node)
{
const auto outputs_number = node->get_output_size();
ngraph::NodeVector outputs(outputs_number);
for (int i = 0; i < outputs_number; ++i)
{
if (node->output(i).get_node_shared_ptr()->get_output_size() == 1)
{
outputs[i] = node->get_output_as_single_output_node(i);
}
else
{
outputs[i] = std::make_shared<ngraph::opset0::GetOutputElement>(node, i);
}
}
return outputs;
}
} // namespace common
} // namespace onnx_import
} // namespace ngraph
......@@ -69,13 +69,6 @@ namespace ngraph
return range;
}
/// \brief Return the outputs of the node as vector.
///
/// \param[in] node Node with multiple outputs.
///
/// \return Vector of outputs of input node.
ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node);
/// \brief Creates a shifted square identity matrix.
/// \note Shifting in the context of this operator means that
/// the matrix can be created with elements equal to 1 not only in the main
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment