Commit d3c2d772 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Make axis for TopK dynamic (#3526)

* Make axis for TopK dynamic

* set_input_is_relevant_to_shape

* Proper handling of output shape with dynamic top_k_axis

* fix ut

* remove unused line

* check is_constant

* add v0 version

* update provenance

* use set_argument

* fix clang
parent 7cc2a41f
...@@ -24,62 +24,110 @@ ...@@ -24,62 +24,110 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::TopK::type_info; constexpr NodeTypeInfo op::v0::TopK::type_info;
op::TopK::TopK(const Output<Node>& arg, op::v0::TopK::TopK(const Output<Node>& arg,
size_t top_k_axis, size_t top_k_axis,
const element::Type& index_element_type, const element::Type& index_element_type,
size_t k, size_t k,
bool compute_max, bool compute_max,
SortType sort) SortType sort)
: Op({arg, op::Constant::create(element::i64, Shape{1}, {k})->output(0)}) : Op({arg})
, m_top_k_axis(top_k_axis)
, m_index_element_type(index_element_type) , m_index_element_type(index_element_type)
, m_compute_max(compute_max) , m_compute_max(compute_max)
, m_sort(sort) , m_sort(sort)
{ {
set_argument(1, op::Constant::create(element::i64, Shape{1}, {k})->output(0));
set_argument(2, op::Constant::create(element::i64, Shape{1}, {top_k_axis})->output(0));
add_provenance_group_member(input_value(1).get_node_shared_ptr()); add_provenance_group_member(input_value(1).get_node_shared_ptr());
add_provenance_group_member(input_value(2).get_node_shared_ptr());
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::TopK::TopK(const Output<Node>& arg, op::v0::TopK::TopK(const Output<Node>& arg,
const Output<Node>& k, const Output<Node>& k,
size_t top_k_axis, size_t top_k_axis,
const element::Type& index_element_type, const element::Type& index_element_type,
bool compute_max, bool compute_max,
SortType sort) SortType sort)
: Op({arg, k}) : Op({arg, k})
, m_top_k_axis(top_k_axis)
, m_index_element_type(index_element_type) , m_index_element_type(index_element_type)
, m_compute_max(compute_max) , m_compute_max(compute_max)
, m_sort(sort) , m_sort(sort)
{ {
set_argument(2, op::Constant::create(element::i64, Shape{1}, {top_k_axis})->output(0));
add_provenance_group_member(input_value(2).get_node_shared_ptr());
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
size_t op::TopK::get_k() const op::v0::TopK::TopK(const Output<Node>& arg,
const Output<Node>& k,
const Output<Node>& top_k_axis,
const element::Type& index_element_type,
bool compute_max,
SortType sort)
: Op({arg, k, top_k_axis})
, m_index_element_type(index_element_type)
, m_compute_max(compute_max)
, m_sort(sort)
{
constructor_validate_and_infer_types();
}
size_t op::v0::TopK::get_k() const
{ {
size_t k = 0; size_t k = 0;
if (auto const_op = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr())) if (auto const_op = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
{ {
k = const_op->get_vector<int64_t>()[0]; k = const_op->get_vector<int64_t>()[0];
} }
if (k == 0 && get_input_partial_shape(0).is_static()) Dimension top_k_axis = get_top_k_axis_dynamic();
if (k == 0 && get_input_partial_shape(0).is_static() && top_k_axis.is_static())
{ {
k = get_input_partial_shape(0).to_shape()[m_top_k_axis]; k = get_input_partial_shape(0).to_shape()[static_cast<size_t>(top_k_axis)];
} }
return k; return k;
} }
void op::TopK::set_k(size_t k) void op::v0::TopK::set_k(size_t k)
{ {
shared_ptr<Node> current_const = shared_ptr<Node> current_const =
get_input_size() == 1 ? nullptr : input_value(1).get_node_shared_ptr(); get_input_size() == 1 ? nullptr : input_value(1).get_node_shared_ptr();
auto replacement_const = op::Constant::create(element::i64, Shape{1}, {k})->output(0); auto replacement_const = op::Constant::create(element::i64, Shape{1}, {k})->output(0);
this->input(1).replace_source_output(replacement_const);
replace_provenance_group_member(current_const, replacement_const.get_node_shared_ptr());
}
size_t op::v0::TopK::get_top_k_axis() const
{
auto d = get_top_k_axis_dynamic();
NGRAPH_CHECK(d.is_static(),
"get_top_k_axis called on a TopK node whose 'top_k_axis' input is not constant");
return static_cast<size_t>(d);
}
Dimension op::v0::TopK::get_top_k_axis_dynamic() const
{
auto const_op = dynamic_pointer_cast<op::Constant>(input_value(2).get_node_shared_ptr());
if (const_op)
{
return const_op->get_vector<int64_t>()[0];
}
else
{
return Dimension::dynamic();
}
}
void op::v0::TopK::set_top_k_axis(size_t top_k_axis)
{
shared_ptr<Node> current_const = input_value(2).get_node_shared_ptr();
auto replacement_const = op::Constant::create(element::i64, Shape{1}, {top_k_axis})->output(0);
this->input(2).replace_source_output(replacement_const);
replace_provenance_group_member(current_const, replacement_const.get_node_shared_ptr()); replace_provenance_group_member(current_const, replacement_const.get_node_shared_ptr());
} }
void op::TopK::validate_and_infer_types() void op::v0::TopK::validate_and_infer_types()
{ {
const PartialShape& input_shape = get_input_partial_shape(0); const PartialShape& input_shape = get_input_partial_shape(0);
Rank input_rank = input_shape.rank(); Rank input_rank = input_shape.rank();
...@@ -100,47 +148,83 @@ void op::TopK::validate_and_infer_types() ...@@ -100,47 +148,83 @@ void op::TopK::validate_and_infer_types()
"Argument rank must be greater than 0."); "Argument rank must be greater than 0.");
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || m_top_k_axis < static_cast<size_t>(input_rank), get_input_element_type(1).compatible(element::i64),
"Element type for 'k' must be i64");
NODE_VALIDATION_CHECK(this,
get_input_element_type(2).compatible(element::i64),
"Element type for 'top_k_axis' must be i64");
Dimension top_k_axis = get_top_k_axis_dynamic();
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || top_k_axis.is_dynamic() ||
static_cast<size_t>(top_k_axis) < static_cast<size_t>(input_rank),
"TopK axis (", "TopK axis (",
m_top_k_axis, top_k_axis,
") is out of bounds."); ") is out of bounds.");
size_t k = get_k(); size_t k = get_k();
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || input_shape[m_top_k_axis].is_dynamic() || input_rank.is_dynamic() || top_k_axis.is_dynamic() ||
k <= static_cast<size_t>(input_shape[m_top_k_axis]), input_shape[static_cast<size_t>(top_k_axis)].is_dynamic() ||
static_cast<size_t>(k) <=
static_cast<size_t>(input_shape[static_cast<size_t>(top_k_axis)]),
"K (", "K (",
k, k,
") exceeds the dimension (", ") exceeds the dimension (",
(input_rank.is_static() ? input_shape[m_top_k_axis] : 0), input_shape[static_cast<size_t>(top_k_axis)],
") of the TopK axis (axis ", ") of the TopK axis (axis ",
m_top_k_axis, top_k_axis,
")."); ").");
PartialShape output_shape{input_shape}; PartialShape output_shape{input_shape};
if (input_rank.is_static() && k != 0) if (input_rank.is_static())
{
if (top_k_axis.is_static())
{ {
output_shape[m_top_k_axis] = k; if (k != 0)
{
output_shape[static_cast<size_t>(top_k_axis)] = k;
}
else if (k == 0 && output_shape[static_cast<size_t>(top_k_axis)].is_static())
{
output_shape[static_cast<size_t>(top_k_axis)] =
input_shape[static_cast<size_t>(top_k_axis)];
}
}
else
{
// If top_k_axis is not static and k is not 0, then we could be changing any
// dimension. So we have to change all dimensions to dynamic.
output_shape = PartialShape::dynamic(input_rank);
}
} }
set_input_is_relevant_to_shape(2);
set_output_size(2); set_output_size(2);
set_output_type(0, m_index_element_type, output_shape); set_output_type(0, m_index_element_type, output_shape);
set_output_type(1, input_element_type, output_shape); set_output_type(1, input_element_type, output_shape);
} }
shared_ptr<Node> op::TopK::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::TopK::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<TopK>( return make_shared<TopK>(new_args.at(0),
new_args.at(0), new_args.at(1), m_top_k_axis, m_index_element_type, m_compute_max, m_sort); new_args.at(1),
new_args.at(2),
m_index_element_type,
m_compute_max,
m_sort);
} }
void op::TopK::generate_adjoints(autodiff::Adjoints& /* adjoints */, const NodeVector& /* deltas */) void op::v0::TopK::generate_adjoints(autodiff::Adjoints& /* adjoints */,
const NodeVector& /* deltas */)
{ {
throw ngraph_error("Forward-propagation-only operation"); throw ngraph_error("Forward-propagation-only operation");
} }
// v1 version starts
constexpr NodeTypeInfo op::v1::TopK::type_info; constexpr NodeTypeInfo op::v1::TopK::type_info;
op::v1::TopK::TopK(const Output<Node>& data, op::v1::TopK::TopK(const Output<Node>& data,
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{
namespace v0
{ {
// \brief Computes indices of top k maximum/minimum index along a specified axis for a // \brief Computes indices of top k maximum/minimum index along a specified axis for a
// given tensor // given tensor
...@@ -76,6 +78,22 @@ namespace ngraph ...@@ -76,6 +78,22 @@ namespace ngraph
bool compute_max = true, bool compute_max = true,
SortType sort = SortType::SORT_VALUES); SortType sort = SortType::SORT_VALUES);
/// \brief Constructs a TopK operation.
///
/// \param arg The input tensor
/// \param k Number of top indices to compute. Compute all indices if k = 0
/// \param top_k_axis The axis along which to compute top k indices
/// \param index_element_type produce indices. Currently, only int64 or int32 are
/// supported
/// \param compute_max Compute top k max or top k min?
/// \param sort SortType for sorting results, default - NONE
TopK(const Output<Node>& arg,
const Output<Node>& k,
const Output<Node>& top_k_axis,
const element::Type& index_element_type,
bool compute_max = true,
SortType sort = SortType::NONE);
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -84,18 +102,21 @@ namespace ngraph ...@@ -84,18 +102,21 @@ namespace ngraph
size_t get_k() const; size_t get_k() const;
void set_k(size_t k); void set_k(size_t k);
size_t get_top_k_axis() const { return m_top_k_axis; } size_t get_top_k_axis() const;
Dimension get_top_k_axis_dynamic() const;
void set_top_k_axis(size_t k);
element::Type get_index_element_type() const { return m_index_element_type; } element::Type get_index_element_type() const { return m_index_element_type; }
bool get_compute_max() const { return m_compute_max; } bool get_compute_max() const { return m_compute_max; }
SortType get_sort() const { return m_sort; } SortType get_sort() const { return m_sort; }
protected: protected:
size_t m_top_k_axis{0};
element::Type m_index_element_type; element::Type m_index_element_type;
bool m_compute_max{false}; bool m_compute_max{false};
SortType m_sort{SortType::NONE}; SortType m_sort{SortType::NONE};
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
}; };
} // namespace v0
namespace v1 namespace v1
{ {
...@@ -191,6 +212,8 @@ namespace ngraph ...@@ -191,6 +212,8 @@ namespace ngraph
template <typename T> template <typename T>
size_t validate_and_get_k(const std::shared_ptr<op::Constant>& k_constant) const; size_t validate_and_get_k(const std::shared_ptr<op::Constant>& k_constant) const;
}; };
} } // namespace v1
}
} using v0::TopK;
} // op
} // ngraph
...@@ -423,6 +423,12 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node) ...@@ -423,6 +423,12 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
case OP_TYPEID::TopK: case OP_TYPEID::TopK:
{ {
const auto topk_v0 = dynamic_cast<const op::TopK*>(node.get()); const auto topk_v0 = dynamic_cast<const op::TopK*>(node.get());
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(),
"parameter k is expected to be a static constant");
NGRAPH_CHECK(node->input_value(2).get_node_shared_ptr()->is_constant(),
"parameter top_k_axis is expected to be a static constant");
const auto k = topk_v0->get_k(); const auto k = topk_v0->get_k();
const auto axis = topk_v0->get_top_k_axis(); const auto axis = topk_v0->get_top_k_axis();
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp" #include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
...@@ -97,7 +98,6 @@ namespace ngraph ...@@ -97,7 +98,6 @@ namespace ngraph
class Atan; class Atan;
class ArgMin; class ArgMin;
class ArgMax; class ArgMax;
class TopK;
class GatherND; class GatherND;
class ScatterAdd; class ScatterAdd;
class ScatterNDAdd; class ScatterNDAdd;
......
...@@ -2163,11 +2163,28 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2163,11 +2163,28 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
{ {
if (op_version == 0) if (op_version == 0)
{ {
auto top_k_axis = node_js.at("top_k_axis").get<size_t>();
auto k = node_js.at("k").get<size_t>();
auto compute_max = node_js.at("compute_max").get<bool>(); auto compute_max = node_js.at("compute_max").get<bool>();
auto target_type = read_element_type(node_js.at("index_element_type")); auto target_type = read_element_type(node_js.at("index_element_type"));
node = make_shared<op::TopK>(args[0], top_k_axis, target_type, k, compute_max); if (has_key(node_js, "top_k_axis"))
{
auto top_k_axis = node_js.at("top_k_axis").get<size_t>();
if (has_key(node_js, "k"))
{
auto k = node_js.at("k").get<size_t>();
node =
make_shared<op::TopK>(args[0], top_k_axis, target_type, k, compute_max);
}
else
{
node = make_shared<op::TopK>(
args[0], args[1], top_k_axis, target_type, compute_max);
}
}
else
{
node =
make_shared<op::TopK>(args[0], args[1], args[2], target_type, compute_max);
}
} }
else if (op_version == 1) else if (op_version == 1)
{ {
...@@ -3366,9 +3383,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -3366,9 +3383,7 @@ json JSONSerializer::serialize_node(const Node& n)
if (op_version == 0) if (op_version == 0)
{ {
const auto tmp = static_cast<const op::TopK*>(&n); const auto tmp = static_cast<const op::TopK*>(&n);
node["top_k_axis"] = tmp->get_top_k_axis();
node["index_element_type"] = write_element_type(tmp->get_index_element_type()); node["index_element_type"] = write_element_type(tmp->get_index_element_type());
node["k"] = tmp->get_k();
node["compute_max"] = tmp->get_compute_max(); node["compute_max"] = tmp->get_compute_max();
} }
else if (op_version == 1) else if (op_version == 1)
...@@ -3379,7 +3394,6 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -3379,7 +3394,6 @@ json JSONSerializer::serialize_node(const Node& n)
node["sort_type"] = tmp->get_sort_type(); node["sort_type"] = tmp->get_sort_type();
node["index_element_type"] = write_element_type(tmp->get_index_element_type()); node["index_element_type"] = write_element_type(tmp->get_index_element_type());
} }
break; break;
} }
case OP_TYPEID::Transpose: { break; case OP_TYPEID::Transpose: { break;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment