Commit 68698d51 authored by mbencer's avatar mbencer

Code review remarks introduced

parent 768a8e13
...@@ -101,22 +101,22 @@ NodeVector builder::split(const Output<Node>& value, size_t split_parts, int axi ...@@ -101,22 +101,22 @@ NodeVector builder::split(const Output<Node>& value, size_t split_parts, int axi
} }
NodeVector builder::opset1::split(const Output<Node>& value, NodeVector builder::opset1::split(const Output<Node>& value,
const std::vector<size_t>& length_parts, const std::vector<size_t>& split_lengths,
int64_t axis) int64_t axis)
{ {
const auto axis_node = op::Constant::create(element::u64, Shape{}, {axis}); const auto axis_node = op::Constant::create(element::u64, Shape{}, {axis});
const auto length_parts_node = const auto split_lengths_node =
op::Constant::create(element::u64, Shape{length_parts.size()}, length_parts); ngraph::opset1::Constant::create(element::u64, Shape{split_lengths.size()}, split_lengths);
const auto variadic_split = const auto variadic_split =
std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, length_parts_node); std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node);
return get_outputs(variadic_split); return get_outputs(variadic_split);
} }
NodeVector builder::opset1::split(const Output<Node>& value, size_t split_parts, int64_t axis) NodeVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis)
{ {
const auto axis_node = op::Constant::create(element::u64, Shape{}, {axis}); const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, split_parts); const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits);
return get_outputs(split); return get_outputs(split);
} }
...@@ -56,7 +56,7 @@ namespace ngraph ...@@ -56,7 +56,7 @@ namespace ngraph
/// \brief Split value on specified axis into multiple parts. /// \brief Split value on specified axis into multiple parts.
/// ///
/// \param[in] value The value to be split. /// \param[in] value The value to be split.
/// \param[in] length_parts The vector defining the lengths of each split part. /// \param[in] split_lengths The vector defining the lengths of each split part.
/// \param[in] axis The axis we split input node on. Default value is zero /// \param[in] axis The axis we split input node on. Default value is zero
/// axis. /// axis.
/// \note This implementation supports negative `axis` values (similar to NumPy /// \note This implementation supports negative `axis` values (similar to NumPy
...@@ -67,13 +67,13 @@ namespace ngraph ...@@ -67,13 +67,13 @@ namespace ngraph
/// The vector is output of Split:v1 op /// The vector is output of Split:v1 op
/// ///
NodeVector split(const Output<Node>& value, NodeVector split(const Output<Node>& value,
const std::vector<size_t>& length_parts, const std::vector<size_t>& split_lengths,
int64_t axis = 0); int64_t axis = 0);
/// \brief Split node on specified axis into multiple parts. /// \brief Split node on specified axis into multiple parts.
/// ///
/// \param[in] value The value to split. /// \param[in] value The value to split.
/// \param[in] split_parts The number of parts we want to split output at given /// \param[in] num_splits The number of parts we want to split output at given
/// axis. The length of the axis to split must be divisible by /// axis. The length of the axis to split must be divisible by
/// this value. /// this value.
/// \param[in] axis The axis we split input node on. Default value is zero /// \param[in] axis The axis we split input node on. Default value is zero
...@@ -86,7 +86,7 @@ namespace ngraph ...@@ -86,7 +86,7 @@ namespace ngraph
/// \return The vector containing multiple nodes we split input node into. /// \return The vector containing multiple nodes we split input node into.
/// The vector is output of VariadicSplit:v1 op /// The vector is output of VariadicSplit:v1 op
/// ///
NodeVector split(const Output<Node>& value, size_t split_parts, int64_t axis = 0); NodeVector split(const Output<Node>& value, size_t num_splits, int64_t axis = 0);
} }
} // namespace builder } // namespace builder
} // namespace ngraph } // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment