Unverified Commit 56976f0c authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

Add support for ONNX 1.5 version of TopK (#3684)

parent d4d169f3
......@@ -17,7 +17,6 @@
#include <cstdint>
#include <memory>
#include "exceptions.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/topk.hpp"
......@@ -25,6 +24,25 @@
#include "topk.hpp"
#include "utils/common.hpp"
static std::int64_t get_axis(const ngraph::onnx_import::Node& node)
{
// Parse node attribute value for axis (adjust for negative value if needed).
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
auto data = node.get_ng_inputs().at(0);
auto data_rank = data->get_shape().size();
return ngraph::onnx_import::common::validate_axis(node, axis, data_rank);
}
static ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& top_k)
{
std::shared_ptr<ngraph::Node> indices =
std::make_shared<ngraph::op::GetOutputElement>(top_k, 0);
std::shared_ptr<ngraph::Node> values = std::make_shared<ngraph::op::GetOutputElement>(top_k, 1);
return {values, indices};
}
namespace ngraph
{
namespace onnx_import
......@@ -37,23 +55,29 @@ namespace ngraph
{
auto data = node.get_ng_inputs().at(0);
std::int64_t k{node.get_attribute_value<std::int64_t>("k")};
auto num_dimensions = data->get_shape().size();
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
std::int64_t valid_axis = common::validate_axis(node, axis, num_dimensions);
auto axis = get_axis(node);
std::shared_ptr<ngraph::Node> top_k =
std::make_shared<ngraph::op::TopK>(data, valid_axis, element::i64, k);
std::shared_ptr<ngraph::Node> indices =
std::make_shared<ngraph::op::GetOutputElement>(top_k, 0);
std::shared_ptr<ngraph::Node> values =
std::make_shared<ngraph::op::GetOutputElement>(top_k, 1);
std::make_shared<ngraph::op::TopK>(data, axis, element::i64, k);
return {values, indices};
return get_outputs(top_k);
}
}
} // namespace set_1
namespace set_10
{
NodeVector topk(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
auto k = node.get_ng_inputs().at(1);
auto axis = get_axis(node);
std::shared_ptr<ngraph::Node> top_k =
std::make_shared<ngraph::op::TopK>(data, k, axis, element::i64);
return get_outputs(top_k);
}
}
} // namespace op
......
......@@ -31,10 +31,18 @@ namespace ngraph
///
/// \param node The ONNX node object representing this operation.
/// \return The vector containing Ngraph nodes producing output of ONNX TopK
/// operation(both values and indices).
/// operation (both values and indices).
NodeVector topk(const Node& node);
}
} // namespace set_1
/// \brief Performs TopK operation from ONNX version 1.5
///
/// \details ONNX op set 10 added support for K as a dynamic input, not a static
/// attribute.
namespace set_10
{
NodeVector topk(const Node& node);
}
} // namespace op
......
......@@ -347,6 +347,7 @@ namespace ngraph
REGISTER_OPERATOR("Tanh", 1, tanh);
REGISTER_OPERATOR("ThresholdedRelu", 1, thresholded_relu);
REGISTER_OPERATOR("TopK", 1, topk);
REGISTER_OPERATOR("TopK", 10, topk);
REGISTER_OPERATOR("Transpose", 1, transpose);
REGISTER_OPERATOR("Unsqueeze", 1, unsqueeze);
REGISTER_OPERATOR("Where", 1, where);
......
......@@ -21,3 +21,6 @@ lrn_across_all_dims
lrn_across_nw
lrn_across_empty
lrn_6D_across_2_axes
# ONNX TopK with dynamic K
top_k_opset_10
......@@ -13,3 +13,6 @@ fake_quantize_with_clip_across_channels
# casting not supported on interpreter
convert_float32_bf16
convert_bf16_float32
# ONNX TopK with dynamic K
top_k_opset_10
......@@ -48,6 +48,8 @@ topk_max_sort_index # No plans to implement TopK
topk_min_sort_index # No plans to implement TopK
topk_2d_max_one_with_equal_values # No plans to implement TopK
model_top_k # No plans to implement TopK
top_k_opset_10 # No plans to implement TopK
top_k_opset_10_const_k # No plans to implement TopK
# unsupported op: `Erf`
erf
......
ir_version: 4
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "k"
output: "values"
output: "indices"
op_type: "TopK"
}
name: "test_top_k"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "k"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "values"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "indices"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 10
}
ir_version: 4
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "k"
output: "values"
output: "indices"
op_type: "TopK"
}
name: "test_top_k"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "k"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
}
}
}
}
initializer {
dims: 1
data_type: 7
int64_data: 3
name: "k"
}
output {
name: "values"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "indices"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 10
}
......@@ -1309,6 +1309,35 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_top_k)
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, top_k_opset_10)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/top_k_opset_10.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
test_case.add_input<int64_t>({3});
test_case.add_expected_output<float>(Shape{3, 3}, {3, 2, 1, 7, 6, 5, 11, 10, 9}); // values
test_case.add_expected_output<std::int64_t>(Shape{3, 3},
{3, 2, 1, 3, 2, 1, 3, 2, 1}); // indices
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, top_k_opset_10_const_k)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/top_k_opset_10_const_k.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
test_case.add_expected_output<float>(Shape{3, 3}, {3, 2, 1, 7, 6, 5, 11, 10, 9}); // values
test_case.add_expected_output<std::int64_t>(Shape{3, 3},
{3, 2, 1, 3, 2, 1, 3, 2, 1}); // indices
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_sinh)
{
auto function =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment