Commit cab1b7b0 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Implement TopK:v1 downgrade (#3789)

* implemented TopK:v1 downgrade

* Using v1 in onnx_importer; styles applied

* Fixed onnx_importer

* code review remarks introduced; fix clang errors

* Fixed problem with incorrect order of TopK outputs

* Code review remarks introduced

* Fixed output_order style
parent 53a6af8d
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/topk.hpp" #include "ngraph/op/topk.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
...@@ -50,8 +51,8 @@ static std::shared_ptr<ngraph::Node> get_k(const ngraph::onnx_import::Node& node ...@@ -50,8 +51,8 @@ static std::shared_ptr<ngraph::Node> get_k(const ngraph::onnx_import::Node& node
/// \return Return the outputs of the TopK node. /// \return Return the outputs of the TopK node.
static ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node) static ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node)
{ {
std::shared_ptr<ngraph::Node> indices = std::make_shared<ngraph::op::GetOutputElement>(node, 0); std::shared_ptr<ngraph::Node> values = std::make_shared<ngraph::op::GetOutputElement>(node, 0);
std::shared_ptr<ngraph::Node> values = std::make_shared<ngraph::op::GetOutputElement>(node, 1); std::shared_ptr<ngraph::Node> indices = std::make_shared<ngraph::op::GetOutputElement>(node, 1);
return {values, indices}; return {values, indices};
} }
...@@ -68,10 +69,16 @@ namespace ngraph ...@@ -68,10 +69,16 @@ namespace ngraph
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
std::int64_t k{node.get_attribute_value<std::int64_t>("k")}; std::int64_t k{node.get_attribute_value<std::int64_t>("k")};
auto k_node = ngraph::op::Constant::create(element::i64, Shape{}, {k});
auto axis = get_axis(node); auto axis = get_axis(node);
std::shared_ptr<ngraph::Node> top_k = std::shared_ptr<ngraph::Node> top_k = std::make_shared<ngraph::op::v1::TopK>(
std::make_shared<ngraph::op::TopK>(data, axis, element::i64, k); data,
k_node,
axis,
ngraph::op::v1::TopK::Mode::MAX,
ngraph::op::v1::TopK::SortType::SORT_VALUES,
element::i64);
return get_outputs(top_k); return get_outputs(top_k);
} }
...@@ -85,8 +92,13 @@ namespace ngraph ...@@ -85,8 +92,13 @@ namespace ngraph
auto k = get_k(node); auto k = get_k(node);
auto axis = get_axis(node); auto axis = get_axis(node);
std::shared_ptr<ngraph::Node> top_k = std::shared_ptr<ngraph::Node> top_k = std::make_shared<ngraph::op::v1::TopK>(
std::make_shared<ngraph::op::TopK>(data, k, axis, element::i64); data,
k,
axis,
ngraph::op::v1::TopK::Mode::MAX,
ngraph::op::v1::TopK::SortType::SORT_VALUES,
element::i64);
return get_outputs(top_k); return get_outputs(top_k);
} }
...@@ -106,12 +118,15 @@ namespace ngraph ...@@ -106,12 +118,15 @@ namespace ngraph
const auto sorted = node.get_attribute_value<std::int64_t>("sorted", 1); const auto sorted = node.get_attribute_value<std::int64_t>("sorted", 1);
// Map attribute values to nGraph enums // Map attribute values to nGraph enums
const auto sort_type = sorted ? ngraph::op::v1::TopK::SortType::SORT_VALUES
: ngraph::op::v1::TopK::SortType::NONE;
const auto compute_max = static_cast<bool>(largest); const auto compute_max = static_cast<bool>(largest);
const auto sort_type = sorted ? ngraph::op::TopK::SortType::SORT_VALUES const auto mode = compute_max ? ngraph::op::v1::TopK::Mode::MAX
: ngraph::op::TopK::SortType::NONE; : ngraph::op::v1::TopK::Mode::MIN;
std::shared_ptr<ngraph::Node> top_k = std::make_shared<ngraph::op::TopK>( std::shared_ptr<ngraph::Node> top_k = std::make_shared<ngraph::op::v1::TopK>(
data, k, axis, element::i64, compute_max, sort_type); data, k, axis, mode, sort_type, element::i64);
return get_outputs(top_k); return get_outputs(top_k);
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <numeric>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -131,13 +132,21 @@ NodeVector ngraph::find_common_args(std::shared_ptr<Node> node1, std::shared_ptr ...@@ -131,13 +132,21 @@ NodeVector ngraph::find_common_args(std::shared_ptr<Node> node1, std::shared_ptr
return common_args; return common_args;
} }
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement) void ngraph::replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement,
const std::vector<int64_t>& output_order)
{ {
if (target->is_output()) if (target->is_output())
{ {
throw ngraph_error("Result nodes cannot be replaced."); throw ngraph_error("Result nodes cannot be replaced.");
} }
NGRAPH_CHECK(target->get_output_size() == output_order.size(),
"Target output size: ",
target->get_output_size(),
" must be equal output_order size: ",
output_order.size());
NGRAPH_CHECK(!target->get_users().empty(), NGRAPH_CHECK(!target->get_users().empty(),
"Attempted to replace unreachable node '", "Attempted to replace unreachable node '",
*target, *target,
...@@ -178,7 +187,7 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re ...@@ -178,7 +187,7 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
{ {
for (auto& input : target->output(i).get_target_inputs()) for (auto& input : target->output(i).get_target_inputs())
{ {
input.replace_source_output(replacement->output(i)); input.replace_source_output(replacement->output(output_order[i]));
} }
} }
...@@ -186,6 +195,13 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re ...@@ -186,6 +195,13 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
target->clear_control_dependents(); target->clear_control_dependents();
} }
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
{
auto default_output_order = vector<int64_t>(target->get_output_size());
std::iota(default_output_order.begin(), default_output_order.end(), 0);
replace_node(target, replacement, default_output_order);
}
void ngraph::replace_nodes( void ngraph::replace_nodes(
const std::shared_ptr<Function>& f, const std::shared_ptr<Function>& f,
const unordered_map<shared_ptr<op::Parameter>, shared_ptr<op::Parameter>>& const unordered_map<shared_ptr<op::Parameter>, shared_ptr<op::Parameter>>&
......
...@@ -81,6 +81,7 @@ namespace ngraph ...@@ -81,6 +81,7 @@ namespace ngraph
/// ///
/// \param target Node to be replaced. /// \param target Node to be replaced.
/// \param replacement Node to replace `target` with. /// \param replacement Node to replace `target` with.
/// \param output_order Vector determines order of replacement node's outputs.
/// ///
/// This is primarily used in graph-rewriting passes. For example, we /// This is primarily used in graph-rewriting passes. For example, we
/// might "fuse" two Concat operations as follows: /// might "fuse" two Concat operations as follows:
...@@ -209,6 +210,9 @@ namespace ngraph ...@@ -209,6 +210,9 @@ namespace ngraph
/// auto new_N = N->copy_with_new_args(N->get_arguments()); /// auto new_N = N->copy_with_new_args(N->get_arguments());
/// shared_ptr<Node> M = make_shared<SomeUnaryOp>(new_N); /// shared_ptr<Node> M = make_shared<SomeUnaryOp>(new_N);
/// replace_node(N, M); /// replace_node(N, M);
void replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement,
const std::vector<int64_t>& output_order);
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement); void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
/// \brief Replace multiple nodes in a function. /// \brief Replace multiple nodes in a function.
......
...@@ -32,15 +32,7 @@ namespace ngraph ...@@ -32,15 +32,7 @@ namespace ngraph
class TopK : public Op class TopK : public Op
{ {
public: public:
enum class SortType using SortType = TopKSortType;
{
// Returned values are not sorted
NONE,
// Sort result based on element indices
SORT_INDICES,
// Sort result based on element values
SORT_VALUES,
};
NGRAPH_API NGRAPH_API
static constexpr NodeTypeInfo type_info{"TopK", 0}; static constexpr NodeTypeInfo type_info{"TopK", 0};
...@@ -125,12 +117,7 @@ namespace ngraph ...@@ -125,12 +117,7 @@ namespace ngraph
class TopK : public Op class TopK : public Op
{ {
public: public:
enum class SortType using SortType = TopKSortType;
{
NONE,
SORT_INDICES,
SORT_VALUES,
};
enum class Mode enum class Mode
{ {
......
...@@ -116,6 +116,16 @@ namespace ngraph ...@@ -116,6 +116,16 @@ namespace ngraph
MAX MAX
}; };
enum class TopKSortType
{
// Returned values are not sorted
NONE,
// Sort result based on element indices
SORT_INDICES,
// Sort result based on element values
SORT_VALUES,
};
/// \brief Implicit broadcast specification /// \brief Implicit broadcast specification
struct AutoBroadcastSpec struct AutoBroadcastSpec
{ {
......
...@@ -50,6 +50,7 @@ ...@@ -50,6 +50,7 @@
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/strided_slice.hpp" #include "ngraph/op/strided_slice.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/xor.hpp" #include "ngraph/op/xor.hpp"
#include "ngraph/pass/opset0_downgrade.hpp" #include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/slice_plan.hpp" #include "ngraph/slice_plan.hpp"
...@@ -642,6 +643,33 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node) ...@@ -642,6 +643,33 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::TopK:
{
const auto tmp = as_type_ptr<op::v1::TopK>(node);
const auto axis = tmp->get_axis();
const auto sort_type = tmp->get_sort_type();
const auto index_elem_type = tmp->get_index_element_type();
bool comnpute_max;
switch (tmp->get_mode())
{
case op::v1::TopK::Mode::MAX: comnpute_max = true; break;
case op::v1::TopK::Mode::MIN: comnpute_max = false; break;
default: break;
}
const auto arg_node = node->input_value(0);
const auto k_node = node->input_value(1);
auto replacement_node = make_shared<op::v0::TopK>(
arg_node, k_node, axis, index_elem_type, comnpute_max, sort_type);
// values output will be 0, indices 1
vector<int64_t> output_order{1, 0};
replace_node(node, replacement_node, output_order);
modified = true;
break;
}
default: break; default: break;
} }
#if defined(__clang__) #if defined(__clang__)
......
...@@ -603,7 +603,9 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node) ...@@ -603,7 +603,9 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
auto replacement_node = auto replacement_node =
make_shared<op::v1::TopK>(node->input_value(0), k_constant, axis, mode, sort); make_shared<op::v1::TopK>(node->input_value(0), k_constant, axis, mode, sort);
replace_node(node, replacement_node); // indices output will be 0, values 1
vector<int64_t> output_order{1, 0};
replace_node(node, replacement_node, output_order);
modified = true; modified = true;
break; break;
} }
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp" #include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp" #include "util/type_prop.hpp"
...@@ -30,7 +31,7 @@ TEST(opset_transform, opset1_topk_upgrade_pass) ...@@ -30,7 +31,7 @@ TEST(opset_transform, opset1_topk_upgrade_pass)
const size_t axis = 2; const size_t axis = 2;
const size_t k = 10; const size_t k = 10;
const auto data = make_shared<op::Parameter>(element::i32, Shape{5, 10, 15}); const auto data = make_shared<op::Parameter>(element::i32, Shape{5, 10, 15});
const auto topk_v0 = make_shared<op::TopK>(data, axis, element::i32, k); const auto topk_v0 = make_shared<op::v0::TopK>(data, axis, element::i32, k);
const auto result = make_shared<op::Result>(topk_v0); const auto result = make_shared<op::Result>(topk_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data}); auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
...@@ -51,3 +52,34 @@ TEST(opset_transform, opset1_topk_upgrade_pass) ...@@ -51,3 +52,34 @@ TEST(opset_transform, opset1_topk_upgrade_pass)
const auto values_out_element_type = topk_v1->output(0).get_element_type(); const auto values_out_element_type = topk_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, data->get_element_type()); EXPECT_EQ(values_out_element_type, data->get_element_type());
} }
TEST(opset_transform, opset1_topk_downgrade_pass)
{
const auto data = make_shared<op::Parameter>(element::i32, Shape{5, 10, 15});
const int32_t k = 10;
const auto k_node = op::Constant::create(element::i64, Shape{}, {k});
const size_t axis = 2;
const auto mode = op::v1::TopK::Mode::MAX;
const auto sort = op::v1::TopK::SortType::SORT_INDICES;
const auto elem_type = element::i64;
const auto topk_v1 = make_shared<op::v1::TopK>(data, k_node, axis, mode, sort, elem_type);
const auto result = make_shared<op::Result>(topk_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto topk_v0 = as_type_ptr<op::v0::TopK>(pass_replacement_node);
EXPECT_EQ(topk_v0->description(), "TopK");
EXPECT_EQ(topk_v0->get_version(), 0);
EXPECT_EQ(topk_v0->get_k(), k);
EXPECT_EQ(topk_v0->get_top_k_axis(), axis);
EXPECT_EQ(topk_v0->get_compute_max(), true);
EXPECT_EQ(topk_v0->get_sort(), op::v0::TopK::SortType::SORT_INDICES);
EXPECT_EQ(topk_v0->get_index_element_type(), elem_type);
}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "util/type_prop.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
...@@ -119,3 +120,59 @@ TEST(replace_node, replace_nodes) ...@@ -119,3 +120,59 @@ TEST(replace_node, replace_nodes)
ASSERT_EQ(z_replacement->input(0).get_source_output().get_node_shared_ptr(), x_replacement); ASSERT_EQ(z_replacement->input(0).get_source_output().get_node_shared_ptr(), x_replacement);
ASSERT_EQ(z_replacement->input(1).get_source_output().get_node_shared_ptr(), mul); ASSERT_EQ(z_replacement->input(1).get_source_output().get_node_shared_ptr(), mul);
} }
TEST(replace_node, replace_nodes_output_order)
{
auto data = make_shared<op::Parameter>(element::f16, Shape{4, 3});
auto topk_v0 = make_shared<op::v0::TopK>(data, 0, element::i32, 2, true);
auto topk_v1 = make_shared<op::v1::TopK>(data,
op::Constant::create(element::i32, Shape{}, {2}),
0,
op::v1::TopK::Mode::MAX,
op::v1::TopK::SortType::SORT_VALUES,
element::i32);
auto values = make_shared<op::GetOutputElement>(topk_v1, 0);
auto indices = make_shared<op::GetOutputElement>(topk_v1, 1);
ASSERT_EQ(values->input(0).get_source_output().get_element_type(), element::f16);
ASSERT_EQ(indices->input(0).get_source_output().get_element_type(), element::i32);
std::vector<int64_t> output_order{1, 0};
replace_node(topk_v1, topk_v0, output_order);
ASSERT_EQ(values->input(0).get_source_output().get_element_type(), element::f16);
ASSERT_EQ(indices->input(0).get_source_output().get_element_type(), element::i32);
}
TEST(replace_node, replace_nodes_output_order_incorrect_size)
{
auto data = make_shared<op::Parameter>(element::f16, Shape{4, 3});
auto topk_v0 = make_shared<op::v0::TopK>(data, 0, element::i32, 2, true);
auto topk_v1 = make_shared<op::v1::TopK>(data,
op::Constant::create(element::i32, Shape{}, {2}),
0,
op::v1::TopK::Mode::MAX,
op::v1::TopK::SortType::SORT_VALUES,
element::i32);
auto values = make_shared<op::GetOutputElement>(topk_v1, 0);
auto indices = make_shared<op::GetOutputElement>(topk_v1, 1);
std::vector<int64_t> output_order{2, 1, 0};
try
{
replace_node(topk_v1, topk_v0, output_order);
FAIL() << "Incorrect output order size exception not detected";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Target output size: "));
}
catch (...)
{
FAIL() << "Incorrect output order size exception not thrown for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment