Commit 7617d385 authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

[ONNX] Add support for ONNX 1.6 TopK (#3771)

* Minor cleanup

* Add support for ONNX 1.5 version of TopK

* Add unit tests

* Style apply

* Exclude failing tests

* Exclude failing tests

* Add support for ONNX 1.6 TopK attribures: larges and sorted

* Support for ONNX 1.6 TopK

* If k_node is a Constant, recreate as constant with Shape{}

* Extend `interpret_as_scalar` function

* Extend `interpret_as_scalar` function

* Remove merge artifact

* Add doc string

* Exclude failing tests

* Exclude failing tests

* Refactor function

* Remove unnecessary template param

* Use get_k function in OpSet 10 TopK

* Style apply

* Remove merge artifact

* Add tests for `interpret_as_scalar`

* Revert "Add tests for `interpret_as_scalar`"

This reverts commit 8b85965acb39c75ff9e66b06ad8f64df16e1a9da.
parent e741f8f1
......@@ -14,7 +14,7 @@
# limitations under the License.
# ******************************************************************************
set(ONNX_OPSET_VERSION 10 CACHE INTERNAL "Supported version of ONNX operator set")
set(ONNX_OPSET_VERSION 11 CACHE INTERNAL "Supported version of ONNX operator set")
add_library(onnx_import_interface OBJECT
core/node.cpp
......
......@@ -20,13 +20,15 @@
#include "ngraph/node.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "topk.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
/// \return Parse node attribute value for axis and adjust for negative value if needed.
static std::int64_t get_axis(const ngraph::onnx_import::Node& node)
{
// Parse node attribute value for axis (adjust for negative value if needed).
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
auto data = node.get_ng_inputs().at(0);
......@@ -34,11 +36,22 @@ static std::int64_t get_axis(const ngraph::onnx_import::Node& node)
return ngraph::onnx_import::common::validate_axis(node, axis, data_rank);
}
static ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& top_k)
/// \return Return the second input to the TopK node reshaped to a scalar.
static std::shared_ptr<ngraph::Node> get_k(const ngraph::onnx_import::Node& node)
{
std::shared_ptr<ngraph::Node> indices =
std::make_shared<ngraph::op::GetOutputElement>(top_k, 0);
std::shared_ptr<ngraph::Node> values = std::make_shared<ngraph::op::GetOutputElement>(top_k, 1);
auto k_node = node.get_ng_inputs().at(1);
NGRAPH_CHECK(shape_size(k_node->get_shape()) == 1,
"ONNX TopK operator: 'K' parameter must contain a single positive value.",
node);
return ngraph::onnx_import::reshape::interpret_as_scalar(k_node);
}
/// \return Return the outputs of the TopK node.
static ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node)
{
std::shared_ptr<ngraph::Node> indices = std::make_shared<ngraph::op::GetOutputElement>(node, 0);
std::shared_ptr<ngraph::Node> values = std::make_shared<ngraph::op::GetOutputElement>(node, 1);
return {values, indices};
}
......@@ -69,7 +82,7 @@ namespace ngraph
NodeVector topk(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
auto k = node.get_ng_inputs().at(1);
auto k = get_k(node);
auto axis = get_axis(node);
std::shared_ptr<ngraph::Node> top_k =
......@@ -79,6 +92,31 @@ namespace ngraph
}
}
namespace set_11
{
NodeVector topk(const Node& node)
{
// Process inputs
auto data = node.get_ng_inputs().at(0);
auto k = get_k(node);
// Process attributes
const auto axis = get_axis(node);
const auto largest = node.get_attribute_value<std::int64_t>("largest", 1);
const auto sorted = node.get_attribute_value<std::int64_t>("sorted", 1);
// Map attribute values to nGraph enums
const auto compute_max = static_cast<bool>(largest);
const auto sort_type = sorted ? ngraph::op::TopK::SortType::SORT_VALUES
: ngraph::op::TopK::SortType::NONE;
std::shared_ptr<ngraph::Node> top_k = std::make_shared<ngraph::op::TopK>(
data, k, axis, element::i64, compute_max, sort_type);
return get_outputs(top_k);
}
}
} // namespace op
} // namespace onnx_import
......
......@@ -44,6 +44,14 @@ namespace ngraph
NodeVector topk(const Node& node);
}
/// \brief Performs TopK operation from ONNX version 1.6
///
/// \details ONNX op set 11 added support for `largest` and `sorted` attributes.
namespace set_11
{
NodeVector topk(const Node& node);
}
} // namespace op
} // namespace onnx_import
......
......@@ -348,6 +348,7 @@ namespace ngraph
REGISTER_OPERATOR("ThresholdedRelu", 1, thresholded_relu);
REGISTER_OPERATOR("TopK", 1, topk);
REGISTER_OPERATOR("TopK", 10, topk);
REGISTER_OPERATOR("TopK", 11, topk);
REGISTER_OPERATOR("Transpose", 1, transpose);
REGISTER_OPERATOR("Unsqueeze", 1, unsqueeze);
REGISTER_OPERATOR("Where", 1, where);
......
......@@ -19,10 +19,9 @@
#include <iterator>
#include <numeric>
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/reshape.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
......@@ -100,6 +99,15 @@ namespace ngraph
"Scalar value can't be derived from a node with ",
node_shape);
// If node is a Constant, recreate as Constant with Shape{}
if (node->is_constant())
{
const auto value =
ngraph::as_type_ptr<ngraph::op::Constant>(node)->get_data_ptr();
return std::make_shared<ngraph::op::Constant>(
node->get_element_type(), ngraph::Shape{}, value);
}
return ngraph::builder::reshape(node, Shape{});
}
......
......@@ -23,7 +23,6 @@
#include "ngraph/axis_vector.hpp"
#include "ngraph/node.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
......@@ -33,7 +32,7 @@ namespace ngraph
{
/// \brief Infer `output_shape` dimension values.
///
/// \par Inferention rules
/// \par Inference rules
/// \li The input_shape may consist at most on -1 value. In this case the
/// value is inferred from the size of the tensor and the remaining
/// dimensions.
......@@ -44,7 +43,7 @@ namespace ngraph
/// \param[in] input_shape The input node shape.
/// \param[in] output_shape The requested output shape for the input node data.
///
/// \return A vector containig new, valid node shape.
/// \return A vector containing new, valid node shape.
///
std::vector<std::size_t> infer_dimensions(const std::string& node_name,
const std::vector<std::size_t>& input_shape,
......
......@@ -16,3 +16,4 @@ convert_bf16_float32
# ONNX TopK with dynamic K
top_k_opset_10
top_k_opset_11_const_k_smallest
......@@ -50,6 +50,7 @@ topk_2d_max_one_with_equal_values # No plans to implement TopK
model_top_k # No plans to implement TopK
top_k_opset_10 # No plans to implement TopK
top_k_opset_10_const_k # No plans to implement TopK
top_k_opset_11_const_k_smallest # No plans to implement TopK
# unsupported op: `Erf`
erf
......
ir_version: 5
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "k"
output: "values"
output: "indices"
op_type: "TopK"
attribute {
name: "axis"
i: 1
type: INT
}
attribute {
name: "largest"
i: 0
type: INT
}
attribute {
name: "sorted"
i: 1
type: INT
}
}
name: "test_top_k"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "k"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
}
}
}
}
initializer {
dims: 1
data_type: 7
int64_data: 3
name: "k"
}
output {
name: "values"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "indices"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 11
}
......@@ -1338,6 +1338,20 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, top_k_opset_10_const_k)
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, top_k_opset_11_const_k_smallest)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/top_k_opset_11_const_k_smallest.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>({0, 1, 2, 3, 4, 5, 6, 7, 11, 10, 9, 8});
test_case.add_expected_output<float>(Shape{3, 3}, {0, 1, 2, 4, 5, 6, 8, 9, 10}); // values
test_case.add_expected_output<std::int64_t>(Shape{3, 3},
{0, 1, 2, 0, 1, 2, 3, 2, 1}); // indices
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_sinh)
{
auto function =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment