Commit 68698d51 authored by mbencer's avatar mbencer

Code review remarks introduced

parent 768a8e13
......@@ -101,22 +101,22 @@ NodeVector builder::split(const Output<Node>& value, size_t split_parts, int axi
}
NodeVector builder::opset1::split(const Output<Node>& value,
const std::vector<size_t>& length_parts,
const std::vector<size_t>& split_lengths,
int64_t axis)
{
const auto axis_node = op::Constant::create(element::u64, Shape{}, {axis});
const auto length_parts_node =
op::Constant::create(element::u64, Shape{length_parts.size()}, length_parts);
const auto split_lengths_node =
ngraph::opset1::Constant::create(element::u64, Shape{split_lengths.size()}, split_lengths);
const auto variadic_split =
std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, length_parts_node);
std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node);
return get_outputs(variadic_split);
}
NodeVector builder::opset1::split(const Output<Node>& value, size_t split_parts, int64_t axis)
NodeVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis)
{
const auto axis_node = op::Constant::create(element::u64, Shape{}, {axis});
const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, split_parts);
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits);
return get_outputs(split);
}
......@@ -56,7 +56,7 @@ namespace ngraph
/// \brief Split value on specified axis into multiple parts.
///
/// \param[in] value The value to be split.
/// \param[in] length_parts The vector defining the lengths of each split part.
/// \param[in] split_lengths The vector defining the lengths of each split part.
/// \param[in] axis The axis we split input node on. Default value is zero
/// axis.
/// \note This implementation supports negative `axis` values (similar to NumPy
......@@ -67,13 +67,13 @@ namespace ngraph
/// The vector is output of Split:v1 op
///
NodeVector split(const Output<Node>& value,
const std::vector<size_t>& length_parts,
const std::vector<size_t>& split_lengths,
int64_t axis = 0);
/// \brief Split node on specified axis into multiple parts.
///
/// \param[in] value The value to split.
/// \param[in] split_parts The number of parts we want to split output at given
/// \param[in] num_splits The number of parts we want to split output at given
/// axis. The length of the axis to split must be divisible by
/// this value.
/// \param[in] axis The axis we split input node on. Default value is zero
......@@ -86,7 +86,7 @@ namespace ngraph
/// \return The vector containing multiple nodes we split input node into.
/// The vector is output of VariadicSplit:v1 op
///
NodeVector split(const Output<Node>& value, size_t split_parts, int64_t axis = 0);
NodeVector split(const Output<Node>& value, size_t num_splits, int64_t axis = 0);
}
} // namespace builder
} // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment