Commit fe474394 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Michał Karzyński

Changed expected shape of end_token (#3983)

parent 33a6a7d0
......@@ -68,8 +68,8 @@ void op::v1::GatherTree::validate_and_infer_types()
NODE_VALIDATION_CHECK(this,
end_token_rank.rank().is_dynamic() ||
static_cast<size_t>(end_token_rank.rank()) == 3,
"end_token input rank must equal to 3 (end_token rank: ",
static_cast<size_t>(end_token_rank.rank()) == 0,
"end_token input rank must be scalar (end_token rank: ",
static_cast<size_t>(end_token_rank.rank()),
")");
......
......@@ -26,7 +26,7 @@ TEST(type_prop, gather_tree_output_shape)
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{});
auto gather_tree =
make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
......@@ -40,7 +40,7 @@ TEST(type_prop, gather_tree_pooling_step_ids_invalid_rank)
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{});
try
{
auto gather_tree =
......@@ -64,7 +64,7 @@ TEST(type_prop, gather_tree_parent_idx_invalid_rank)
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{});
try
{
auto gather_tree =
......@@ -89,7 +89,7 @@ TEST(type_prop, gather_tree_max_seq_len_invalid_rank)
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1, 2});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{});
try
{
auto gather_tree =
......@@ -114,7 +114,7 @@ TEST(type_prop, gather_tree_end_token_invalid_rank)
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1});
try
{
auto gather_tree =
......@@ -125,7 +125,7 @@ TEST(type_prop, gather_tree_end_token_invalid_rank)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), std::string("end_token input rank must equal to 3 (end_token rank: 4)"));
error.what(), std::string("end_token input rank must be scalar (end_token rank: 1)"));
}
catch (...)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment