Commit 5136e1b5 authored by Adam Rogowiec's avatar Adam Rogowiec

Unit tests for GRUCell.

parent da4e9a0e
This diff is collapsed.
......@@ -14775,3 +14775,86 @@ TEST(type_prop, rnn_cell_invalid_input)
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor B must have shape"));
}
}
TEST(type_prop, gru_cell)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const size_t gates_count = 3;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto W = make_shared<op::Parameter>(element::f32,
Shape{gates_count * hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32,
Shape{gates_count * hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto gru_cell = make_shared<op::GRUCell>(X, W, R, H_t, hidden_size);
EXPECT_EQ(gru_cell->output(0).get_element_type(), element::f32);
EXPECT_EQ(gru_cell->output(0).get_shape(), (Shape{batch_size, hidden_size}));
}
TEST(type_prop, gru_cell_invalid_input)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const size_t gates_count = 3;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
auto R = make_shared<op::Parameter>(element::f32,
Shape{gates_count * hidden_size, hidden_size});
auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
// Invalid W tensor shape.
auto W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
try
{
const auto gru_cell = make_shared<op::GRUCell>(X, W, R, H_t, hidden_size);
FAIL() << "GRUCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor W must have shape"));
}
// Invalid R tensor shape.
W = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, 1});
try
{
const auto gru_cell = make_shared<op::GRUCell>(X, W, R, H_t, hidden_size);
FAIL() << "GRUCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor R must have shape"));
}
// Invalid H_t tensor shape.
R = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
H_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
try
{
const auto gru_cell = make_shared<op::GRUCell>(X, W, R, H_t, hidden_size);
FAIL() << "GRUCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor H_t must have shape"));
}
// Invalid B tensor shape.
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
try
{
const auto gru_cell = make_shared<op::GRUCell>(X, W, R, H_t, hidden_size, B);
FAIL() << "GRUCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor B must have shape"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment