Unverified Commit 7018f9ca authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

Add axis normalization to Squeeze and Unsqueeze fused ops (#4389)

Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 01698d7a
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/squeeze.hpp" #include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/validation_util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -43,13 +44,15 @@ void op::Squeeze::pre_validate_and_infer_types() ...@@ -43,13 +44,15 @@ void op::Squeeze::pre_validate_and_infer_types()
return; return;
} }
auto data_shape = data.get_shape();
// Get value of axes from Constant // Get value of axes from Constant
auto axes_constant = as_type_ptr<op::Constant>(axes_node); auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->cast_vector<size_t>(); auto axes = normalize_axes(
auto data_shape = data.get_shape(); this->description(), axes_constant->cast_vector<int64_t>(), data_shape.size());
std::vector<uint64_t> axes_to_squeeze(data_shape.size());
// Prepare set of unique axes marked to be removed from input data. // Prepare set of unique axes marked to be removed from input data.
std::vector<uint64_t> axes_to_squeeze(data_shape.size());
if (axes.empty()) if (axes.empty())
{ {
// Default behaviour is to remove all single dimension axes. // Default behaviour is to remove all single dimension axes.
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/unsqueeze.hpp" #include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/validation_util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -34,18 +35,21 @@ op::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes) ...@@ -34,18 +35,21 @@ op::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes)
void op::Unsqueeze::pre_validate_and_infer_types() void op::Unsqueeze::pre_validate_and_infer_types()
{ {
auto data = input_value(0); const auto data = input_value(0);
auto axes_node = input_value(1).get_node_shared_ptr(); const auto axes_node = input_value(1).get_node_shared_ptr();
const auto data_rank = data.get_partial_shape().rank();
if (data.get_partial_shape().rank().is_dynamic() || !axes_node->is_constant()) if (data_rank.is_dynamic() || !axes_node->is_constant())
{ {
set_output_type(0, get_input_element_type(0), PartialShape::dynamic()); set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
return; return;
} }
// Get value of axes from Constant // Get value of axes from Constant
auto axes_constant = as_type_ptr<op::Constant>(axes_node); const auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->cast_vector<size_t>(); const auto axes_values = axes_constant->cast_vector<int64_t>();
const auto expanded_rank = static_cast<size_t>(data_rank) + axes_values.size();
auto axes = normalize_axes(this->description(), axes_values, expanded_rank);
NODE_VALIDATION_CHECK(this, !axes.empty(), "'axes' input is mandatory."); NODE_VALIDATION_CHECK(this, !axes.empty(), "'axes' input is mandatory.");
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
...@@ -61,7 +65,6 @@ void op::Unsqueeze::pre_validate_and_infer_types() ...@@ -61,7 +65,6 @@ void op::Unsqueeze::pre_validate_and_infer_types()
} }
auto data_shape = data.get_shape(); auto data_shape = data.get_shape();
sort(begin(axes), end(axes), less<int64_t>()); sort(begin(axes), end(axes), less<int64_t>());
AxisVector input_order{ngraph::get_default_order(data_shape.size())}; AxisVector input_order{ngraph::get_default_order(data_shape.size())};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment