Commit 516167f7 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Add type_prop tests for ReverseSequence (#978)

* type tests for reverse_sequence

* remove commented out code
parent ecce61f1
...@@ -38,6 +38,16 @@ op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg, ...@@ -38,6 +38,16 @@ op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg,
throw ngraph_error("indices should be a 1-dimensional array"); throw ngraph_error("indices should be a 1-dimensional array");
} }
if (batch_axis >= arg->get_shape().size())
{
throw ngraph_error("batch axis index is out of bounds");
}
if (seq_axis >= arg->get_shape().size())
{
throw ngraph_error("sequence axis index is out of bounds");
}
if (arg->get_shape().at(batch_axis) != seq_indices->get_shape().at(0)) if (arg->get_shape().at(batch_axis) != seq_indices->get_shape().at(0))
{ {
throw ngraph_error("Sequence length size should be equal to batch axis dimension"); throw ngraph_error("Sequence length size should be equal to batch axis dimension");
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/op/batch_norm.hpp" #include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include <memory> #include <memory>
using namespace std; using namespace std;
...@@ -4024,6 +4025,92 @@ TEST(type_prop, reverse_3d_deduce_oob) ...@@ -4024,6 +4025,92 @@ TEST(type_prop, reverse_3d_deduce_oob)
} }
} }
TEST(type_prop, reverse_sequence_1_dim)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{4, 3, 2});
auto seq_lenghts = make_shared<op::Parameter>(element::f32, Shape{4, 4});
try
{
size_t batch_axis = 0;
size_t seq_axis = 1;
auto bc = make_shared<op::ReverseSequence>(data, seq_lenghts, batch_axis, seq_axis);
FAIL() << "ReverseSequence c-tor should throw for seq_lenghts whose rank isn't equal to 1";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("indices should be a 1-dimensional array"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, reverse_sequence_batch_index_oob)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{4, 3, 2});
auto seq_lenghts = make_shared<op::Parameter>(element::f32, Shape{3});
try
{
size_t batch_axis = 3;
size_t seq_axis = 1;
auto bc = make_shared<op::ReverseSequence>(data, seq_lenghts, batch_axis, seq_axis);
FAIL() << "ReverseSequence c-tor should throw for out-of-bounds batch axis index";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("batch axis index is out of bounds"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, reverse_sequence_sequence_index_oob)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{4, 3, 2});
auto seq_lenghts = make_shared<op::Parameter>(element::f32, Shape{3});
try
{
size_t batch_axis = 1;
size_t seq_axis = 3;
auto bc = make_shared<op::ReverseSequence>(data, seq_lenghts, batch_axis, seq_axis);
FAIL() << "ReverseSequence c-tor should throw for out-of-bounds sequence axis index";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("sequence axis index is out of bounds"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, reverse_sequence_seq_len_size_equal_to_batch_dim)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{4, 3, 2});
auto seq_lenghts = make_shared<op::Parameter>(element::f32, Shape{3});
try
{
size_t batch_axis = 0;
size_t seq_axis = 1;
auto bc = make_shared<op::ReverseSequence>(data, seq_lenghts, batch_axis, seq_axis);
FAIL() << "ReverseSequence c-tor should throw when sequence length size isn't equal to "
"batch dimension";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(),
std::string("Sequence length size should be equal to batch axis dimension"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, reduce_window_deduce_1d) TEST(type_prop, reduce_window_deduce_1d)
{ {
auto param_0 = make_shared<op::Parameter>(element::f32, Shape{16}); auto param_0 = make_shared<op::Parameter>(element::f32, Shape{16});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment