Commit f3603647 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Add negative axes support for ReverseSequence (#3926)

* Added negative axes support for ReverseRequence

* code review remarks introduced

* Disable reverse sequence for PlaidMl tests

* Fixed styles

* Fixed axes assignment

* Fixed normalized axes assignment
parent f6bddf08
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/reverse_sequence.hpp" #include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/validation_util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -27,11 +28,13 @@ constexpr NodeTypeInfo op::ReverseSequence::type_info; ...@@ -27,11 +28,13 @@ constexpr NodeTypeInfo op::ReverseSequence::type_info;
op::ReverseSequence::ReverseSequence(const Output<Node>& arg, op::ReverseSequence::ReverseSequence(const Output<Node>& arg,
const Output<Node>& seq_indices, const Output<Node>& seq_indices,
size_t batch_axis, int64_t batch_axis,
size_t seq_axis) int64_t seq_axis)
: Op({arg, seq_indices}) : Op({arg, seq_indices})
, m_batch_axis(batch_axis) , m_batch_axis(batch_axis)
, m_seq_axis(seq_axis) , m_seq_axis(seq_axis)
, m_normalized_batch_axis{0}
, m_normalized_seq_axis{0}
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -41,21 +44,30 @@ void op::ReverseSequence::validate_and_infer_types() ...@@ -41,21 +44,30 @@ void op::ReverseSequence::validate_and_infer_types()
auto input_shape = get_input_partial_shape(0); auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank(); auto input_rank = input_shape.rank();
NODE_VALIDATION_CHECK(this, if (m_batch_axis < 0 || m_seq_axis < 0)
input_rank.is_dynamic() || m_batch_axis < size_t(input_rank), {
"Batch axis index (", NODE_VALIDATION_CHECK(this,
m_batch_axis, input_rank.is_static(),
") is out of bounds (argument shape: ", "In order to handle negative axes input_rank must be static (",
input_shape, "batch_axis=",
")."); m_batch_axis,
", seq_axis=",
NODE_VALIDATION_CHECK(this, m_seq_axis,
input_rank.is_dynamic() || m_seq_axis < size_t(input_rank), ")");
"Sequence axis index (", }
m_seq_axis, else
") is out of bounds (argument shape: ", {
input_shape, m_normalized_batch_axis = m_batch_axis;
")."); m_normalized_seq_axis = m_seq_axis;
}
if (input_rank.is_static())
{
m_normalized_batch_axis =
ngraph::normalize_axis(this, m_batch_axis, static_cast<int64_t>(input_rank));
m_normalized_seq_axis =
ngraph::normalize_axis(this, m_seq_axis, static_cast<int64_t>(input_rank));
}
auto indices_shape = get_input_partial_shape(1); auto indices_shape = get_input_partial_shape(1);
auto indices_rank = indices_shape.rank(); auto indices_rank = indices_shape.rank();
...@@ -73,20 +85,21 @@ void op::ReverseSequence::validate_and_infer_types() ...@@ -73,20 +85,21 @@ void op::ReverseSequence::validate_and_infer_types()
{ {
Dimension merged_sequence_length; Dimension merged_sequence_length;
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(this,
this, Dimension::merge(merged_sequence_length,
Dimension::merge(merged_sequence_length, input_shape[m_batch_axis], indices_shape[0]), input_shape[m_normalized_batch_axis],
"Sequence length (", indices_shape[0]),
indices_shape[0], "Sequence length (",
") is not equal to batch axis ", indices_shape[0],
"dimension (", ") is not equal to batch axis ",
input_shape[m_batch_axis], "dimension (",
") (argument shape: ", input_shape[m_normalized_batch_axis],
input_shape, ") (argument shape: ",
", sequence indices shape: ", input_shape,
indices_shape, ", sequence indices shape: ",
")."); indices_shape,
output_shape[m_batch_axis] = merged_sequence_length; ").");
output_shape[m_normalized_batch_axis] = merged_sequence_length;
} }
set_output_type(0, get_input_element_type(0), output_shape); set_output_type(0, get_input_element_type(0), output_shape);
......
...@@ -35,25 +35,29 @@ namespace ngraph ...@@ -35,25 +35,29 @@ namespace ngraph
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
ReverseSequence(const Output<Node>& arg, ReverseSequence(const Output<Node>& arg,
const Output<Node>& seq_lengths, const Output<Node>& seq_lengths,
size_t batch_axis, int64_t batch_axis,
size_t seq_axis); int64_t seq_axis);
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
size_t get_batch_axis() const { return m_batch_axis; } size_t get_batch_axis() const { return m_normalized_batch_axis; }
void set_batch_axis(size_t batch_axis) { m_batch_axis = batch_axis; } int64_t get_origin_batch_axis() const { return m_batch_axis; }
size_t get_sequence_axis() const { return m_seq_axis; } void set_batch_axis(int64_t batch_axis) { m_batch_axis = batch_axis; }
void set_sequence_axis(size_t sequence_axis) { m_seq_axis = sequence_axis; } size_t get_sequence_axis() const { return m_normalized_seq_axis; }
int64_t get_origin_sequence_axis() const { return m_seq_axis; }
void set_sequence_axis(int64_t sequence_axis) { m_seq_axis = sequence_axis; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
private: private:
size_t m_batch_axis{0}; int64_t m_batch_axis;
size_t m_seq_axis{0}; int64_t m_seq_axis;
size_t m_normalized_batch_axis;
size_t m_normalized_seq_axis;
}; };
} }
using v0::ReverseSequence; using v0::ReverseSequence;
......
...@@ -15,6 +15,7 @@ min_3d_eliminate_zero_dim # Out-of-range for PlaidML ...@@ -15,6 +15,7 @@ min_3d_eliminate_zero_dim # Out-of-range for PlaidML
reverse_sequence_n2c3h4w2 # No plans to implement ReverseSequence reverse_sequence_n2c3h4w2 # No plans to implement ReverseSequence
reverse_sequence_n4c3h2w2 # No plans to implement ReverseSequence reverse_sequence_n4c3h2w2 # No plans to implement ReverseSequence
reverse_sequence_n4d2c3h2w2 # No plans to implement ReverseSequence reverse_sequence_n4d2c3h2w2 # No plans to implement ReverseSequence
reverse_sequence_negative_axes # No plans to implement ReverseSequence
topk_1d_max_all # No plans to implement TopK topk_1d_max_all # No plans to implement TopK
topk_1d_max_partial # No plans to implement TopK topk_1d_max_partial # No plans to implement TopK
topk_1d_max_one # No plans to implement TopK topk_1d_max_one # No plans to implement TopK
......
...@@ -2370,8 +2370,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2370,8 +2370,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
} }
case OP_TYPEID::ReverseSequence: case OP_TYPEID::ReverseSequence:
{ {
auto batch_axis = node_js.at("batch_axis").get<size_t>(); auto batch_axis = node_js.at("batch_axis").get<int64_t>();
auto sequence_axis = node_js.at("sequence_axis").get<size_t>(); auto sequence_axis = node_js.at("sequence_axis").get<int64_t>();
node = make_shared<op::ReverseSequence>(args[0], args[1], batch_axis, sequence_axis); node = make_shared<op::ReverseSequence>(args[0], args[1], batch_axis, sequence_axis);
break; break;
} }
...@@ -4033,8 +4033,8 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -4033,8 +4033,8 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::ReverseSequence: case OP_TYPEID::ReverseSequence:
{ {
auto tmp = static_cast<const op::ReverseSequence*>(&n); auto tmp = static_cast<const op::ReverseSequence*>(&n);
node["batch_axis"] = tmp->get_batch_axis(); node["batch_axis"] = tmp->get_origin_batch_axis();
node["sequence_axis"] = tmp->get_sequence_axis(); node["sequence_axis"] = tmp->get_origin_sequence_axis();
break; break;
} }
case OP_TYPEID::RNNCell: case OP_TYPEID::RNNCell:
......
...@@ -154,3 +154,44 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_sequence_n4d2c3h2w2) ...@@ -154,3 +154,44 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_sequence_n4d2c3h2w2)
handle->call_with_validate({result}, {a, b}); handle->call_with_validate({result}, {a, b});
EXPECT_EQ(read_vector<int>(result), expected); EXPECT_EQ(read_vector<int>(result), expected);
} }
NGRAPH_TEST(${BACKEND_NAME}, reverse_sequence_negative_axes)
{
Shape shape{2, 3, 4, 2};
Shape seq_len_shape{4};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto B = make_shared<op::Parameter>(element::i32, seq_len_shape);
int64_t batch_axis = -2;
int64_t sequence_axis = -3;
auto rs = std::make_shared<op::ReverseSequence>(A, B, batch_axis, sequence_axis);
auto f = make_shared<Function>(rs, ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::i32, shape);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::i32, seq_len_shape);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::i32, shape);
std::vector<int> input{
0, 0, 3, 0, 6, 0, 9, 0, 1, 0, 4, 0, 7, 0, 10, 0, 2, 0, 5, 0, 8, 0, 11, 0,
12, 0, 15, 0, 18, 0, 21, 0, 13, 0, 16, 0, 19, 0, 22, 0, 14, 0, 17, 0, 20, 0, 23, 0,
};
std::vector<int> seq_lenghts{1, 2, 1, 2};
copy_data(b, seq_lenghts);
std::vector<int> expected{
0, 0, 4, 0, 6, 0, 10, 0, 1, 0, 3, 0, 7, 0, 9, 0, 2, 0, 5, 0, 8, 0, 11, 0,
12, 0, 16, 0, 18, 0, 22, 0, 13, 0, 15, 0, 19, 0, 21, 0, 14, 0, 17, 0, 20, 0, 23, 0};
copy_data(a, input);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_EQ(read_vector<int>(result), expected);
}
...@@ -54,9 +54,9 @@ TEST(type_prop, reverse_sequence_batch_index_oob) ...@@ -54,9 +54,9 @@ TEST(type_prop, reverse_sequence_batch_index_oob)
auto bc = make_shared<op::ReverseSequence>(data, seq_lenghts, batch_axis, seq_axis); auto bc = make_shared<op::ReverseSequence>(data, seq_lenghts, batch_axis, seq_axis);
FAIL() << "ReverseSequence c-tor should throw for out-of-bounds batch axis index"; FAIL() << "ReverseSequence c-tor should throw for out-of-bounds batch axis index";
} }
catch (const NodeValidationFailure& error) catch (const ngraph_error& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch axis index (3) is out of bounds")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter axis 3 out of the tensor rank"));
} }
catch (...) catch (...)
{ {
...@@ -75,9 +75,9 @@ TEST(type_prop, reverse_sequence_sequence_index_oob) ...@@ -75,9 +75,9 @@ TEST(type_prop, reverse_sequence_sequence_index_oob)
auto bc = make_shared<op::ReverseSequence>(data, seq_lengths, batch_axis, seq_axis); auto bc = make_shared<op::ReverseSequence>(data, seq_lengths, batch_axis, seq_axis);
FAIL() << "ReverseSequence c-tor should throw for out-of-bounds sequence axis index"; FAIL() << "ReverseSequence c-tor should throw for out-of-bounds sequence axis index";
} }
catch (const NodeValidationFailure& error) catch (const ngraph_error& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Sequence axis index (3) is out of bounds")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter axis 3 out of the tensor rank"));
} }
catch (...) catch (...)
{ {
...@@ -179,9 +179,9 @@ TEST(type_prop, reverse_sequence_partial_both_rank_static_dynamic_batch_axis_oob ...@@ -179,9 +179,9 @@ TEST(type_prop, reverse_sequence_partial_both_rank_static_dynamic_batch_axis_oob
auto rs = make_shared<op::ReverseSequence>(data, seq_lengths, batch_axis, seq_axis); auto rs = make_shared<op::ReverseSequence>(data, seq_lengths, batch_axis, seq_axis);
FAIL() << "Batch axis out of bounds not detected (rank-static dynamic shape)"; FAIL() << "Batch axis out of bounds not detected (rank-static dynamic shape)";
} }
catch (const NodeValidationFailure& error) catch (const ngraph_error& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch axis index (4) is out of bounds")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter axis 4 out of the tensor rank"));
} }
catch (...) catch (...)
{ {
...@@ -204,9 +204,9 @@ TEST(type_prop, reverse_sequence_partial_both_rank_static_dynamic_sequence_axis_ ...@@ -204,9 +204,9 @@ TEST(type_prop, reverse_sequence_partial_both_rank_static_dynamic_sequence_axis_
auto rs = make_shared<op::ReverseSequence>(data, seq_lengths, batch_axis, seq_axis); auto rs = make_shared<op::ReverseSequence>(data, seq_lengths, batch_axis, seq_axis);
FAIL() << "Sequence axis out of bounds not detected (rank-static dynamic shape)"; FAIL() << "Sequence axis out of bounds not detected (rank-static dynamic shape)";
} }
catch (const NodeValidationFailure& error) catch (const ngraph_error& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Sequence axis index (4) is out of bounds")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter axis 4 out of the tensor rank"));
} }
catch (...) catch (...)
{ {
...@@ -289,3 +289,39 @@ TEST( ...@@ -289,3 +289,39 @@ TEST(
FAIL() << "Deduced type check failed for unexpected reason"; FAIL() << "Deduced type check failed for unexpected reason";
} }
} }
TEST(type_prop, reverse_sequence_negative_axis_dynamic_input_rank)
{
auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto seq_lengths = make_shared<op::Parameter>(element::f32, PartialShape{1});
int64_t batch_axis = 1;
int64_t seq_axis = -2;
try
{
auto rs = make_shared<op::ReverseSequence>(data, seq_lengths, batch_axis, seq_axis);
FAIL() << "Dynamic input rank for negative axis not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("In order to handle negative axes input_rank must be "
"static (batch_axis=1, seq_axis=-2)"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, reverse_sequence_negative_axes_support)
{
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4, 5});
auto seq_lengths = make_shared<op::Parameter>(element::f32, PartialShape{3});
int64_t batch_axis = -3;
int64_t seq_axis = -2;
auto rs = make_shared<op::ReverseSequence>(data, seq_lengths, batch_axis, seq_axis);
EXPECT_EQ(rs->get_batch_axis(), 2);
EXPECT_EQ(rs->get_sequence_axis(), 3);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment