Commit f3603647 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Add negative axes support for ReverseSequence (#3926)

* Added negative axes support for ReverseRequence

* code review remarks introduced

* Disable reverse sequence for PlaidMl tests

* Fixed styles

* Fixed axes assignment

* Fixed normalized axes assignment
parent f6bddf08
......@@ -19,6 +19,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
......@@ -27,11 +28,13 @@ constexpr NodeTypeInfo op::ReverseSequence::type_info;
op::ReverseSequence::ReverseSequence(const Output<Node>& arg,
const Output<Node>& seq_indices,
size_t batch_axis,
size_t seq_axis)
int64_t batch_axis,
int64_t seq_axis)
: Op({arg, seq_indices})
, m_batch_axis(batch_axis)
, m_seq_axis(seq_axis)
, m_normalized_batch_axis{0}
, m_normalized_seq_axis{0}
{
constructor_validate_and_infer_types();
}
......@@ -41,21 +44,30 @@ void op::ReverseSequence::validate_and_infer_types()
auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank();
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || m_batch_axis < size_t(input_rank),
"Batch axis index (",
m_batch_axis,
") is out of bounds (argument shape: ",
input_shape,
").");
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || m_seq_axis < size_t(input_rank),
"Sequence axis index (",
m_seq_axis,
") is out of bounds (argument shape: ",
input_shape,
").");
if (m_batch_axis < 0 || m_seq_axis < 0)
{
NODE_VALIDATION_CHECK(this,
input_rank.is_static(),
"In order to handle negative axes input_rank must be static (",
"batch_axis=",
m_batch_axis,
", seq_axis=",
m_seq_axis,
")");
}
else
{
m_normalized_batch_axis = m_batch_axis;
m_normalized_seq_axis = m_seq_axis;
}
if (input_rank.is_static())
{
m_normalized_batch_axis =
ngraph::normalize_axis(this, m_batch_axis, static_cast<int64_t>(input_rank));
m_normalized_seq_axis =
ngraph::normalize_axis(this, m_seq_axis, static_cast<int64_t>(input_rank));
}
auto indices_shape = get_input_partial_shape(1);
auto indices_rank = indices_shape.rank();
......@@ -73,20 +85,21 @@ void op::ReverseSequence::validate_and_infer_types()
{
Dimension merged_sequence_length;
NODE_VALIDATION_CHECK(
this,
Dimension::merge(merged_sequence_length, input_shape[m_batch_axis], indices_shape[0]),
"Sequence length (",
indices_shape[0],
") is not equal to batch axis ",
"dimension (",
input_shape[m_batch_axis],
") (argument shape: ",
input_shape,
", sequence indices shape: ",
indices_shape,
").");
output_shape[m_batch_axis] = merged_sequence_length;
NODE_VALIDATION_CHECK(this,
Dimension::merge(merged_sequence_length,
input_shape[m_normalized_batch_axis],
indices_shape[0]),
"Sequence length (",
indices_shape[0],
") is not equal to batch axis ",
"dimension (",
input_shape[m_normalized_batch_axis],
") (argument shape: ",
input_shape,
", sequence indices shape: ",
indices_shape,
").");
output_shape[m_normalized_batch_axis] = merged_sequence_length;
}
set_output_type(0, get_input_element_type(0), output_shape);
......
......@@ -35,25 +35,29 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
ReverseSequence(const Output<Node>& arg,
const Output<Node>& seq_lengths,
size_t batch_axis,
size_t seq_axis);
int64_t batch_axis,
int64_t seq_axis);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
size_t get_batch_axis() const { return m_batch_axis; }
void set_batch_axis(size_t batch_axis) { m_batch_axis = batch_axis; }
size_t get_sequence_axis() const { return m_seq_axis; }
void set_sequence_axis(size_t sequence_axis) { m_seq_axis = sequence_axis; }
size_t get_batch_axis() const { return m_normalized_batch_axis; }
int64_t get_origin_batch_axis() const { return m_batch_axis; }
void set_batch_axis(int64_t batch_axis) { m_batch_axis = batch_axis; }
size_t get_sequence_axis() const { return m_normalized_seq_axis; }
int64_t get_origin_sequence_axis() const { return m_seq_axis; }
void set_sequence_axis(int64_t sequence_axis) { m_seq_axis = sequence_axis; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
size_t m_batch_axis{0};
size_t m_seq_axis{0};
int64_t m_batch_axis;
int64_t m_seq_axis;
size_t m_normalized_batch_axis;
size_t m_normalized_seq_axis;
};
}
using v0::ReverseSequence;
......
......@@ -15,6 +15,7 @@ min_3d_eliminate_zero_dim # Out-of-range for PlaidML
reverse_sequence_n2c3h4w2 # No plans to implement ReverseSequence
reverse_sequence_n4c3h2w2 # No plans to implement ReverseSequence
reverse_sequence_n4d2c3h2w2 # No plans to implement ReverseSequence
reverse_sequence_negative_axes # No plans to implement ReverseSequence
topk_1d_max_all # No plans to implement TopK
topk_1d_max_partial # No plans to implement TopK
topk_1d_max_one # No plans to implement TopK
......
......@@ -2370,8 +2370,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::ReverseSequence:
{
auto batch_axis = node_js.at("batch_axis").get<size_t>();
auto sequence_axis = node_js.at("sequence_axis").get<size_t>();
auto batch_axis = node_js.at("batch_axis").get<int64_t>();
auto sequence_axis = node_js.at("sequence_axis").get<int64_t>();
node = make_shared<op::ReverseSequence>(args[0], args[1], batch_axis, sequence_axis);
break;
}
......@@ -4033,8 +4033,8 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::ReverseSequence:
{
auto tmp = static_cast<const op::ReverseSequence*>(&n);
node["batch_axis"] = tmp->get_batch_axis();
node["sequence_axis"] = tmp->get_sequence_axis();
node["batch_axis"] = tmp->get_origin_batch_axis();
node["sequence_axis"] = tmp->get_origin_sequence_axis();
break;
}
case OP_TYPEID::RNNCell:
......
......@@ -154,3 +154,44 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_sequence_n4d2c3h2w2)
handle->call_with_validate({result}, {a, b});
EXPECT_EQ(read_vector<int>(result), expected);
}
NGRAPH_TEST(${BACKEND_NAME}, reverse_sequence_negative_axes)
{
Shape shape{2, 3, 4, 2};
Shape seq_len_shape{4};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto B = make_shared<op::Parameter>(element::i32, seq_len_shape);
int64_t batch_axis = -2;
int64_t sequence_axis = -3;
auto rs = std::make_shared<op::ReverseSequence>(A, B, batch_axis, sequence_axis);
auto f = make_shared<Function>(rs, ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::i32, shape);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::i32, seq_len_shape);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::i32, shape);
std::vector<int> input{
0, 0, 3, 0, 6, 0, 9, 0, 1, 0, 4, 0, 7, 0, 10, 0, 2, 0, 5, 0, 8, 0, 11, 0,
12, 0, 15, 0, 18, 0, 21, 0, 13, 0, 16, 0, 19, 0, 22, 0, 14, 0, 17, 0, 20, 0, 23, 0,
};
std::vector<int> seq_lenghts{1, 2, 1, 2};
copy_data(b, seq_lenghts);
std::vector<int> expected{
0, 0, 4, 0, 6, 0, 10, 0, 1, 0, 3, 0, 7, 0, 9, 0, 2, 0, 5, 0, 8, 0, 11, 0,
12, 0, 16, 0, 18, 0, 22, 0, 13, 0, 15, 0, 19, 0, 21, 0, 14, 0, 17, 0, 20, 0, 23, 0};
copy_data(a, input);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_EQ(read_vector<int>(result), expected);
}
......@@ -54,9 +54,9 @@ TEST(type_prop, reverse_sequence_batch_index_oob)
auto bc = make_shared<op::ReverseSequence>(data, seq_lenghts, batch_axis, seq_axis);
FAIL() << "ReverseSequence c-tor should throw for out-of-bounds batch axis index";
}
catch (const NodeValidationFailure& error)
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch axis index (3) is out of bounds"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter axis 3 out of the tensor rank"));
}
catch (...)
{
......@@ -75,9 +75,9 @@ TEST(type_prop, reverse_sequence_sequence_index_oob)
auto bc = make_shared<op::ReverseSequence>(data, seq_lengths, batch_axis, seq_axis);
FAIL() << "ReverseSequence c-tor should throw for out-of-bounds sequence axis index";
}
catch (const NodeValidationFailure& error)
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Sequence axis index (3) is out of bounds"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter axis 3 out of the tensor rank"));
}
catch (...)
{
......@@ -179,9 +179,9 @@ TEST(type_prop, reverse_sequence_partial_both_rank_static_dynamic_batch_axis_oob
auto rs = make_shared<op::ReverseSequence>(data, seq_lengths, batch_axis, seq_axis);
FAIL() << "Batch axis out of bounds not detected (rank-static dynamic shape)";
}
catch (const NodeValidationFailure& error)
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch axis index (4) is out of bounds"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter axis 4 out of the tensor rank"));
}
catch (...)
{
......@@ -204,9 +204,9 @@ TEST(type_prop, reverse_sequence_partial_both_rank_static_dynamic_sequence_axis_
auto rs = make_shared<op::ReverseSequence>(data, seq_lengths, batch_axis, seq_axis);
FAIL() << "Sequence axis out of bounds not detected (rank-static dynamic shape)";
}
catch (const NodeValidationFailure& error)
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Sequence axis index (4) is out of bounds"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter axis 4 out of the tensor rank"));
}
catch (...)
{
......@@ -289,3 +289,39 @@ TEST(
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, reverse_sequence_negative_axis_dynamic_input_rank)
{
auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto seq_lengths = make_shared<op::Parameter>(element::f32, PartialShape{1});
int64_t batch_axis = 1;
int64_t seq_axis = -2;
try
{
auto rs = make_shared<op::ReverseSequence>(data, seq_lengths, batch_axis, seq_axis);
FAIL() << "Dynamic input rank for negative axis not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("In order to handle negative axes input_rank must be "
"static (batch_axis=1, seq_axis=-2)"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, reverse_sequence_negative_axes_support)
{
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4, 5});
auto seq_lengths = make_shared<op::Parameter>(element::f32, PartialShape{3});
int64_t batch_axis = -3;
int64_t seq_axis = -2;
auto rs = make_shared<op::ReverseSequence>(data, seq_lengths, batch_axis, seq_axis);
EXPECT_EQ(rs->get_batch_axis(), 2);
EXPECT_EQ(rs->get_sequence_axis(), 3);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment