Commit fe474394 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Michał Karzyński

Changed expected shape of end_token (#3983)

parent 33a6a7d0
...@@ -68,8 +68,8 @@ void op::v1::GatherTree::validate_and_infer_types() ...@@ -68,8 +68,8 @@ void op::v1::GatherTree::validate_and_infer_types()
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
end_token_rank.rank().is_dynamic() || end_token_rank.rank().is_dynamic() ||
static_cast<size_t>(end_token_rank.rank()) == 3, static_cast<size_t>(end_token_rank.rank()) == 0,
"end_token input rank must equal to 3 (end_token rank: ", "end_token input rank must be scalar (end_token rank: ",
static_cast<size_t>(end_token_rank.rank()), static_cast<size_t>(end_token_rank.rank()),
")"); ")");
......
...@@ -26,7 +26,7 @@ TEST(type_prop, gather_tree_output_shape) ...@@ -26,7 +26,7 @@ TEST(type_prop, gather_tree_output_shape)
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3}); auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3}); auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1}); auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3}); auto end_token = make_shared<op::Parameter>(element::i64, Shape{});
auto gather_tree = auto gather_tree =
make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token); make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
...@@ -40,7 +40,7 @@ TEST(type_prop, gather_tree_pooling_step_ids_invalid_rank) ...@@ -40,7 +40,7 @@ TEST(type_prop, gather_tree_pooling_step_ids_invalid_rank)
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4}); auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3}); auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1}); auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3}); auto end_token = make_shared<op::Parameter>(element::i64, Shape{});
try try
{ {
auto gather_tree = auto gather_tree =
...@@ -64,7 +64,7 @@ TEST(type_prop, gather_tree_parent_idx_invalid_rank) ...@@ -64,7 +64,7 @@ TEST(type_prop, gather_tree_parent_idx_invalid_rank)
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3}); auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4}); auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1}); auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3}); auto end_token = make_shared<op::Parameter>(element::i64, Shape{});
try try
{ {
auto gather_tree = auto gather_tree =
...@@ -89,7 +89,7 @@ TEST(type_prop, gather_tree_max_seq_len_invalid_rank) ...@@ -89,7 +89,7 @@ TEST(type_prop, gather_tree_max_seq_len_invalid_rank)
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3}); auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3}); auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1, 2}); auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1, 2});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3}); auto end_token = make_shared<op::Parameter>(element::i64, Shape{});
try try
{ {
auto gather_tree = auto gather_tree =
...@@ -114,7 +114,7 @@ TEST(type_prop, gather_tree_end_token_invalid_rank) ...@@ -114,7 +114,7 @@ TEST(type_prop, gather_tree_end_token_invalid_rank)
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3}); auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3}); auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1}); auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4}); auto end_token = make_shared<op::Parameter>(element::i64, Shape{1});
try try
{ {
auto gather_tree = auto gather_tree =
...@@ -125,7 +125,7 @@ TEST(type_prop, gather_tree_end_token_invalid_rank) ...@@ -125,7 +125,7 @@ TEST(type_prop, gather_tree_end_token_invalid_rank)
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), std::string("end_token input rank must equal to 3 (end_token rank: 4)")); error.what(), std::string("end_token input rank must be scalar (end_token rank: 1)"));
} }
catch (...) catch (...)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment