Unverified Commit d9242244 authored by Sang Ik Lee's avatar Sang Ik Lee Committed by GitHub

Merge pull request #4311 from NervanaSystems/mbencer/BuilderSplitV1

[ONNX] builder::split used by onnx importer should produce v1 ops
parents 6c0cf85a 9664b96f
......@@ -15,7 +15,9 @@
//*****************************************************************************
#include "ngraph/builder/split.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/opsets/opset1.hpp"
using namespace ngraph;
......@@ -45,6 +47,29 @@ namespace
std::make_shared<op::Slice>(output, lower_bounds, upper_bounds)
->add_provenance_group_members_above({output}));
}
/// \brief Return the outputs of the node as vector.
///
/// \param[in] node Node with multiple outputs.
///
/// \return Vector of outputs of input node.
NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node)
{
const auto outputs_number = node->get_output_size();
ngraph::NodeVector outputs(outputs_number);
for (int i = 0; i < outputs_number; ++i)
{
if (node->output(i).get_node_shared_ptr()->get_output_size() == 1)
{
outputs[i] = node->get_output_as_single_output_node(i);
}
else
{
outputs[i] = std::make_shared<op::GetOutputElement>(node, i);
}
}
return outputs;
}
}
NodeVector builder::split(const Output<ngraph::Node>& value,
......@@ -74,3 +99,24 @@ NodeVector builder::split(const Output<Node>& value, size_t split_parts, int axi
std::vector<size_t> length_parts(split_parts, length_axis_to_split / split_parts);
return split(value, length_parts, axis_to_split);
}
NodeVector builder::opset1::split(const Output<Node>& value,
const std::vector<size_t>& split_lengths,
int64_t axis)
{
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
const auto split_lengths_node =
ngraph::opset1::Constant::create(element::u64, Shape{split_lengths.size()}, split_lengths);
const auto variadic_split =
std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node);
return get_outputs(variadic_split);
}
NodeVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis)
{
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits);
return get_outputs(split);
}
......@@ -23,11 +23,11 @@ namespace ngraph
{
namespace builder
{
/// \brief Split value on specified axis into multiple parts.
/// \brief Split value on specified axis into multiple parts.
///
/// \param[in] value The value to be split.
/// \param[in] length_parts The vector defining the lengths of each split part.
/// \param[in] axis The axis we split input node on. Default value is zero axis.
/// \param value The value to be split.
/// \param length_parts The vector defining the lengths of each split part.
/// \param axis The axis we split input node on. Default value is zero axis.
///
/// \return The vector containing multiple nodes we split input node into.
///
......@@ -37,11 +37,11 @@ namespace ngraph
/// \brief Split node on specified axis into multiple parts.
///
/// \param[in] value The value to split.
/// \param[in] split_parts The number of parts we want to split output at given
/// axis. The length of the axis to split must be divisible by
/// this value.
/// \param[in] axis The axis we split input node on. Default value is zero axis.
/// \param value The value to split.
/// \param split_parts The number of parts we want to split output at given
/// axis. The length of the axis to split must be divisible by
/// this value.
/// \param axis The axis we split input node on. Default value is zero axis.
///
/// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing). This means that the axis to split on will be counted from
......@@ -50,5 +50,43 @@ namespace ngraph
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector split(const Output<Node>& value, size_t split_parts, int axis = 0);
namespace opset1
{
/// \brief Split value on specified axis into multiple parts.
///
/// \param value The value to be split.
/// \param split_lengths The vector defining the lengths of each split part.
/// \param axis The axis we split input node on. Default value is zero
/// axis.
/// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing). This means that the axis to split on will be counted from
/// the back of the tensor (negative values are subtracted from its rank).
///
/// \return The vector containing multiple nodes we split input node into.
/// The vector is output of Split:v1 op
///
NodeVector split(const Output<Node>& value,
const std::vector<size_t>& split_lengths,
int64_t axis = 0);
/// \brief Split value on specified axis into multiple parts.
///
/// \param value The value to split.
/// \param num_splits The number of parts we want to split output at given
/// axis. The length of the axis to split must be divisible by
/// this value.
/// \param axis The axis we split input node on. Default value is zero
/// axis.
///
/// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing). This means that the axis to split on will be counted from
/// the back of the tensor (negative values are subtracted from its rank).
///
/// \return The vector containing multiple nodes we split input node into.
/// The vector is output of VariadicSplit:v1 op
///
NodeVector split(const Output<Node>& value, size_t num_splits, int64_t axis = 0);
}
} // namespace builder
} // namespace ngraph
......@@ -45,7 +45,8 @@ namespace ngraph
ASSERT_VALID_ARGUMENT(node, p_norm >= 0)
<< "Only positive (including zero) values are supported for 'p' attribute.";
NodeVector slices = ngraph::builder::split(data, channels_count, channel_axis);
NodeVector slices =
ngraph::builder::opset1::split(data, channels_count, channel_axis);
for (auto& slice : slices)
{
......
......@@ -91,7 +91,7 @@ namespace ngraph
if (ng_inputs.size() > 3 && !ng_inputs.at(3)->is_null())
{
auto bias = ng_inputs.at(3);
auto split_bias = builder::split(bias, 2, 1);
auto split_bias = builder::opset1::split(bias, 2, 1);
m_map[LSTMInput::LSTM_INPUT_B] = split_bias.at(0) + split_bias.at(1);
}
else
......
......@@ -134,7 +134,7 @@ namespace ngraph
{
auto axis =
default_opset::Constant::create(element::i64, ngraph::Shape{}, {0});
NodeVector padding = builder::split(pads, 2, 0);
NodeVector padding = builder::opset1::split(pads, 2, 0);
padding_begin =
std::make_shared<default_opset::Convert>(padding.at(0), element::i64);
......
......@@ -18,8 +18,8 @@
#include <vector>
#include "default_opset.hpp"
#include "ngraph/builder/split.hpp"
#include "split.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -33,28 +33,18 @@ namespace ngraph
{
const auto input = node.get_ng_inputs().at(0);
const auto axis = node.get_attribute_value<int64_t>("axis", 0);
const auto axis_node =
default_opset::Constant::create(element::i64, Shape{}, {axis});
std::shared_ptr<ngraph::Node> split;
if (node.has_attribute("split"))
{
const auto splits =
node.get_attribute_value<std::vector<std::size_t>>("split");
const auto split_lengths = default_opset::Constant::create(
element::u64, Shape{splits.size()}, splits);
split = std::make_shared<default_opset::VariadicSplit>(
input, axis_node, split_lengths);
return ngraph::builder::opset1::split(input, splits, axis);
}
else
{
const auto outputs_number = node.get_output_names().size();
split = std::make_shared<default_opset::Split>(
input, axis_node, outputs_number);
return ngraph::builder::opset1::split(input, outputs_number, axis);
}
return common::get_outputs(split);
}
} // namespace set_1
......
......@@ -18,8 +18,6 @@
#include "common.hpp"
#include "default_opset.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/opsets/opset0.hpp"
namespace ngraph
......@@ -51,24 +49,6 @@ namespace ngraph
static_cast<onnx::TensorProto_DataType>(onnx_type)));
}
ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node)
{
const auto outputs_number = node->get_output_size();
ngraph::NodeVector outputs(outputs_number);
for (int i = 0; i < outputs_number; ++i)
{
if (node->output(i).get_node_shared_ptr()->get_output_size() == 1)
{
outputs[i] = node->get_output_as_single_output_node(i);
}
else
{
outputs[i] = std::make_shared<ngraph::opset0::GetOutputElement>(node, i);
}
}
return outputs;
}
} // namespace common
} // namespace onnx_import
} // namespace ngraph
......@@ -68,13 +68,6 @@ namespace ngraph
return range;
}
/// \brief Return the outputs of the node as vector.
///
/// \param[in] node Node with multiple outputs.
///
/// \return Vector of outputs of input node.
ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node);
/// \brief Creates a shifted square identity matrix.
/// \note Shifting in the context of this operator means that
/// the matrix can be created with elements equal to 1 not only in the main
......
......@@ -78,7 +78,7 @@ namespace ngraph
///
/// \brief Factory class which generates sub-graphs for ONNX 'local' pooling
/// operators.
/// \note Kernel shape attribute is required
/// \note For a 'local' pooling operation, the kernel shape attribute is required
class LocalPoolingFactory : public PoolingFactory
{
public:
......@@ -89,7 +89,8 @@ namespace ngraph
///
/// \brief Factory class which generates sub-graphs for ONNX 'global' pooling
/// operators.
/// \note Kernel shape is calculated based on spatial dims
/// \note In a 'global' pooling operation, the kernel shape is calculated
/// based on spatial dims
class GlobalPoolingFactory : public PoolingFactory
{
public:
......
......@@ -37,6 +37,7 @@
#include "ngraph/pass/get_output_element_elimination.hpp"
#include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_compiled_function.hpp"
......@@ -148,6 +149,7 @@ void runtime::gpu::GPUCompiledFunction::compile()
#endif
pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::Opset0Downgrade>();
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
pass_manager.register_pass<ngraph::pass::ImplicitBroadcastElimination>();
pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment