Commit cab1b7b0 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Implement TopK:v1 downgrade (#3789)

* implemented TopK:v1 downgrade

* Using v1 in onnx_importer; styles applied

* Fixed onnx_importer

* code review remarks introduced; fix clang errors

* Fixed problem with incorrect order of TopK outputs

* Code review remarks introduced

* Fixed output_order style
parent 53a6af8d
......@@ -18,6 +18,7 @@
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/shape.hpp"
......@@ -50,8 +51,8 @@ static std::shared_ptr<ngraph::Node> get_k(const ngraph::onnx_import::Node& node
/// \return Return the outputs of the TopK node.
static ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node)
{
std::shared_ptr<ngraph::Node> indices = std::make_shared<ngraph::op::GetOutputElement>(node, 0);
std::shared_ptr<ngraph::Node> values = std::make_shared<ngraph::op::GetOutputElement>(node, 1);
std::shared_ptr<ngraph::Node> values = std::make_shared<ngraph::op::GetOutputElement>(node, 0);
std::shared_ptr<ngraph::Node> indices = std::make_shared<ngraph::op::GetOutputElement>(node, 1);
return {values, indices};
}
......@@ -68,10 +69,16 @@ namespace ngraph
{
auto data = node.get_ng_inputs().at(0);
std::int64_t k{node.get_attribute_value<std::int64_t>("k")};
auto k_node = ngraph::op::Constant::create(element::i64, Shape{}, {k});
auto axis = get_axis(node);
std::shared_ptr<ngraph::Node> top_k =
std::make_shared<ngraph::op::TopK>(data, axis, element::i64, k);
std::shared_ptr<ngraph::Node> top_k = std::make_shared<ngraph::op::v1::TopK>(
data,
k_node,
axis,
ngraph::op::v1::TopK::Mode::MAX,
ngraph::op::v1::TopK::SortType::SORT_VALUES,
element::i64);
return get_outputs(top_k);
}
......@@ -85,8 +92,13 @@ namespace ngraph
auto k = get_k(node);
auto axis = get_axis(node);
std::shared_ptr<ngraph::Node> top_k =
std::make_shared<ngraph::op::TopK>(data, k, axis, element::i64);
std::shared_ptr<ngraph::Node> top_k = std::make_shared<ngraph::op::v1::TopK>(
data,
k,
axis,
ngraph::op::v1::TopK::Mode::MAX,
ngraph::op::v1::TopK::SortType::SORT_VALUES,
element::i64);
return get_outputs(top_k);
}
......@@ -106,12 +118,15 @@ namespace ngraph
const auto sorted = node.get_attribute_value<std::int64_t>("sorted", 1);
// Map attribute values to nGraph enums
const auto sort_type = sorted ? ngraph::op::v1::TopK::SortType::SORT_VALUES
: ngraph::op::v1::TopK::SortType::NONE;
const auto compute_max = static_cast<bool>(largest);
const auto sort_type = sorted ? ngraph::op::TopK::SortType::SORT_VALUES
: ngraph::op::TopK::SortType::NONE;
const auto mode = compute_max ? ngraph::op::v1::TopK::Mode::MAX
: ngraph::op::v1::TopK::Mode::MIN;
std::shared_ptr<ngraph::Node> top_k = std::make_shared<ngraph::op::TopK>(
data, k, axis, element::i64, compute_max, sort_type);
std::shared_ptr<ngraph::Node> top_k = std::make_shared<ngraph::op::v1::TopK>(
data, k, axis, mode, sort_type, element::i64);
return get_outputs(top_k);
}
......
......@@ -14,6 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include <numeric>
#include <unordered_map>
#include <unordered_set>
#include <vector>
......@@ -131,13 +132,21 @@ NodeVector ngraph::find_common_args(std::shared_ptr<Node> node1, std::shared_ptr
return common_args;
}
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
void ngraph::replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement,
const std::vector<int64_t>& output_order)
{
if (target->is_output())
{
throw ngraph_error("Result nodes cannot be replaced.");
}
NGRAPH_CHECK(target->get_output_size() == output_order.size(),
"Target output size: ",
target->get_output_size(),
" must be equal output_order size: ",
output_order.size());
NGRAPH_CHECK(!target->get_users().empty(),
"Attempted to replace unreachable node '",
*target,
......@@ -178,7 +187,7 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
{
for (auto& input : target->output(i).get_target_inputs())
{
input.replace_source_output(replacement->output(i));
input.replace_source_output(replacement->output(output_order[i]));
}
}
......@@ -186,6 +195,13 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
target->clear_control_dependents();
}
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
{
auto default_output_order = vector<int64_t>(target->get_output_size());
std::iota(default_output_order.begin(), default_output_order.end(), 0);
replace_node(target, replacement, default_output_order);
}
void ngraph::replace_nodes(
const std::shared_ptr<Function>& f,
const unordered_map<shared_ptr<op::Parameter>, shared_ptr<op::Parameter>>&
......
......@@ -81,6 +81,7 @@ namespace ngraph
///
/// \param target Node to be replaced.
/// \param replacement Node to replace `target` with.
/// \param output_order Vector determines order of replacement node's outputs.
///
/// This is primarily used in graph-rewriting passes. For example, we
/// might "fuse" two Concat operations as follows:
......@@ -209,6 +210,9 @@ namespace ngraph
/// auto new_N = N->copy_with_new_args(N->get_arguments());
/// shared_ptr<Node> M = make_shared<SomeUnaryOp>(new_N);
/// replace_node(N, M);
void replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement,
const std::vector<int64_t>& output_order);
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
/// \brief Replace multiple nodes in a function.
......
......@@ -32,15 +32,7 @@ namespace ngraph
class TopK : public Op
{
public:
enum class SortType
{
// Returned values are not sorted
NONE,
// Sort result based on element indices
SORT_INDICES,
// Sort result based on element values
SORT_VALUES,
};
using SortType = TopKSortType;
NGRAPH_API
static constexpr NodeTypeInfo type_info{"TopK", 0};
......@@ -125,12 +117,7 @@ namespace ngraph
class TopK : public Op
{
public:
enum class SortType
{
NONE,
SORT_INDICES,
SORT_VALUES,
};
using SortType = TopKSortType;
enum class Mode
{
......
......@@ -116,6 +116,16 @@ namespace ngraph
MAX
};
enum class TopKSortType
{
// Returned values are not sorted
NONE,
// Sort result based on element indices
SORT_INDICES,
// Sort result based on element values
SORT_VALUES,
};
/// \brief Implicit broadcast specification
struct AutoBroadcastSpec
{
......
......@@ -50,6 +50,7 @@
#include "ngraph/op/slice.hpp"
#include "ngraph/op/strided_slice.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/slice_plan.hpp"
......@@ -642,6 +643,33 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::TopK:
{
const auto tmp = as_type_ptr<op::v1::TopK>(node);
const auto axis = tmp->get_axis();
const auto sort_type = tmp->get_sort_type();
const auto index_elem_type = tmp->get_index_element_type();
bool comnpute_max;
switch (tmp->get_mode())
{
case op::v1::TopK::Mode::MAX: comnpute_max = true; break;
case op::v1::TopK::Mode::MIN: comnpute_max = false; break;
default: break;
}
const auto arg_node = node->input_value(0);
const auto k_node = node->input_value(1);
auto replacement_node = make_shared<op::v0::TopK>(
arg_node, k_node, axis, index_elem_type, comnpute_max, sort_type);
// values output will be 0, indices 1
vector<int64_t> output_order{1, 0};
replace_node(node, replacement_node, output_order);
modified = true;
break;
}
default: break;
}
#if defined(__clang__)
......
......@@ -603,7 +603,9 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
auto replacement_node =
make_shared<op::v1::TopK>(node->input_value(0), k_constant, axis, mode, sort);
replace_node(node, replacement_node);
// indices output will be 0, values 1
vector<int64_t> output_order{1, 0};
replace_node(node, replacement_node, output_order);
modified = true;
break;
}
......
......@@ -19,6 +19,7 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
......@@ -30,7 +31,7 @@ TEST(opset_transform, opset1_topk_upgrade_pass)
const size_t axis = 2;
const size_t k = 10;
const auto data = make_shared<op::Parameter>(element::i32, Shape{5, 10, 15});
const auto topk_v0 = make_shared<op::TopK>(data, axis, element::i32, k);
const auto topk_v0 = make_shared<op::v0::TopK>(data, axis, element::i32, k);
const auto result = make_shared<op::Result>(topk_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
......@@ -51,3 +52,34 @@ TEST(opset_transform, opset1_topk_upgrade_pass)
const auto values_out_element_type = topk_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, data->get_element_type());
}
TEST(opset_transform, opset1_topk_downgrade_pass)
{
const auto data = make_shared<op::Parameter>(element::i32, Shape{5, 10, 15});
const int32_t k = 10;
const auto k_node = op::Constant::create(element::i64, Shape{}, {k});
const size_t axis = 2;
const auto mode = op::v1::TopK::Mode::MAX;
const auto sort = op::v1::TopK::SortType::SORT_INDICES;
const auto elem_type = element::i64;
const auto topk_v1 = make_shared<op::v1::TopK>(data, k_node, axis, mode, sort, elem_type);
const auto result = make_shared<op::Result>(topk_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto topk_v0 = as_type_ptr<op::v0::TopK>(pass_replacement_node);
EXPECT_EQ(topk_v0->description(), "TopK");
EXPECT_EQ(topk_v0->get_version(), 0);
EXPECT_EQ(topk_v0->get_k(), k);
EXPECT_EQ(topk_v0->get_top_k_axis(), axis);
EXPECT_EQ(topk_v0->get_compute_max(), true);
EXPECT_EQ(topk_v0->get_sort(), op::v0::TopK::SortType::SORT_INDICES);
EXPECT_EQ(topk_v0->get_index_element_type(), elem_type);
}
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "gtest/gtest.h"
#include "util/type_prop.hpp"
#include "ngraph/ngraph.hpp"
......@@ -119,3 +120,59 @@ TEST(replace_node, replace_nodes)
ASSERT_EQ(z_replacement->input(0).get_source_output().get_node_shared_ptr(), x_replacement);
ASSERT_EQ(z_replacement->input(1).get_source_output().get_node_shared_ptr(), mul);
}
TEST(replace_node, replace_nodes_output_order)
{
auto data = make_shared<op::Parameter>(element::f16, Shape{4, 3});
auto topk_v0 = make_shared<op::v0::TopK>(data, 0, element::i32, 2, true);
auto topk_v1 = make_shared<op::v1::TopK>(data,
op::Constant::create(element::i32, Shape{}, {2}),
0,
op::v1::TopK::Mode::MAX,
op::v1::TopK::SortType::SORT_VALUES,
element::i32);
auto values = make_shared<op::GetOutputElement>(topk_v1, 0);
auto indices = make_shared<op::GetOutputElement>(topk_v1, 1);
ASSERT_EQ(values->input(0).get_source_output().get_element_type(), element::f16);
ASSERT_EQ(indices->input(0).get_source_output().get_element_type(), element::i32);
std::vector<int64_t> output_order{1, 0};
replace_node(topk_v1, topk_v0, output_order);
ASSERT_EQ(values->input(0).get_source_output().get_element_type(), element::f16);
ASSERT_EQ(indices->input(0).get_source_output().get_element_type(), element::i32);
}
TEST(replace_node, replace_nodes_output_order_incorrect_size)
{
auto data = make_shared<op::Parameter>(element::f16, Shape{4, 3});
auto topk_v0 = make_shared<op::v0::TopK>(data, 0, element::i32, 2, true);
auto topk_v1 = make_shared<op::v1::TopK>(data,
op::Constant::create(element::i32, Shape{}, {2}),
0,
op::v1::TopK::Mode::MAX,
op::v1::TopK::SortType::SORT_VALUES,
element::i32);
auto values = make_shared<op::GetOutputElement>(topk_v1, 0);
auto indices = make_shared<op::GetOutputElement>(topk_v1, 1);
std::vector<int64_t> output_order{2, 1, 0};
try
{
replace_node(topk_v1, topk_v0, output_order);
FAIL() << "Incorrect output order size exception not detected";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Target output size: "));
}
catch (...)
{
FAIL() << "Incorrect output order size exception not thrown for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment