Unverified Commit d9242244 authored by Sang Ik Lee's avatar Sang Ik Lee Committed by GitHub

Merge pull request #4311 from NervanaSystems/mbencer/BuilderSplitV1

[ONNX] builder::split used by onnx importer should produce v1 ops
parents 6c0cf85a 9664b96f
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/builder/split.hpp" #include "ngraph/builder/split.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/opsets/opset1.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -45,6 +47,29 @@ namespace ...@@ -45,6 +47,29 @@ namespace
std::make_shared<op::Slice>(output, lower_bounds, upper_bounds) std::make_shared<op::Slice>(output, lower_bounds, upper_bounds)
->add_provenance_group_members_above({output})); ->add_provenance_group_members_above({output}));
} }
/// \brief Return the outputs of the node as vector.
///
/// \param[in] node Node with multiple outputs.
///
/// \return Vector of outputs of input node.
NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node)
{
const auto outputs_number = node->get_output_size();
ngraph::NodeVector outputs(outputs_number);
for (int i = 0; i < outputs_number; ++i)
{
if (node->output(i).get_node_shared_ptr()->get_output_size() == 1)
{
outputs[i] = node->get_output_as_single_output_node(i);
}
else
{
outputs[i] = std::make_shared<op::GetOutputElement>(node, i);
}
}
return outputs;
}
} }
NodeVector builder::split(const Output<ngraph::Node>& value, NodeVector builder::split(const Output<ngraph::Node>& value,
...@@ -74,3 +99,24 @@ NodeVector builder::split(const Output<Node>& value, size_t split_parts, int axi ...@@ -74,3 +99,24 @@ NodeVector builder::split(const Output<Node>& value, size_t split_parts, int axi
std::vector<size_t> length_parts(split_parts, length_axis_to_split / split_parts); std::vector<size_t> length_parts(split_parts, length_axis_to_split / split_parts);
return split(value, length_parts, axis_to_split); return split(value, length_parts, axis_to_split);
} }
NodeVector builder::opset1::split(const Output<Node>& value,
const std::vector<size_t>& split_lengths,
int64_t axis)
{
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
const auto split_lengths_node =
ngraph::opset1::Constant::create(element::u64, Shape{split_lengths.size()}, split_lengths);
const auto variadic_split =
std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node);
return get_outputs(variadic_split);
}
NodeVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis)
{
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits);
return get_outputs(split);
}
...@@ -25,9 +25,9 @@ namespace ngraph ...@@ -25,9 +25,9 @@ namespace ngraph
{ {
/// \brief Split value on specified axis into multiple parts. /// \brief Split value on specified axis into multiple parts.
/// ///
/// \param[in] value The value to be split. /// \param value The value to be split.
/// \param[in] length_parts The vector defining the lengths of each split part. /// \param length_parts The vector defining the lengths of each split part.
/// \param[in] axis The axis we split input node on. Default value is zero axis. /// \param axis The axis we split input node on. Default value is zero axis.
/// ///
/// \return The vector containing multiple nodes we split input node into. /// \return The vector containing multiple nodes we split input node into.
/// ///
...@@ -37,11 +37,11 @@ namespace ngraph ...@@ -37,11 +37,11 @@ namespace ngraph
/// \brief Split node on specified axis into multiple parts. /// \brief Split node on specified axis into multiple parts.
/// ///
/// \param[in] value The value to split. /// \param value The value to split.
/// \param[in] split_parts The number of parts we want to split output at given /// \param split_parts The number of parts we want to split output at given
/// axis. The length of the axis to split must be divisible by /// axis. The length of the axis to split must be divisible by
/// this value. /// this value.
/// \param[in] axis The axis we split input node on. Default value is zero axis. /// \param axis The axis we split input node on. Default value is zero axis.
/// ///
/// \note This implementation supports negative `axis` values (similar to NumPy /// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing). This means that the axis to split on will be counted from /// indexing). This means that the axis to split on will be counted from
...@@ -50,5 +50,43 @@ namespace ngraph ...@@ -50,5 +50,43 @@ namespace ngraph
/// \return The vector containing multiple nodes we split input node into. /// \return The vector containing multiple nodes we split input node into.
/// ///
NodeVector split(const Output<Node>& value, size_t split_parts, int axis = 0); NodeVector split(const Output<Node>& value, size_t split_parts, int axis = 0);
namespace opset1
{
/// \brief Split value on specified axis into multiple parts.
///
/// \param value The value to be split.
/// \param split_lengths The vector defining the lengths of each split part.
/// \param axis The axis we split input node on. Default value is zero
/// axis.
/// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing). This means that the axis to split on will be counted from
/// the back of the tensor (negative values are subtracted from its rank).
///
/// \return The vector containing multiple nodes we split input node into.
/// The vector is output of Split:v1 op
///
NodeVector split(const Output<Node>& value,
const std::vector<size_t>& split_lengths,
int64_t axis = 0);
/// \brief Split value on specified axis into multiple parts.
///
/// \param value The value to split.
/// \param num_splits The number of parts we want to split output at given
/// axis. The length of the axis to split must be divisible by
/// this value.
/// \param axis The axis we split input node on. Default value is zero
/// axis.
///
/// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing). This means that the axis to split on will be counted from
/// the back of the tensor (negative values are subtracted from its rank).
///
/// \return The vector containing multiple nodes we split input node into.
/// The vector is output of VariadicSplit:v1 op
///
NodeVector split(const Output<Node>& value, size_t num_splits, int64_t axis = 0);
}
} // namespace builder } // namespace builder
} // namespace ngraph } // namespace ngraph
...@@ -45,7 +45,8 @@ namespace ngraph ...@@ -45,7 +45,8 @@ namespace ngraph
ASSERT_VALID_ARGUMENT(node, p_norm >= 0) ASSERT_VALID_ARGUMENT(node, p_norm >= 0)
<< "Only positive (including zero) values are supported for 'p' attribute."; << "Only positive (including zero) values are supported for 'p' attribute.";
NodeVector slices = ngraph::builder::split(data, channels_count, channel_axis); NodeVector slices =
ngraph::builder::opset1::split(data, channels_count, channel_axis);
for (auto& slice : slices) for (auto& slice : slices)
{ {
......
...@@ -91,7 +91,7 @@ namespace ngraph ...@@ -91,7 +91,7 @@ namespace ngraph
if (ng_inputs.size() > 3 && !ng_inputs.at(3)->is_null()) if (ng_inputs.size() > 3 && !ng_inputs.at(3)->is_null())
{ {
auto bias = ng_inputs.at(3); auto bias = ng_inputs.at(3);
auto split_bias = builder::split(bias, 2, 1); auto split_bias = builder::opset1::split(bias, 2, 1);
m_map[LSTMInput::LSTM_INPUT_B] = split_bias.at(0) + split_bias.at(1); m_map[LSTMInput::LSTM_INPUT_B] = split_bias.at(0) + split_bias.at(1);
} }
else else
......
...@@ -134,7 +134,7 @@ namespace ngraph ...@@ -134,7 +134,7 @@ namespace ngraph
{ {
auto axis = auto axis =
default_opset::Constant::create(element::i64, ngraph::Shape{}, {0}); default_opset::Constant::create(element::i64, ngraph::Shape{}, {0});
NodeVector padding = builder::split(pads, 2, 0); NodeVector padding = builder::opset1::split(pads, 2, 0);
padding_begin = padding_begin =
std::make_shared<default_opset::Convert>(padding.at(0), element::i64); std::make_shared<default_opset::Convert>(padding.at(0), element::i64);
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#include <vector> #include <vector>
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/builder/split.hpp"
#include "split.hpp" #include "split.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -33,28 +33,18 @@ namespace ngraph ...@@ -33,28 +33,18 @@ namespace ngraph
{ {
const auto input = node.get_ng_inputs().at(0); const auto input = node.get_ng_inputs().at(0);
const auto axis = node.get_attribute_value<int64_t>("axis", 0); const auto axis = node.get_attribute_value<int64_t>("axis", 0);
const auto axis_node =
default_opset::Constant::create(element::i64, Shape{}, {axis});
std::shared_ptr<ngraph::Node> split;
if (node.has_attribute("split")) if (node.has_attribute("split"))
{ {
const auto splits = const auto splits =
node.get_attribute_value<std::vector<std::size_t>>("split"); node.get_attribute_value<std::vector<std::size_t>>("split");
return ngraph::builder::opset1::split(input, splits, axis);
const auto split_lengths = default_opset::Constant::create(
element::u64, Shape{splits.size()}, splits);
split = std::make_shared<default_opset::VariadicSplit>(
input, axis_node, split_lengths);
} }
else else
{ {
const auto outputs_number = node.get_output_names().size(); const auto outputs_number = node.get_output_names().size();
split = std::make_shared<default_opset::Split>( return ngraph::builder::opset1::split(input, outputs_number, axis);
input, axis_node, outputs_number);
} }
return common::get_outputs(split);
} }
} // namespace set_1 } // namespace set_1
......
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
#include "common.hpp" #include "common.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/opsets/opset0.hpp" #include "ngraph/opsets/opset0.hpp"
namespace ngraph namespace ngraph
...@@ -51,24 +49,6 @@ namespace ngraph ...@@ -51,24 +49,6 @@ namespace ngraph
static_cast<onnx::TensorProto_DataType>(onnx_type))); static_cast<onnx::TensorProto_DataType>(onnx_type)));
} }
ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node)
{
const auto outputs_number = node->get_output_size();
ngraph::NodeVector outputs(outputs_number);
for (int i = 0; i < outputs_number; ++i)
{
if (node->output(i).get_node_shared_ptr()->get_output_size() == 1)
{
outputs[i] = node->get_output_as_single_output_node(i);
}
else
{
outputs[i] = std::make_shared<ngraph::opset0::GetOutputElement>(node, i);
}
}
return outputs;
}
} // namespace common } // namespace common
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -68,13 +68,6 @@ namespace ngraph ...@@ -68,13 +68,6 @@ namespace ngraph
return range; return range;
} }
/// \brief Return the outputs of the node as vector.
///
/// \param[in] node Node with multiple outputs.
///
/// \return Vector of outputs of input node.
ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node);
/// \brief Creates a shifted square identity matrix. /// \brief Creates a shifted square identity matrix.
/// \note Shifting in the context of this operator means that /// \note Shifting in the context of this operator means that
/// the matrix can be created with elements equal to 1 not only in the main /// the matrix can be created with elements equal to 1 not only in the main
......
...@@ -78,7 +78,7 @@ namespace ngraph ...@@ -78,7 +78,7 @@ namespace ngraph
/// ///
/// \brief Factory class which generates sub-graphs for ONNX 'local' pooling /// \brief Factory class which generates sub-graphs for ONNX 'local' pooling
/// operators. /// operators.
/// \note Kernel shape attribute is required /// \note For a 'local' pooling operation, the kernel shape attribute is required
class LocalPoolingFactory : public PoolingFactory class LocalPoolingFactory : public PoolingFactory
{ {
public: public:
...@@ -89,7 +89,8 @@ namespace ngraph ...@@ -89,7 +89,8 @@ namespace ngraph
/// ///
/// \brief Factory class which generates sub-graphs for ONNX 'global' pooling /// \brief Factory class which generates sub-graphs for ONNX 'global' pooling
/// operators. /// operators.
/// \note Kernel shape is calculated based on spatial dims /// \note In a 'global' pooling operation, the kernel shape is calculated
/// based on spatial dims
class GlobalPoolingFactory : public PoolingFactory class GlobalPoolingFactory : public PoolingFactory
{ {
public: public:
......
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include "ngraph/pass/get_output_element_elimination.hpp" #include "ngraph/pass/get_output_element_elimination.hpp"
#include "ngraph/pass/implicit_broadcast_elimination.hpp" #include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/pass/like_replacement.hpp" #include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp" #include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_compiled_function.hpp" #include "ngraph/runtime/gpu/gpu_compiled_function.hpp"
...@@ -148,6 +149,7 @@ void runtime::gpu::GPUCompiledFunction::compile() ...@@ -148,6 +149,7 @@ void runtime::gpu::GPUCompiledFunction::compile()
#endif #endif
pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>(); pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>(); pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::Opset0Downgrade>();
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>(); pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
pass_manager.register_pass<ngraph::pass::ImplicitBroadcastElimination>(); pass_manager.register_pass<ngraph::pass::ImplicitBroadcastElimination>();
pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this); pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment