Commit d3c2d772 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Make axis for TopK dynamic (#3526)

* Make axis for TopK dynamic

* set_input_is_relevant_to_shape

* Proper handling of output shape with dynamic top_k_axis

* fix ut

* remove unused line

* check is_constant

* add v0 version

* update provenance

* use set_argument

* fix clang
parent 7cc2a41f
......@@ -24,62 +24,110 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::TopK::type_info;
constexpr NodeTypeInfo op::v0::TopK::type_info;
op::TopK::TopK(const Output<Node>& arg,
op::v0::TopK::TopK(const Output<Node>& arg,
size_t top_k_axis,
const element::Type& index_element_type,
size_t k,
bool compute_max,
SortType sort)
: Op({arg, op::Constant::create(element::i64, Shape{1}, {k})->output(0)})
, m_top_k_axis(top_k_axis)
: Op({arg})
, m_index_element_type(index_element_type)
, m_compute_max(compute_max)
, m_sort(sort)
{
set_argument(1, op::Constant::create(element::i64, Shape{1}, {k})->output(0));
set_argument(2, op::Constant::create(element::i64, Shape{1}, {top_k_axis})->output(0));
add_provenance_group_member(input_value(1).get_node_shared_ptr());
add_provenance_group_member(input_value(2).get_node_shared_ptr());
constructor_validate_and_infer_types();
}
op::TopK::TopK(const Output<Node>& arg,
op::v0::TopK::TopK(const Output<Node>& arg,
const Output<Node>& k,
size_t top_k_axis,
const element::Type& index_element_type,
bool compute_max,
SortType sort)
: Op({arg, k})
, m_top_k_axis(top_k_axis)
, m_index_element_type(index_element_type)
, m_compute_max(compute_max)
, m_sort(sort)
{
set_argument(2, op::Constant::create(element::i64, Shape{1}, {top_k_axis})->output(0));
add_provenance_group_member(input_value(2).get_node_shared_ptr());
constructor_validate_and_infer_types();
}
size_t op::TopK::get_k() const
op::v0::TopK::TopK(const Output<Node>& arg,
const Output<Node>& k,
const Output<Node>& top_k_axis,
const element::Type& index_element_type,
bool compute_max,
SortType sort)
: Op({arg, k, top_k_axis})
, m_index_element_type(index_element_type)
, m_compute_max(compute_max)
, m_sort(sort)
{
constructor_validate_and_infer_types();
}
size_t op::v0::TopK::get_k() const
{
size_t k = 0;
if (auto const_op = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
{
k = const_op->get_vector<int64_t>()[0];
}
if (k == 0 && get_input_partial_shape(0).is_static())
Dimension top_k_axis = get_top_k_axis_dynamic();
if (k == 0 && get_input_partial_shape(0).is_static() && top_k_axis.is_static())
{
k = get_input_partial_shape(0).to_shape()[m_top_k_axis];
k = get_input_partial_shape(0).to_shape()[static_cast<size_t>(top_k_axis)];
}
return k;
}
void op::TopK::set_k(size_t k)
void op::v0::TopK::set_k(size_t k)
{
shared_ptr<Node> current_const =
get_input_size() == 1 ? nullptr : input_value(1).get_node_shared_ptr();
auto replacement_const = op::Constant::create(element::i64, Shape{1}, {k})->output(0);
this->input(1).replace_source_output(replacement_const);
replace_provenance_group_member(current_const, replacement_const.get_node_shared_ptr());
}
size_t op::v0::TopK::get_top_k_axis() const
{
auto d = get_top_k_axis_dynamic();
NGRAPH_CHECK(d.is_static(),
"get_top_k_axis called on a TopK node whose 'top_k_axis' input is not constant");
return static_cast<size_t>(d);
}
Dimension op::v0::TopK::get_top_k_axis_dynamic() const
{
auto const_op = dynamic_pointer_cast<op::Constant>(input_value(2).get_node_shared_ptr());
if (const_op)
{
return const_op->get_vector<int64_t>()[0];
}
else
{
return Dimension::dynamic();
}
}
void op::v0::TopK::set_top_k_axis(size_t top_k_axis)
{
shared_ptr<Node> current_const = input_value(2).get_node_shared_ptr();
auto replacement_const = op::Constant::create(element::i64, Shape{1}, {top_k_axis})->output(0);
this->input(2).replace_source_output(replacement_const);
replace_provenance_group_member(current_const, replacement_const.get_node_shared_ptr());
}
void op::TopK::validate_and_infer_types()
void op::v0::TopK::validate_and_infer_types()
{
const PartialShape& input_shape = get_input_partial_shape(0);
Rank input_rank = input_shape.rank();
......@@ -100,47 +148,83 @@ void op::TopK::validate_and_infer_types()
"Argument rank must be greater than 0.");
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || m_top_k_axis < static_cast<size_t>(input_rank),
get_input_element_type(1).compatible(element::i64),
"Element type for 'k' must be i64");
NODE_VALIDATION_CHECK(this,
get_input_element_type(2).compatible(element::i64),
"Element type for 'top_k_axis' must be i64");
Dimension top_k_axis = get_top_k_axis_dynamic();
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || top_k_axis.is_dynamic() ||
static_cast<size_t>(top_k_axis) < static_cast<size_t>(input_rank),
"TopK axis (",
m_top_k_axis,
top_k_axis,
") is out of bounds.");
size_t k = get_k();
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || input_shape[m_top_k_axis].is_dynamic() ||
k <= static_cast<size_t>(input_shape[m_top_k_axis]),
input_rank.is_dynamic() || top_k_axis.is_dynamic() ||
input_shape[static_cast<size_t>(top_k_axis)].is_dynamic() ||
static_cast<size_t>(k) <=
static_cast<size_t>(input_shape[static_cast<size_t>(top_k_axis)]),
"K (",
k,
") exceeds the dimension (",
(input_rank.is_static() ? input_shape[m_top_k_axis] : 0),
input_shape[static_cast<size_t>(top_k_axis)],
") of the TopK axis (axis ",
m_top_k_axis,
top_k_axis,
").");
PartialShape output_shape{input_shape};
if (input_rank.is_static() && k != 0)
if (input_rank.is_static())
{
if (top_k_axis.is_static())
{
output_shape[m_top_k_axis] = k;
if (k != 0)
{
output_shape[static_cast<size_t>(top_k_axis)] = k;
}
else if (k == 0 && output_shape[static_cast<size_t>(top_k_axis)].is_static())
{
output_shape[static_cast<size_t>(top_k_axis)] =
input_shape[static_cast<size_t>(top_k_axis)];
}
}
else
{
// If top_k_axis is not static and k is not 0, then we could be changing any
// dimension. So we have to change all dimensions to dynamic.
output_shape = PartialShape::dynamic(input_rank);
}
}
set_input_is_relevant_to_shape(2);
set_output_size(2);
set_output_type(0, m_index_element_type, output_shape);
set_output_type(1, input_element_type, output_shape);
}
shared_ptr<Node> op::TopK::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::TopK::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<TopK>(
new_args.at(0), new_args.at(1), m_top_k_axis, m_index_element_type, m_compute_max, m_sort);
return make_shared<TopK>(new_args.at(0),
new_args.at(1),
new_args.at(2),
m_index_element_type,
m_compute_max,
m_sort);
}
void op::TopK::generate_adjoints(autodiff::Adjoints& /* adjoints */, const NodeVector& /* deltas */)
void op::v0::TopK::generate_adjoints(autodiff::Adjoints& /* adjoints */,
const NodeVector& /* deltas */)
{
throw ngraph_error("Forward-propagation-only operation");
}
// v1 version starts
constexpr NodeTypeInfo op::v1::TopK::type_info;
op::v1::TopK::TopK(const Output<Node>& data,
......
......@@ -24,6 +24,8 @@
namespace ngraph
{
namespace op
{
namespace v0
{
// \brief Computes indices of top k maximum/minimum index along a specified axis for a
// given tensor
......@@ -76,6 +78,22 @@ namespace ngraph
bool compute_max = true,
SortType sort = SortType::SORT_VALUES);
/// \brief Constructs a TopK operation.
///
/// \param arg The input tensor
/// \param k Number of top indices to compute. Compute all indices if k = 0
/// \param top_k_axis The axis along which to compute top k indices
/// \param index_element_type produce indices. Currently, only int64 or int32 are
/// supported
/// \param compute_max Compute top k max or top k min?
/// \param sort SortType for sorting results, default - NONE
TopK(const Output<Node>& arg,
const Output<Node>& k,
const Output<Node>& top_k_axis,
const element::Type& index_element_type,
bool compute_max = true,
SortType sort = SortType::NONE);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......@@ -84,18 +102,21 @@ namespace ngraph
size_t get_k() const;
void set_k(size_t k);
size_t get_top_k_axis() const { return m_top_k_axis; }
size_t get_top_k_axis() const;
Dimension get_top_k_axis_dynamic() const;
void set_top_k_axis(size_t k);
element::Type get_index_element_type() const { return m_index_element_type; }
bool get_compute_max() const { return m_compute_max; }
SortType get_sort() const { return m_sort; }
protected:
size_t m_top_k_axis{0};
element::Type m_index_element_type;
bool m_compute_max{false};
SortType m_sort{SortType::NONE};
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v0
namespace v1
{
......@@ -191,6 +212,8 @@ namespace ngraph
template <typename T>
size_t validate_and_get_k(const std::shared_ptr<op::Constant>& k_constant) const;
};
}
}
}
} // namespace v1
using v0::TopK;
} // op
} // ngraph
......@@ -423,6 +423,12 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
case OP_TYPEID::TopK:
{
const auto topk_v0 = dynamic_cast<const op::TopK*>(node.get());
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(),
"parameter k is expected to be a static constant");
NGRAPH_CHECK(node->input_value(2).get_node_shared_ptr()->is_constant(),
"parameter top_k_axis is expected to be a static constant");
const auto k = topk_v0->get_k();
const auto axis = topk_v0->get_top_k_axis();
......
......@@ -32,6 +32,7 @@
#include "ngraph/op/product.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
......@@ -97,7 +98,6 @@ namespace ngraph
class Atan;
class ArgMin;
class ArgMax;
class TopK;
class GatherND;
class ScatterAdd;
class ScatterNDAdd;
......
......@@ -2163,11 +2163,28 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
{
if (op_version == 0)
{
auto top_k_axis = node_js.at("top_k_axis").get<size_t>();
auto k = node_js.at("k").get<size_t>();
auto compute_max = node_js.at("compute_max").get<bool>();
auto target_type = read_element_type(node_js.at("index_element_type"));
node = make_shared<op::TopK>(args[0], top_k_axis, target_type, k, compute_max);
if (has_key(node_js, "top_k_axis"))
{
auto top_k_axis = node_js.at("top_k_axis").get<size_t>();
if (has_key(node_js, "k"))
{
auto k = node_js.at("k").get<size_t>();
node =
make_shared<op::TopK>(args[0], top_k_axis, target_type, k, compute_max);
}
else
{
node = make_shared<op::TopK>(
args[0], args[1], top_k_axis, target_type, compute_max);
}
}
else
{
node =
make_shared<op::TopK>(args[0], args[1], args[2], target_type, compute_max);
}
}
else if (op_version == 1)
{
......@@ -3366,9 +3383,7 @@ json JSONSerializer::serialize_node(const Node& n)
if (op_version == 0)
{
const auto tmp = static_cast<const op::TopK*>(&n);
node["top_k_axis"] = tmp->get_top_k_axis();
node["index_element_type"] = write_element_type(tmp->get_index_element_type());
node["k"] = tmp->get_k();
node["compute_max"] = tmp->get_compute_max();
}
else if (op_version == 1)
......@@ -3379,7 +3394,6 @@ json JSONSerializer::serialize_node(const Node& n)
node["sort_type"] = tmp->get_sort_type();
node["index_element_type"] = write_element_type(tmp->get_index_element_type());
}
break;
}
case OP_TYPEID::Transpose: { break;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment