Commit 8492e5c1 authored by tsocha's avatar tsocha Committed by Michał Karzyński

[ONNX] TopK operator (#2359)

parent 9a075c46
......@@ -145,6 +145,8 @@ add_library(onnx_import STATIC
op/tanh.hpp
op/thresholded_relu.cpp
op/thresholded_relu.hpp
op/topk.cpp
op/topk.hpp
op/transpose.cpp
op/transpose.hpp
op/unsqueeze.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstdint>
#include <memory>
#include <vector>
#include "exceptions.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/type/element_type.hpp"
#include "topk.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector topk(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
std::int64_t k{node.get_attribute_value<std::int64_t>("k")};
auto num_dimensions = data->get_shape().size();
if (axis < 0)
{
axis += num_dimensions;
}
ASSERT_VALID_ARGUMENT(node, axis < num_dimensions)
<< "`axis` parameter is out of range: " << axis;
std::shared_ptr<ngraph::Node> top_k =
std::make_shared<ngraph::op::TopK>(data, axis, element::i64, k);
std::shared_ptr<ngraph::Node> indices =
std::make_shared<ngraph::op::GetOutputElement>(top_k, 0);
std::shared_ptr<ngraph::Node> values =
std::make_shared<ngraph::op::GetOutputElement>(top_k, 1);
return {values, indices};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
/// \brief Performs ONNX TopK operation.
///
/// \param node The ONNX node object representing this operation.
/// \return The vector containing Ngraph nodes producing output of ONNX TopK
/// operation(both values and indices).
NodeVector topk(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -94,6 +94,7 @@
#include "op/tan.hpp"
#include "op/tanh.hpp"
#include "op/thresholded_relu.hpp"
#include "op/topk.hpp"
#include "op/transpose.hpp"
#include "op/unsqueeze.hpp"
#include "op/xor.hpp"
......@@ -277,6 +278,7 @@ namespace ngraph
REGISTER_OPERATOR("Tan", 1, tan);
REGISTER_OPERATOR("Tanh", 1, tanh);
REGISTER_OPERATOR("ThresholdedRelu", 1, thresholded_relu);
REGISTER_OPERATOR("TopK", 1, topk);
REGISTER_OPERATOR("Transpose", 1, transpose);
REGISTER_OPERATOR("Unsqueeze", 1, unsqueeze);
REGISTER_OPERATOR("Xor", 1, logical_xor);
......
 backend-test:‰
1
xvaluesindices"TopK*
k *
axis 
test_top_kZ
x


b
values


b
indices


B
\ No newline at end of file
......@@ -1819,3 +1819,24 @@ TEST(onnx_${BACKEND_NAME}, model_space_to_depth_no_blocksize)
file_util::path_join(SERIALIZED_ZOO, "onnx/space_to_depth_no_blocksize.onnx")),
std::runtime_error);
}
TEST(onnx_${BACKEND_NAME}, model_top_k)
{
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/top_k.onnx"));
Inputs inputs;
inputs.emplace_back(std::vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
std::vector<float> expected_values_output{3, 2, 1, 7, 6, 5, 11, 10, 9};
std::vector<std::int64_t> expected_indices_output{3, 2, 1, 3, 2, 1, 3, 2, 1};
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> result_tensors =
prepare_and_run(function, inputs, "${BACKEND_NAME}");
std::vector<float> values_output = read_vector<float>(result_tensors.at(0));
std::vector<std::int64_t> indices_output = read_vector<std::int64_t>(result_tensors.at(1));
EXPECT_TRUE(test::all_close_f(expected_values_output, values_output));
EXPECT_TRUE(test::all_close(expected_indices_output, indices_output));
}
\ No newline at end of file
......@@ -127,8 +127,9 @@ void init_real_tv(ngraph::runtime::Tensor* tv, std::default_random_engine& engin
void random_init(ngraph::runtime::Tensor* tv, std::default_random_engine& engine);
template <typename T, typename T1 = T>
std::vector<std::vector<T1>> execute(const std::shared_ptr<ngraph::Function>& function,
template <typename T>
std::vector<std::shared_ptr<ngraph::runtime::Tensor>>
prepare_and_run(const std::shared_ptr<ngraph::Function>& function,
std::vector<std::vector<T>> args,
const std::string& backend_id)
{
......@@ -160,6 +161,16 @@ std::vector<std::vector<T1>> execute(const std::shared_ptr<ngraph::Function>& fu
auto handle = backend->compile(function);
backend->call_with_validate(handle, result_tensors, arg_tensors);
return result_tensors;
}
template <typename T, typename T1 = T>
std::vector<std::vector<T1>> execute(const std::shared_ptr<ngraph::Function>& function,
std::vector<std::vector<T>> args,
const std::string& backend_id)
{
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> result_tensors =
prepare_and_run(function, args, backend_id);
std::vector<std::vector<T1>> result_vectors;
for (auto rt : result_tensors)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment