Commit 69486262 authored by tsocha's avatar tsocha Committed by Scott Cyphers

[ONNX] Add verification of split attribute in split op (#2669)

* [ONNX] Add verification of split attribute in split op

* Style

* Update split.cpp

* Add verification of split length
parent 9248588c
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <cstdint>
#include <vector>
#include "exceptions.hpp"
#include "op/split.hpp" #include "op/split.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
...@@ -36,7 +40,7 @@ namespace ngraph ...@@ -36,7 +40,7 @@ namespace ngraph
{ {
} }
}; };
} } // namespace detail
struct OutOfRange : detail::Error struct OutOfRange : detail::Error
{ {
...@@ -84,18 +88,19 @@ namespace ngraph ...@@ -84,18 +88,19 @@ namespace ngraph
NodeVector split(const Node& node) NodeVector split(const Node& node)
{ {
std::shared_ptr<ngraph::Node> input = node.get_ng_inputs().at(0); std::shared_ptr<ngraph::Node> input = node.get_ng_inputs().at(0);
auto input_shape = input->get_shape();
std::size_t count_outputs{node.get_output_names().size()}; std::size_t count_outputs{node.get_output_names().size()};
int64_t axis{node.get_attribute_value<int64_t>("axis", 0)}; int64_t axis{node.get_attribute_value<int64_t>("axis", 0)};
std::size_t axis_to_split{static_cast<std::size_t>(axis)}; std::size_t axis_to_split{static_cast<std::size_t>(axis)};
if (axis < 0) if (axis < 0)
{ {
axis_to_split = input->get_shape().size() + axis; axis_to_split = input_shape.size() + axis;
} }
else if (axis_to_split >= input->get_shape().size()) else if (axis_to_split >= input_shape.size())
{ {
throw error::op::split::OutOfRange{node.get_name()}; throw error::op::split::OutOfRange{node.get_name()};
} }
std::size_t length_axis_to_split{input->get_shape().at(axis_to_split)}; std::size_t length_axis_to_split{input_shape.at(axis_to_split)};
std::vector<std::size_t> length_parts; std::vector<std::size_t> length_parts;
try try
{ {
...@@ -111,6 +116,15 @@ namespace ngraph ...@@ -111,6 +116,15 @@ namespace ngraph
length_parts.assign(count_outputs, length_axis_to_split / count_outputs); length_parts.assign(count_outputs, length_axis_to_split / count_outputs);
} }
std::size_t total_parts_length = 0;
for (auto length : length_parts)
{
ASSERT_VALID_ARGUMENT(node, length > 0)
<< "Invalid value in 'split' attribute";
total_parts_length += length;
}
ASSERT_VALID_ARGUMENT(node, total_parts_length == input_shape.at(axis_to_split))
<< "Cannot split using values in 'split' attribute";
return reshape::split(input, length_parts, axis_to_split); return reshape::split(input, length_parts, axis_to_split);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment