Unverified Commit 7018f9ca authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

Add axis normalization to Squeeze and Unsqueeze fused ops (#4389)

Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 01698d7a
......@@ -20,6 +20,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
......@@ -43,13 +44,15 @@ void op::Squeeze::pre_validate_and_infer_types()
return;
}
auto data_shape = data.get_shape();
// Get value of axes from Constant
auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->cast_vector<size_t>();
auto data_shape = data.get_shape();
std::vector<uint64_t> axes_to_squeeze(data_shape.size());
auto axes = normalize_axes(
this->description(), axes_constant->cast_vector<int64_t>(), data_shape.size());
// Prepare set of unique axes marked to be removed from input data.
std::vector<uint64_t> axes_to_squeeze(data_shape.size());
if (axes.empty())
{
// Default behaviour is to remove all single dimension axes.
......
......@@ -20,6 +20,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
......@@ -34,18 +35,21 @@ op::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes)
void op::Unsqueeze::pre_validate_and_infer_types()
{
auto data = input_value(0);
auto axes_node = input_value(1).get_node_shared_ptr();
const auto data = input_value(0);
const auto axes_node = input_value(1).get_node_shared_ptr();
const auto data_rank = data.get_partial_shape().rank();
if (data.get_partial_shape().rank().is_dynamic() || !axes_node->is_constant())
if (data_rank.is_dynamic() || !axes_node->is_constant())
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
return;
}
// Get value of axes from Constant
auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->cast_vector<size_t>();
const auto axes_constant = as_type_ptr<op::Constant>(axes_node);
const auto axes_values = axes_constant->cast_vector<int64_t>();
const auto expanded_rank = static_cast<size_t>(data_rank) + axes_values.size();
auto axes = normalize_axes(this->description(), axes_values, expanded_rank);
NODE_VALIDATION_CHECK(this, !axes.empty(), "'axes' input is mandatory.");
NODE_VALIDATION_CHECK(this,
......@@ -61,7 +65,6 @@ void op::Unsqueeze::pre_validate_and_infer_types()
}
auto data_shape = data.get_shape();
sort(begin(axes), end(axes), less<int64_t>());
AxisVector input_order{ngraph::get_default_order(data_shape.size())};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment