Commit b8266cab authored by Tomasz Socha's avatar Tomasz Socha Committed by Scott Cyphers

Hotfix for negative axes in unsqueeze op (#3705)

* Hotfix for negative axes in unsqueeze op

* Review fix I
parent 519963e4
......@@ -17,6 +17,7 @@
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/constant.hpp"
#include "squeeze.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -30,8 +31,11 @@ namespace ngraph
{
auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
const auto expanded_rank = data->get_shape().size() + axes.size();
std::vector<std::size_t> valid_axes =
common::validate_axes(node, axes, expanded_rank);
auto axes_node = std::make_shared<ngraph::op::Constant>(
element::i64, Shape{axes.size()}, axes);
element::i64, Shape{valid_axes.size()}, valid_axes);
return {std::make_shared<ngraph::op::Unsqueeze>(data, axes_node)};
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment