Commit 14b9bab2 authored by tsocha's avatar tsocha Committed by Scott Cyphers

[ONNX] Add verification of duplicate in unsqueeze axes attribute (#2668)

parent 36cf8fe0
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
//***************************************************************************** //*****************************************************************************
#include <numeric> #include <numeric>
#include <set>
#include <vector>
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
...@@ -34,9 +36,14 @@ namespace ngraph ...@@ -34,9 +36,14 @@ namespace ngraph
NodeVector inputs{node.get_ng_inputs()}; NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0); auto data = inputs.at(0);
auto data_shape = data->get_shape(); auto data_shape = data->get_shape();
auto axes = node.get_attribute_value<std::vector<int64_t>>("axes"); auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes");
ASSERT_VALID_ARGUMENT(node, !axes.empty()) << "'axes' attribute is mandatory."; ASSERT_VALID_ARGUMENT(node, !axes.empty()) << "'axes' attribute is mandatory.";
ASSERT_VALID_ARGUMENT(
node,
axes.size() ==
std::set<std::int64_t>(std::begin(axes), std::end(axes)).size())
<< "'axes' has a duplicate axis.";
std::sort(std::begin(axes), std::end(axes), std::less<int64_t>()); std::sort(std::begin(axes), std::end(axes), std::less<int64_t>());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment