Unverified Commit 77941168 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by GitHub

[SPEC][ONNX] Handle negative axis for TopK:v1, add dynamic shape support for…

[SPEC][ONNX] Handle negative axis for TopK:v1, add dynamic shape support for ONNX Arg Min/Max ops (#4291)

* First version

* Added support to no_keep_dims

* Excluded tests for PlaidML

* Code review remarks introduced

* Reduced TopK axis restrictions

* Added assert to TopK get_axis

* Added missing EOF

* Style applied

* Code review suggestions introduced

* Disable tests for GPU

* Code review remarks introduced
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 6e4ab0b7
......@@ -16,7 +16,6 @@
#include "utils/arg_min_max_factory.hpp"
#include "default_opset.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "ngraph/validation_util.hpp"
......@@ -29,12 +28,9 @@ namespace ngraph
{
ArgMinMaxFactory::ArgMinMaxFactory(const Node& node)
: m_keep_dims{node.get_attribute_value<std::int64_t>("keepdims", 1)}
, m_axis{node.get_attribute_value<std::int64_t>("axis", 0)}
{
m_input_node = node.get_ng_inputs().at(0);
const auto axis = node.get_attribute_value<std::int64_t>("axis", 0);
const auto data_rank = m_input_node->get_output_partial_shape(0).rank();
m_normalized_axis = ngraph::normalize_axis(node.get_description(), axis, data_rank);
}
std::shared_ptr<ngraph::Node> ArgMinMaxFactory::make_arg_max() const
......@@ -52,19 +48,18 @@ namespace ngraph
{
const auto k_node =
default_opset::Constant::create(ngraph::element::i64, Shape{}, {1});
const auto topk =
std::make_shared<default_opset::TopK>(m_input_node,
k_node,
m_normalized_axis,
mode,
default_opset::TopK::SortType::NONE);
const auto topk = std::make_shared<default_opset::TopK>(
m_input_node, k_node, m_axis, mode, default_opset::TopK::SortType::NONE);
const auto indices = std::make_shared<ngraph::opset0::GetOutputElement>(topk, 1);
if (m_keep_dims == 0)
{
const auto reshaped_indices = ngraph::builder::opset1::squeeze(
indices, {static_cast<std::size_t>(m_normalized_axis)});
const auto axis_to_remove =
default_opset::Constant::create(element::u64, Shape{}, {topk->get_axis()});
const auto reshaped_indices =
std::make_shared<default_opset::Squeeze>(indices, axis_to_remove);
return std::make_shared<default_opset::Convert>(reshaped_indices, element::i64);
}
return std::make_shared<default_opset::Convert>(indices, element::i64);
......
......@@ -49,7 +49,7 @@ namespace ngraph
const std::int64_t m_keep_dims;
std::shared_ptr<ngraph::Node> m_input_node;
std::int64_t m_normalized_axis;
std::int64_t m_axis;
};
} // namespace arg
......
......@@ -20,6 +20,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
......@@ -235,6 +236,7 @@ op::v1::TopK::TopK(const Output<Node>& data,
const element::Type& index_element_type)
: Op{{data, k}}
, m_axis{axis}
, m_normalized_axis{0}
, m_mode{mode_from_string(mode)}
, m_sort{sort_type_from_string(sort)}
, m_index_element_type{index_element_type}
......@@ -242,6 +244,8 @@ op::v1::TopK::TopK(const Output<Node>& data,
constructor_validate_and_infer_types();
}
static const std::uint64_t UNKNOWN_NORMALIZED_AXIS = std::numeric_limits<uint64_t>::max();
op::v1::TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
......@@ -250,6 +254,7 @@ op::v1::TopK::TopK(const Output<Node>& data,
const element::Type& index_element_type)
: Op{{data, k}}
, m_axis{axis}
, m_normalized_axis{UNKNOWN_NORMALIZED_AXIS}
, m_mode{mode}
, m_sort{sort}
, m_index_element_type{index_element_type}
......@@ -281,16 +286,10 @@ void op::v1::TopK::validate_and_infer_types()
if (output_shape.rank().is_static())
{
NODE_VALIDATION_CHECK(
this,
m_axis >= 0 && static_cast<size_t>(m_axis) < static_cast<size_t>(output_shape.rank()),
"TopK axis (",
m_axis,
") is out of bounds.");
m_normalized_axis = ngraph::normalize_axis(this, m_axis, output_shape.rank());
if (k != 0)
{
output_shape[m_axis] = k;
output_shape[m_normalized_axis] = k;
}
}
......@@ -299,6 +298,28 @@ void op::v1::TopK::validate_and_infer_types()
set_output_type(1, m_index_element_type, output_shape);
}
void op::v1::TopK::set_axis(const int64_t axis)
{
const auto input_rank = get_input_partial_shape(0).rank();
if (input_rank.is_static())
{
m_normalized_axis = ngraph::normalize_axis(this, axis, input_rank);
}
else
{
m_normalized_axis = UNKNOWN_NORMALIZED_AXIS;
}
m_axis = axis;
}
uint64_t op::v1::TopK::get_axis() const
{
NODE_VALIDATION_CHECK(
this, m_normalized_axis != UNKNOWN_NORMALIZED_AXIS, "Normalized axis of TopK is unknown");
return m_normalized_axis;
}
size_t op::v1::TopK::read_k_from_constant_node(const shared_ptr<Node>& node,
const element::Type& k_element_type) const
{
......@@ -403,7 +424,7 @@ size_t op::v1::TopK::get_k() const
if (k == 0 && get_input_partial_shape(0).is_static())
{
k = get_input_partial_shape(0).to_shape()[m_axis];
k = get_input_partial_shape(0).to_shape()[m_normalized_axis];
}
return k;
}
......
......@@ -161,8 +161,13 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override;
virtual size_t get_version() const override { return 1; }
size_t get_axis() const { return m_axis; }
void set_axis(const size_t axis) { m_axis = axis; }
/// \brief Returns axis value after normalization
/// \note If input rank required to normalization is dynamic, the exception is
/// thrown
uint64_t get_axis() const;
/// \brief Returns axis value before normalization
int64_t get_provided_axis() const { return m_axis; }
void set_axis(const int64_t axis);
Mode get_mode() const { return m_mode; }
void set_mode(const Mode mode) { m_mode = mode; }
SortType get_sort_type() const { return m_sort; }
......@@ -182,6 +187,7 @@ namespace ngraph
protected:
int64_t m_axis;
uint64_t m_normalized_axis;
Mode m_mode;
SortType m_sort;
element::Type m_index_element_type;
......
......@@ -328,6 +328,8 @@ model_softplus_infinity
model_sum_opset8
model_argmax_int32
model_argmin_int32
arg_max_dyn_shape
arg_min_no_keep_dims_dyn_shape
model_top_k
top_k_opset_10
top_k_opset_10_const_k
......
......@@ -283,6 +283,8 @@ model_argmin_int32
model_lp_norm_default
model_instance_normalization
model_round
arg_max_dyn_shape # unsupported op `TopK`
arg_min_no_keep_dims_dyn_shape # unsupported op `TopK`
# passing locally, fails closeness checks in CI which may be too strict
elu
......
......@@ -4624,7 +4624,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::TopK_v1:
{
const auto tmp = static_cast<const op::v1::TopK*>(&n);
node["axis"] = tmp->get_axis();
node["axis"] = tmp->get_provided_axis();
node["mode"] = tmp->get_mode();
node["sort_type"] = tmp->get_sort_type();
node["index_element_type"] = write_element_type(tmp->get_index_element_type());
......
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "data"
output: "reduced"
name: "node1"
op_type: "ArgMax"
attribute {
name: "keepdims"
i: 1
type: INT
}
attribute {
name: "axis"
i: -2
type: INT
}
doc_string: "ArgMax"
domain: ""
}
name: "test"
input {
name: "data"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
}
}
}
}
output {
name: "reduced"
type {
tensor_type {
elem_type: 7
shape {
}
}
}
}
}
opset_import {
version: 7
}
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "data"
output: "reduced"
name: "node1"
op_type: "ArgMin"
attribute {
name: "keepdims"
i: 0
type: INT
}
attribute {
name: "axis"
i: 0
type: INT
}
doc_string: "ArgMin"
domain: ""
}
name: "test"
input {
name: "data"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
}
}
}
}
output {
name: "reduced"
type {
tensor_type {
elem_type: 7
shape {
}
}
}
}
}
opset_import {
version: 7
}
......@@ -363,6 +363,46 @@ NGRAPH_TEST(onnx_dyn_shapes_${BACKEND_NAME}, global_max_pool_dyn_shape)
test_case.run();
}
NGRAPH_TEST(onnx_dyn_shapes_${BACKEND_NAME}, arg_max_dyn_shape)
{
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/dynamic_shapes/argmax_dyn.prototxt"));
auto test_case = NgraphTestCase(function, "${BACKEND_NAME}", BackendMode::DYNAMIC);
const Shape shape{3, 2, 2};
const auto elems_in_tensor = shape_size(shape);
std::vector<int32_t> input_values(elems_in_tensor);
std::iota(input_values.begin(), input_values.end(), 1);
test_case.add_input<int32_t>(shape, input_values);
std::vector<int64_t> expected_values{1, 1, 1, 1, 1, 1};
test_case.add_expected_output<int64_t>(Shape{3, 1, 2}, expected_values);
test_case.run();
}
NGRAPH_TEST(onnx_dyn_shapes_${BACKEND_NAME}, arg_min_no_keep_dims_dyn_shape)
{
const auto function = onnx_import::import_onnx_model(file_util::path_join(
SERIALIZED_ZOO, "onnx/dynamic_shapes/argmin_no_keep_dims_dyn.prototxt"));
auto test_case = NgraphTestCase(function, "${BACKEND_NAME}", BackendMode::DYNAMIC);
const Shape shape{3, 2, 2};
const auto elems_in_tensor = shape_size(shape);
std::vector<int32_t> input_values(elems_in_tensor);
std::iota(input_values.begin(), input_values.end(), 1);
test_case.add_input<int32_t>(shape, input_values);
std::vector<int64_t> expected_values{0, 0, 0, 0};
test_case.add_expected_output<int64_t>(Shape{2, 2}, expected_values);
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_constant_of_shape_float_zeros)
{
auto function = onnx_import::import_onnx_model(file_util::path_join(
......
......@@ -331,3 +331,38 @@ TEST(type_prop, topk_rank_static_dynamic_k_known_ok)
ASSERT_TRUE(topk->get_output_partial_shape(1).same_scheme(
PartialShape{Dimension::dynamic(), 2, Dimension::dynamic()}));
}
TEST(type_prop, topk_v1_negative_axis_support)
{
const auto data_shape = Shape{1, 2, 3, 4};
const auto data = make_shared<op::Parameter>(element::f32, data_shape);
const auto k = op::Constant::create(element::i64, Shape{}, {2});
const int64_t axis = -2;
const auto topk = make_shared<op::v1::TopK>(data, k, axis, "max", "value");
ASSERT_EQ(topk->get_provided_axis(), axis);
ASSERT_EQ(topk->get_axis(), data_shape.at(1));
}
TEST(type_prop, topk_v1_negative_axis_dynamic_rank)
{
const auto data_shape = PartialShape::dynamic();
const auto data = make_shared<op::Parameter>(element::f32, data_shape);
const auto k = op::Constant::create(element::i64, Shape{}, {2});
const int64_t axis = -2;
const auto topk = make_shared<op::v1::TopK>(data, k, axis, "max", "value");
try
{
topk->get_axis();
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Normalized axis of TopK is unknown"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment