Commit a2abc8ee authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Michał Karzyński

[SPEC] TopK::v1 implementation (#3588)

parent 8e8c18ac
......@@ -137,3 +137,191 @@ void op::TopK::generate_adjoints(autodiff::Adjoints& /* adjoints */, const NodeV
{
throw ngraph_error("Forward-propagation-only operation");
}
constexpr NodeTypeInfo op::v1::TopK::type_info;
op::v1::TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type)
: Op{{data, k}}
, m_axis{axis}
, m_mode{mode_from_string(mode)}
, m_sort{sort_type_from_string(sort)}
, m_index_element_type{index_element_type}
{
constructor_validate_and_infer_types();
}
op::v1::TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const Mode mode,
const SortType sort,
const element::Type& index_element_type)
: Op{{data, k}}
, m_axis{axis}
, m_mode{mode}
, m_sort{sort}
, m_index_element_type{index_element_type}
{
constructor_validate_and_infer_types();
}
void op::v1::TopK::validate_and_infer_types()
{
const auto& input_partial_shape = get_input_partial_shape(0);
const auto input_rank = input_partial_shape.rank();
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || static_cast<size_t>(input_rank) > 0,
"Input rank must be greater than 0.");
const auto& k_partial_shape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(
this, k_partial_shape.rank().compatible(0), "The 'K' input must be a scalar.");
size_t k = 0;
if (input_value(1).get_node_shared_ptr()->is_constant())
{
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(),
get_input_element_type(1));
}
PartialShape output_shape{input_partial_shape};
if (output_shape.rank().is_static())
{
NODE_VALIDATION_CHECK(
this,
m_axis >= 0 && static_cast<size_t>(m_axis) < static_cast<size_t>(output_shape.rank()),
"TopK axis (",
m_axis,
") is out of bounds.");
if (k != 0)
{
output_shape[m_axis] = k;
}
}
set_output_size(2);
set_output_type(0, get_input_element_type(0), output_shape);
set_output_type(1, m_index_element_type, output_shape);
}
size_t op::v1::TopK::read_k_from_constant_node(const shared_ptr<Node>& node,
const element::Type& k_element_type) const
{
NODE_VALIDATION_CHECK(this,
k_element_type == element::i8 || k_element_type == element::i32 ||
k_element_type == element::i64,
"K input element type must be i8, i32 or i64 (got ",
k_element_type,
").");
const auto k_constant = dynamic_pointer_cast<op::Constant>(node);
size_t k = 0;
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wswitch-enum"
#endif
switch (static_cast<element::Type_t>(k_element_type))
{
case element::Type_t::i8: k = validate_and_get_k<int8_t>(k_constant); break;
case element::Type_t::i32: k = validate_and_get_k<int32_t>(k_constant); break;
case element::Type_t::i64: k = validate_and_get_k<int64_t>(k_constant); break;
default: break;
}
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
return k;
}
template <typename T>
size_t op::v1::TopK::validate_and_get_k(const shared_ptr<op::Constant>& k_constant) const
{
const auto k_const_contents = k_constant->get_vector<T>();
NODE_VALIDATION_CHECK(this,
k_const_contents.size() == 1,
"Only one value (scalar) should be provided as the 'K' input to TopK",
" (got ",
k_const_contents.size(),
" elements).");
NODE_VALIDATION_CHECK(this,
k_const_contents[0] > 0,
"The value of 'K' must be a positive number.",
" (got ",
k_const_contents[0],
").");
return static_cast<size_t>(k_const_contents[0]);
}
void op::v1::TopK::generate_adjoints(autodiff::Adjoints& /*adjoints*/, const NodeVector& /*deltas*/)
{
throw ngraph_error("Forward-propagation-only operation");
}
shared_ptr<Node> op::v1::TopK::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
auto new_v1_topk =
make_shared<v1::TopK>(new_args.at(0), new_args.at(1), m_axis, m_mode, m_sort);
new_v1_topk->set_index_element_type(m_index_element_type);
return new_v1_topk;
}
op::v1::TopK::Mode op::v1::TopK::mode_from_string(const std::string& mode) const
{
static const std::map<std::string, Mode> allowed_values = {{"max", Mode::MAX},
{"min", Mode::MIN}};
NODE_VALIDATION_CHECK(this, allowed_values.count(mode) > 0, "Invalid 'mode' value passed in.");
return allowed_values.at(mode);
}
op::v1::TopK::SortType op::v1::TopK::sort_type_from_string(const std::string& sort) const
{
static const std::map<std::string, SortType> allowed_values = {
{"none", SortType::NONE},
{"index", SortType::SORT_INDICES},
{"value", SortType::SORT_VALUES}};
NODE_VALIDATION_CHECK(this, allowed_values.count(sort) > 0, "Invalid 'sort' value passed in.");
return allowed_values.at(sort);
}
size_t op::v1::TopK::get_k() const
{
size_t k = 0;
if (input_value(1).get_node_shared_ptr()->is_constant())
{
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(),
get_input_element_type(1));
}
if (k == 0 && get_input_partial_shape(0).is_static())
{
k = get_input_partial_shape(0).to_shape()[m_axis];
}
return k;
}
void op::v1::TopK::set_k(size_t k)
{
this->input(1).replace_source_output(
op::Constant::create(element::i64, Shape{}, {k})->output(0));
}
......@@ -96,5 +96,101 @@ namespace ngraph
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
namespace v1
{
/// \brief Computes indices and values of the k maximum/minimum values
/// for each slice along specified axis.
class TopK : public Op
{
public:
enum class SortType
{
NONE,
SORT_INDICES,
SORT_VALUES,
};
enum class Mode
{
MAX,
MIN
};
NGRAPH_API
static constexpr NodeTypeInfo type_info{"TopK", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a TopK operation
TopK() = default;
/// \brief Constructs a TopK operation with two outputs: values and indices.
/// By default the indices output is described by i32 data type.
///
/// \param data The input tensor
/// \param k Specifies how many maximum/minimum elements should be computed
/// (note: scalar input tensor)
/// \param axis The axis along which to compute top k indices
/// \param mode Specifies which operation (min or max) is used to select
/// the biggest element of two.
/// \param sort Specifies order of output elements and/or indices
/// Accepted values: none, index, value
/// \param index_element_type Specyfies type of produced indices
TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type = element::i32);
TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const Mode mode,
const SortType sort,
const element::Type& index_element_type = element::i32);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual size_t get_version() const override { return 1; }
size_t get_axis() const { return m_axis; }
void set_axis(const size_t axis) { m_axis = axis; }
Mode get_mode() const { return m_mode; }
void set_mode(const Mode mode) { m_mode = mode; }
SortType get_sort_type() const { return m_sort; }
void set_sort_type(const SortType sort) { m_sort = sort; }
element::Type get_index_element_type() const { return m_index_element_type; }
void set_index_element_type(const element::Type& index_element_type)
{
m_index_element_type = index_element_type;
}
/// \brief Returns the value of K, if available
///
/// \note If the second input to this op is a constant, the value is retrieved
/// and returned. If the input is not constant(dynamic) this method returns 0
size_t get_k() const;
void set_k(size_t k);
protected:
int64_t m_axis;
Mode m_mode;
SortType m_sort;
element::Type m_index_element_type;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
size_t read_k_from_constant_node(const std::shared_ptr<Node>& node,
const element::Type& k_element_type) const;
Mode mode_from_string(const std::string& mode) const;
SortType sort_type_from_string(const std::string& sort) const;
template <typename T>
size_t validate_and_get_k(const std::shared_ptr<op::Constant>& k_constant) const;
};
}
}
}
......@@ -30,6 +30,9 @@
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/topk.hpp"
#include <limits>
using namespace std;
using namespace ngraph;
......@@ -386,6 +389,38 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::TopK:
{
const auto topk_v0 = dynamic_cast<const op::TopK*>(node.get());
const auto k = topk_v0->get_k();
const auto axis = topk_v0->get_top_k_axis();
std::string sort;
switch (topk_v0->get_sort())
{
case op::TopK::SortType::SORT_INDICES: sort = "index"; break;
case op::TopK::SortType::SORT_VALUES: sort = "value"; break;
default: sort = "none"; break;
}
std::string mode;
if (topk_v0->get_compute_max())
{
mode = "max";
}
else
{
mode = "min";
}
const auto k_constant = op::Constant::create(element::i64, Shape{}, {k});
auto replacement_node =
make_shared<op::v1::TopK>(node->input_value(0), k_constant, axis, mode, sort);
replace_node(node, replacement_node);
modified = true;
break;
}
default: break;
}
......
......@@ -2146,11 +2146,24 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::TopK:
{
auto top_k_axis = node_js.at("top_k_axis").get<size_t>();
auto k = node_js.at("k").get<size_t>();
auto compute_max = node_js.at("compute_max").get<bool>();
auto target_type = read_element_type(node_js.at("index_element_type"));
node = make_shared<op::TopK>(args[0], top_k_axis, target_type, k, compute_max);
if (op_version == 0)
{
auto top_k_axis = node_js.at("top_k_axis").get<size_t>();
auto k = node_js.at("k").get<size_t>();
auto compute_max = node_js.at("compute_max").get<bool>();
auto target_type = read_element_type(node_js.at("index_element_type"));
node = make_shared<op::TopK>(args[0], top_k_axis, target_type, k, compute_max);
}
else if (op_version == 1)
{
const auto axis = node_js.at("axis").get<size_t>();
const auto mode = node_js.at("mode").get<op::v1::TopK::Mode>();
const auto sort_type = node_js.at("sort_type").get<op::v1::TopK::SortType>();
const auto index_element_type = read_element_type(node_js.at("index_element_type"));
auto topk = make_shared<op::v1::TopK>(args[0], args[1], axis, mode, sort_type);
topk->set_index_element_type(index_element_type);
node = move(topk);
}
break;
}
case OP_TYPEID::Transpose:
......@@ -3325,11 +3338,23 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::TopK:
{
auto tmp = dynamic_cast<const op::TopK*>(&n);
node["top_k_axis"] = tmp->get_top_k_axis();
node["index_element_type"] = write_element_type(tmp->get_index_element_type());
node["k"] = tmp->get_k();
node["compute_max"] = tmp->get_compute_max();
if (op_version == 0)
{
const auto tmp = dynamic_cast<const op::TopK*>(&n);
node["top_k_axis"] = tmp->get_top_k_axis();
node["index_element_type"] = write_element_type(tmp->get_index_element_type());
node["k"] = tmp->get_k();
node["compute_max"] = tmp->get_compute_max();
}
else if (op_version == 1)
{
const auto tmp = dynamic_cast<const op::v1::TopK*>(&n);
node["axis"] = tmp->get_axis();
node["mode"] = tmp->get_mode();
node["sort_type"] = tmp->get_sort_type();
node["index_element_type"] = write_element_type(tmp->get_index_element_type());
}
break;
}
case OP_TYPEID::Transpose: { break;
......
......@@ -77,6 +77,7 @@ set(SRC
opset_pass/reverse_opset_pass.cpp
opset_pass/softmax_opset_pass.cpp
opset_pass/sum_opset_pass.cpp
opset_pass/topk_opset_pass.cpp
partial_shape.cpp
pass.cpp
pass_liveness.cpp
......
......@@ -23,6 +23,7 @@
#include <string>
#include "gtest/gtest.h"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp"
......@@ -1140,3 +1141,29 @@ NGRAPH_TEST(${BACKEND_NAME}, topk_3d_single_output)
h0->call_with_validate({result0}, {a});
EXPECT_EQ((vector<int32_t>{2, 0, 1, 2, 1, 0, 0, 1}), read_vector<int32_t>(result0));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_v1_invalid_strings)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto k = op::Constant::create(element::i64, Shape{}, {1});
EXPECT_THROW(op::v1::TopK(data, k, 0, "invalid_mode", "max"), ngraph::NodeValidationFailure);
EXPECT_THROW(op::v1::TopK(data, k, 0, "index", "invalid_sort"), ngraph::NodeValidationFailure);
}
NGRAPH_TEST(${BACKEND_NAME}, topk_v1_invalid_k)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
// K must be a scalar
const auto k_non_scalar = op::Constant::create(element::i64, Shape{2}, {1, 2});
EXPECT_THROW(op::v1::TopK(data, k_non_scalar, 0, "index", "max"),
ngraph::NodeValidationFailure);
// K can only be i8, i32 or i64
const auto k_float = op::Constant::create(element::f32, Shape{}, {1.0f});
EXPECT_THROW(op::v1::TopK(data, k_float, 0, "index", "max"), ngraph::NodeValidationFailure);
// the value of K must be positive
const auto k_negative = op::Constant::create(element::i8, Shape{}, {-1});
EXPECT_THROW(op::v1::TopK(data, k_negative, 0, "index", "max"), ngraph::NodeValidationFailure);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(serialize, opset1_topk_pass)
{
const size_t axis = 2;
const size_t k = 10;
const auto data = make_shared<op::Parameter>(element::i32, Shape{5, 10, 15});
const auto topk_v0 = make_shared<op::TopK>(data, axis, element::i32, k);
const auto result = make_shared<op::Result>(topk_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto topk_v1 = static_pointer_cast<op::v1::TopK>(pass_replacement_node);
EXPECT_EQ(topk_v1->get_axis(), axis);
EXPECT_EQ(topk_v1->description(), "TopK");
EXPECT_EQ(topk_v1->get_version(), 1);
EXPECT_EQ(topk_v1->get_mode(), op::v1::TopK::Mode::MAX);
EXPECT_EQ(topk_v1->get_sort_type(), op::v1::TopK::SortType::SORT_VALUES);
const auto values_out_element_type = topk_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, data->get_element_type());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment