Commit ad30a5bd authored by Adam Rogowiec's avatar Adam Rogowiec

Unit tests.

parent 3308f7b2
......@@ -1246,3 +1246,178 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_activaction_functions)
{0.94656503f, 0.9527454f, 0.9706756f, 0.84206575f, 0.91898793f, 0.9127192f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, rnn_cell_no_bias)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto rnn_cell = make_shared<op::RNNCell>(X, W, R, H_t, hidden_size);
auto function = make_shared<Function>(rnn_cell, ParameterVector{X, W, R, H_t});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
// X
test_case.add_input<float>(
{0.3432185f, 0.612268f, 0.20272376f, 0.9513413f, 0.30585995f, 0.7265472f});
// W
test_case.add_input<float>({0.41930267f,
0.7872176f,
0.89940447f,
0.23659843f,
0.24676207f,
0.17101714f,
0.3147149f,
0.6555601f,
0.4559603f});
// R
test_case.add_input<float>({0.8374871f,
0.86660194f,
0.82114047f,
0.71549815f,
0.18775631f,
0.3182116f,
0.25392973f,
0.38301638f,
0.85531586f});
// Ht
test_case.add_input<float>(
{0.12444675f, 0.52055854f, 0.46489045f, 0.4983964f, 0.7730452f, 0.28439692f});
test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{0.9408395f, 0.53823817f, 0.84270686f, 0.98932856f, 0.768665f, 0.90461975f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, rnn_cell_bias_clip)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
float clip = 2.88f;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{2 * hidden_size});
const auto rnn_cell = make_shared<op::RNNCell>(X,
W,
R,
H_t,
hidden_size,
B,
vector<string>{"tanh"},
vector<float>{},
vector<float>{},
clip);
auto function = make_shared<Function>(rnn_cell, ParameterVector{X, W, R, H_t, B});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
// X
test_case.add_input<float>(
{0.3432185f, 0.612268f, 0.20272376f, 0.9513413f, 0.30585995f, 0.7265472f});
// W
test_case.add_input<float>({0.41930267f,
0.7872176f,
0.89940447f,
0.23659843f,
0.24676207f,
0.17101714f,
0.3147149f,
0.6555601f,
0.4559603f});
// R
test_case.add_input<float>({0.8374871f,
0.86660194f,
0.82114047f,
0.71549815f,
0.18775631f,
0.3182116f,
0.25392973f,
0.38301638f,
0.85531586f});
// Ht
test_case.add_input<float>(
{0.12444675f, 0.52055854f, 0.46489045f, 0.4983964f, 0.7730452f, 0.28439692f});
// B
test_case.add_input<float>(
{0.45513555f, 0.96227735f, 0.24737759f, 0.57380486f, 0.67398053f, 0.18968852f});
test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{0.9922437f, 0.97749525f, 0.9312212f, 0.9937176f, 0.9901317f, 0.95906746f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, rnn_cell_activation_function)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
float clip = 2.88f;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{2 * hidden_size});
const auto rnn_cell = make_shared<op::RNNCell>(X,
W,
R,
H_t,
hidden_size,
B,
vector<string>{"sigmoid"},
vector<float>{},
vector<float>{},
clip);
auto function = make_shared<Function>(rnn_cell, ParameterVector{X, W, R, H_t, B});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
// X
test_case.add_input<float>(
{0.3432185f, 0.612268f, 0.20272376f, 0.9513413f, 0.30585995f, 0.7265472f});
// W
test_case.add_input<float>({0.41930267f,
0.7872176f,
0.89940447f,
0.23659843f,
0.24676207f,
0.17101714f,
0.3147149f,
0.6555601f,
0.4559603f});
// R
test_case.add_input<float>({0.8374871f,
0.86660194f,
0.82114047f,
0.71549815f,
0.18775631f,
0.3182116f,
0.25392973f,
0.38301638f,
0.85531586f});
// Ht
test_case.add_input<float>(
{0.12444675f, 0.52055854f, 0.46489045f, 0.4983964f, 0.7730452f, 0.28439692f});
// B
test_case.add_input<float>(
{0.45513555f, 0.96227735f, 0.24737759f, 0.57380486f, 0.67398053f, 0.18968852f});
test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{0.94126844f, 0.9036043f, 0.841243f, 0.9468489f, 0.934215f, 0.873708f});
test_case.run();
}
......@@ -14697,3 +14697,81 @@ TEST(type_prop, lstm_cell_invalid_input)
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor P must have shape"));
}
}
TEST(type_prop, rnn_cell)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto rnn_cell = make_shared<op::RNNCell>(X, W, R, H_t, hidden_size);
EXPECT_EQ(rnn_cell->output(0).get_element_type(), element::f32);
EXPECT_EQ(rnn_cell->output(0).get_shape(), (Shape{batch_size, hidden_size}));
}
TEST(type_prop, rnn_cell_invalid_input)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
// Invalid W tensor shape.
auto W = make_shared<op::Parameter>(element::f32, Shape{2 * hidden_size, input_size});
try
{
const auto rnn_cell = make_shared<op::RNNCell>(X, W, R, H_t, hidden_size);
FAIL() << "RNNCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor W must have shape"));
}
// Invalid R tensor shape.
W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, 1});
try
{
const auto rnn_cell = make_shared<op::RNNCell>(X, W, R, H_t, hidden_size);
FAIL() << "RNNCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor R must have shape"));
}
// Invalid H_t tensor shape.
R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
H_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
try
{
const auto rnn_cell = make_shared<op::RNNCell>(X, W, R, H_t, hidden_size);
FAIL() << "RNNCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor H_t must have shape"));
}
// Invalid B tensor shape.
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
try
{
const auto rnn_cell = make_shared<op::RNNCell>(X, W, R, H_t, hidden_size, B);
FAIL() << "RNNCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor B must have shape"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment