Commit 759f79c0 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Implement partial shape/type validation for TopK (#1912)

parent 7246875e
...@@ -39,36 +39,48 @@ op::TopK::TopK(const shared_ptr<Node>& arg, ...@@ -39,36 +39,48 @@ op::TopK::TopK(const shared_ptr<Node>& arg,
void op::TopK::validate_and_infer_types() void op::TopK::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic()) const PartialShape& input_shape = get_input_partial_shape(0);
{ Rank input_rank = input_shape.rank();
return; element::Type input_element_type = get_input_element_type(0);
}
auto& input = get_inputs().at(0); NODE_VALIDATION_ASSERT(this, !m_index_element_type.is_dynamic())
auto rank = input.get_shape().size(); << "Argument element type must not be dynamic.";
NODE_VALIDATION_ASSERT(this, rank > 0) << "Input Tensor's rank must be greater than 0";
NODE_VALIDATION_ASSERT(this, m_top_k_axis < rank) << "TopK axis must be less than rank";
NODE_VALIDATION_ASSERT( NODE_VALIDATION_ASSERT(
this, m_index_element_type == element::i32 || m_index_element_type == element::i64) this, m_index_element_type == element::i32 || m_index_element_type == element::i64)
<< "Index element type must be i64 or i32"; << "Argument element type must be i64 or i32 (got " << m_index_element_type << ").";
NODE_VALIDATION_ASSERT(this, m_k <= input.get_shape()[m_top_k_axis])
<< "K should not exceed TopK axis length";
Shape input_shape = input.get_shape(); NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || static_cast<size_t>(input_rank) > 0)
Shape output_shape(input_shape); << "Argument rank must be greater than 0.";
if (m_k != 0)
{ NODE_VALIDATION_ASSERT(
output_shape[m_top_k_axis] = m_k; this, input_rank.is_dynamic() || m_top_k_axis < static_cast<size_t>(input_rank))
} << "TopK axis (" << m_top_k_axis << ") is out of bounds.";
else
NODE_VALIDATION_ASSERT(this,
input_rank.is_dynamic() || input_shape[m_top_k_axis].is_dynamic() ||
m_k <= static_cast<size_t>(input_shape[m_top_k_axis]))
<< "K (" << m_k << ") exceeds the dimension ("
<< (input_rank.is_static() ? input_shape[m_top_k_axis] : 0) << ") of the TopK axis (axis "
<< m_top_k_axis << ").";
PartialShape output_shape{input_shape};
if (input_rank.is_static())
{ {
m_k = input_shape[m_top_k_axis]; if (m_k != 0)
{
output_shape[m_top_k_axis] = m_k;
}
else if (input_shape[m_top_k_axis].is_static())
{
m_k = static_cast<size_t>(input_shape[m_top_k_axis]);
}
} }
set_output_size(2); set_output_size(2);
set_output_type(0, m_index_element_type, output_shape); set_output_type(0, m_index_element_type, output_shape);
set_output_type(1, input.get_element_type(), output_shape); set_output_type(1, input_element_type, output_shape);
} }
shared_ptr<Node> op::TopK::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::TopK::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -10289,7 +10289,7 @@ TEST(type_prop, topk_invalid_rank) ...@@ -10289,7 +10289,7 @@ TEST(type_prop, topk_invalid_rank)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "Input Tensor's rank must be greater than 0"); EXPECT_HAS_SUBSTRING(error.what(), "Argument rank must be greater than 0");
} }
catch (...) catch (...)
{ {
...@@ -10308,7 +10308,7 @@ TEST(type_prop, topk_invalid_top_k) ...@@ -10308,7 +10308,7 @@ TEST(type_prop, topk_invalid_top_k)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "TopK axis must be less than rank"); EXPECT_HAS_SUBSTRING(error.what(), "TopK axis (2) is out of bounds");
} }
catch (...) catch (...)
{ {
...@@ -10327,7 +10327,9 @@ TEST(type_prop, topk_invalid_index_type) ...@@ -10327,7 +10327,9 @@ TEST(type_prop, topk_invalid_index_type)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "Index element type must be i64 or i32"); EXPECT_HAS_SUBSTRING(
error.what(),
"Argument element type must be i64 or i32 (got element::Type{32, 1, 1, 0, \"float\"})");
} }
catch (...) catch (...)
{ {
...@@ -10346,7 +10348,8 @@ TEST(type_prop, topk_invalid_k) ...@@ -10346,7 +10348,8 @@ TEST(type_prop, topk_invalid_k)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "K should not exceed TopK axis length"); EXPECT_HAS_SUBSTRING(error.what(),
"K (3) exceeds the dimension (2) of the TopK axis (axis 0)");
} }
catch (...) catch (...)
{ {
...@@ -10354,6 +10357,247 @@ TEST(type_prop, topk_invalid_k) ...@@ -10354,6 +10357,247 @@ TEST(type_prop, topk_invalid_k)
} }
} }
TEST(type_prop, topk_rank_dynamic_ok)
{
element::Type arg_et{element::f32};
PartialShape arg_shape{PartialShape::dynamic()};
size_t top_k_axis = 22;
size_t k = 900;
element::Type result_et{element::i32};
bool compute_max = true;
auto param = make_shared<op::Parameter>(arg_et, arg_shape);
auto topk = make_shared<op::TopK>(param, top_k_axis, result_et, k, compute_max);
ASSERT_TRUE(topk->get_output_element_type(0) == element::i32);
ASSERT_TRUE(topk->get_output_element_type(1) == element::f32);
ASSERT_TRUE(topk->get_output_partial_shape(0).rank().is_dynamic());
ASSERT_TRUE(topk->get_output_partial_shape(1).rank().is_dynamic());
}
TEST(type_prop, topk_rank_dynamic_result_et_dynamic)
{
element::Type arg_et{element::f32};
PartialShape arg_shape{PartialShape::dynamic()};
size_t top_k_axis = 22;
size_t k = 900;
element::Type result_et{element::dynamic};
bool compute_max = true;
auto param = make_shared<op::Parameter>(arg_et, arg_shape);
try
{
auto topk = make_shared<op::TopK>(param, top_k_axis, result_et, k, compute_max);
FAIL() << "Dynamic result element type not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument element type must not be dynamic");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, topk_rank_dynamic_result_et_invalid)
{
element::Type arg_et{element::f32};
PartialShape arg_shape{PartialShape::dynamic()};
size_t top_k_axis = 22;
size_t k = 900;
element::Type result_et{element::f32};
bool compute_max = true;
auto param = make_shared<op::Parameter>(arg_et, arg_shape);
try
{
auto topk = make_shared<op::TopK>(param, top_k_axis, result_et, k, compute_max);
FAIL() << "Invalid result element type not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Argument element type must be i64 or i32 (got element::Type{32, 1, 1, 0, \"float\"})");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, topk_rank_static_dynamic_k_known_topk_dim_dynamic_ok)
{
element::Type arg_et{element::f32};
PartialShape arg_shape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
size_t top_k_axis = 1;
size_t k = 999;
element::Type result_et{element::i32};
bool compute_max = true;
auto param = make_shared<op::Parameter>(arg_et, arg_shape);
auto topk = make_shared<op::TopK>(param, top_k_axis, result_et, k, compute_max);
ASSERT_TRUE(topk->get_output_element_type(0) == element::i32);
ASSERT_TRUE(topk->get_output_element_type(1) == element::f32);
ASSERT_TRUE(topk->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), 999, Dimension::dynamic()}));
ASSERT_TRUE(topk->get_output_partial_shape(1).same_scheme(
PartialShape{Dimension::dynamic(), 999, Dimension::dynamic()}));
}
TEST(type_prop, topk_rank_static_dynamic_k_unknown_topk_dim_dynamic_ok)
{
element::Type arg_et{element::f32};
PartialShape arg_shape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
size_t top_k_axis = 1;
size_t k = 0;
element::Type result_et{element::i32};
bool compute_max = true;
auto param = make_shared<op::Parameter>(arg_et, arg_shape);
auto topk = make_shared<op::TopK>(param, top_k_axis, result_et, k, compute_max);
ASSERT_TRUE(topk->get_output_element_type(0) == element::i32);
ASSERT_TRUE(topk->get_output_element_type(1) == element::f32);
ASSERT_TRUE(topk->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
ASSERT_TRUE(topk->get_output_partial_shape(1).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, topk_rank_static_dynamic_axis_oob)
{
element::Type arg_et{element::f32};
PartialShape arg_shape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
size_t top_k_axis = 22;
size_t k = 900;
element::Type result_et{element::f32};
bool compute_max = true;
auto param = make_shared<op::Parameter>(arg_et, arg_shape);
try
{
auto topk = make_shared<op::TopK>(param, top_k_axis, result_et, k, compute_max);
FAIL() << "TopK axis out-of-bounds not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Argument element type must be i64 or i32 (got element::Type{32, 1, 1, 0, \"float\"})");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, topk_rank_static_dynamic_k_unknown_axis_oob)
{
element::Type arg_et{element::f32};
PartialShape arg_shape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
size_t top_k_axis = 22;
size_t k = 0;
element::Type result_et{element::f32};
bool compute_max = true;
auto param = make_shared<op::Parameter>(arg_et, arg_shape);
try
{
auto topk = make_shared<op::TopK>(param, top_k_axis, result_et, k, compute_max);
FAIL() << "TopK axis out-of-bounds not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Argument element type must be i64 or i32 (got element::Type{32, 1, 1, 0, \"float\"})");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, topk_rank_static_dynamic_k_known_too_big)
{
element::Type arg_et{element::f32};
PartialShape arg_shape{Dimension::dynamic(), 3, Dimension::dynamic()};
size_t top_k_axis = 1;
size_t k = 4;
element::Type result_et{element::f32};
bool compute_max = true;
auto param = make_shared<op::Parameter>(arg_et, arg_shape);
try
{
auto topk = make_shared<op::TopK>(param, top_k_axis, result_et, k, compute_max);
FAIL() << "Oversize K not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Argument element type must be i64 or i32 (got element::Type{32, 1, 1, 0, \"float\"})");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, topk_rank_static_dynamic_k_unknown_ok)
{
element::Type arg_et{element::f32};
PartialShape arg_shape{Dimension::dynamic(), 3, Dimension::dynamic()};
size_t top_k_axis = 1;
size_t k = 0;
element::Type result_et{element::i32};
bool compute_max = true;
auto param = make_shared<op::Parameter>(arg_et, arg_shape);
auto topk = make_shared<op::TopK>(param, top_k_axis, result_et, k, compute_max);
ASSERT_TRUE(topk->get_output_element_type(0) == element::i32);
ASSERT_TRUE(topk->get_output_element_type(1) == element::f32);
ASSERT_TRUE(topk->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()}));
ASSERT_TRUE(topk->get_output_partial_shape(1).same_scheme(
PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()}));
}
TEST(type_prop, topk_rank_static_dynamic_k_known_ok)
{
element::Type arg_et{element::f32};
PartialShape arg_shape{Dimension::dynamic(), 3, Dimension::dynamic()};
size_t top_k_axis = 1;
size_t k = 2;
element::Type result_et{element::i32};
bool compute_max = true;
auto param = make_shared<op::Parameter>(arg_et, arg_shape);
auto topk = make_shared<op::TopK>(param, top_k_axis, result_et, k, compute_max);
ASSERT_TRUE(topk->get_output_element_type(0) == element::i32);
ASSERT_TRUE(topk->get_output_element_type(1) == element::f32);
ASSERT_TRUE(topk->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), 2, Dimension::dynamic()}));
ASSERT_TRUE(topk->get_output_partial_shape(1).same_scheme(
PartialShape{Dimension::dynamic(), 2, Dimension::dynamic()}));
}
TEST(type_prop, param_partial_rank_dynamic) TEST(type_prop, param_partial_rank_dynamic)
{ {
auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic()); auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment