Commit 7ccb6cf1 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[ONNX] Fix backward pass for bidirectional LSTM. (#3194)

* Fix used operator for reversing input sequences in LSTM.

* Fix backward pass for bidirectional LSTM.

* UT for LSTM with sequence_lens shorter than input sequence size.

* Skip LSTM UT using ReverseSequence since it is not supported yet on
PlaidML.
parent 13210138
......@@ -33,7 +33,7 @@
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
......@@ -279,7 +279,8 @@ namespace ngraph
if (reverse)
{
m_X = std::make_shared<ngraph::op::Reverse>(m_X, AxisSet{0});
m_X = std::make_shared<ngraph::op::ReverseSequence>(
m_X, m_seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
}
NodeVector in_seqs{};
......@@ -343,7 +344,8 @@ namespace ngraph
// Get back the original order of the output data.
if (reverse)
{
Y = std::make_shared<ngraph::op::Reverse>(Y, AxisSet{0});
Y = std::make_shared<ngraph::op::ReverseSequence>(
Y, m_seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
}
// Expand Y so that it has expected shape:
......@@ -485,7 +487,7 @@ namespace ngraph
attributes);
NodeVector fwd_results{lstm_fwd.run()};
NodeVector rev_results{lstm_fwd.run(true)};
NodeVector rev_results{lstm_reversed.run(true)};
// Stack together respective outputs from both forward and reverse passess.
std::shared_ptr<ngraph::Node> Y{std::make_shared<ngraph::op::Concat>(
......
......@@ -294,6 +294,10 @@ model_hardmax
quantized_convolution
quantized_conv_int32_output
# unsupported op: `ReverseSequence`
model_lstm_bdir_short_input_seq
model_lstm_mixed_seq_reverse
# node validation error: "Argument shapes are inconsistent."
model_lstm_fwd_with_clip
model_lstm_fwd_mixed_seq
......
ir_version: 5
graph {
node {
input: "X"
input: "W"
input: "R"
input: "B"
input: "sequence_lens"
input: "initial_h"
input: "initial_c"
input: "P"
output: "Y"
output: "Y_h"
output: ""
name: "node1"
op_type: "LSTM"
attribute {
name: "direction"
s: "bidirectional"
type: STRING
}
attribute {
name: "input_forget"
i: 0
type: INT
}
attribute {
name: "activations"
strings: "sigmoid"
strings: "tanh"
strings: "tanh"
strings: "sigmoid"
strings: "tanh"
strings: "tanh"
type: STRINGS
}
attribute {
name: "hidden_size"
i: 2
type: INT
}
attribute {
name: "clip"
f: 9999
type: FLOAT
}
doc_string: "LSTM"
domain: ""
}
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 8
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 8
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 16
}
}
}
}
}
input {
name: "sequence_lens"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
}
}
}
}
input {
name: "initial_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "initial_c"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "P"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 6
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
domain: ""
version: 7
}
ir_version: 5
graph {
node {
input: "X"
input: "W"
input: "R"
input: ""
input: "sequence_lens"
input: ""
input: ""
input: ""
output: "Y"
output: "Y_h"
output: "Y_c"
name: "node1"
op_type: "LSTM"
attribute {
name: "direction"
s: "reverse"
type: STRING
}
attribute {
name: "input_forget"
i: 0
type: INT
}
attribute {
name: "activations"
strings: "sigmoid"
strings: "tanh"
strings: "tanh"
type: STRINGS
}
attribute {
name: "hidden_size"
i: 3
type: INT
}
attribute {
name: "clip"
f: 9999
type: FLOAT
}
doc_string: "LSTM"
domain: ""
}
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 12
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 12
}
dim {
dim_value: 3
}
}
}
}
}
input {
name: "sequence_lens"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "Y_c"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
domain: ""
version: 7
}
......@@ -249,3 +249,110 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_large_batch_no_clip)
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_bdir_short_input_seq)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/lstm_bdir_short_input_seq.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
// X
test_case.add_input<float>({-0.455351f, -0.276391f, -0.185934f, -0.269585f});
// W
test_case.add_input<float>(
{-0.494659f, 0.0453352f, -0.487793f, 0.417264f, -0.0175329f, 0.489074f, -0.446013f,
0.414029f, -0.0091708f, -0.255364f, -0.106952f, -0.266717f, -0.0888852f, -0.428709f,
-0.283349f, 0.208792f, -0.494659f, 0.0453352f, -0.487793f, 0.417264f, -0.0175329f,
0.489074f, -0.446013f, 0.414029f, -0.0091708f, -0.255364f, -0.106952f, -0.266717f,
-0.0888852f, -0.428709f, -0.283349f, 0.208792f});
// R
test_case.add_input<float>(
{0.146626f, -0.0620289f, -0.0815302f, 0.100482f, -0.219535f, -0.306635f, -0.28515f,
-0.314112f, -0.228172f, 0.405972f, 0.31576f, 0.281487f, -0.394864f, 0.42111f,
-0.386624f, -0.390225f, 0.146626f, -0.0620289f, -0.0815302f, 0.100482f, -0.219535f,
-0.306635f, -0.28515f, -0.314112f, -0.228172f, 0.405972f, 0.31576f, 0.281487f,
-0.394864f, 0.42111f, -0.386624f, -0.390225f});
// B
test_case.add_input<float>(
{0.381619f, 0.0323954f, -0.14449f, 0.420804f, -0.258721f, 0.45056f, -0.250755f, 0.0967895f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.381619f, 0.0323954f, -0.14449f, 0.420804f, -0.258721f, 0.45056f, -0.250755f, 0.0967895f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
// sequence_lens
test_case.add_input<int>({1});
// initial_h
test_case.add_input<float>({0.0f, 0.0f, -0.0306872f, 0.028035f});
// initial_c
test_case.add_input<float>({0.0f, 0.0f, -0.07243599f, 0.0467052f});
// P
test_case.add_input<float>({0.2345f,
0.5235f,
0.4378f,
0.3475f,
0.8927f,
0.3456f,
0.2345f,
0.5235f,
0.4378f,
0.3475f,
0.8927f,
0.3456f});
// Y
test_case.add_expected_output<float>(
Shape{2, 2, 1, 2},
{-0.0251062f, 0.0561262f, -0.0318928f, 0.0762679f, 0.0f, 0.0f, 0.0f, 0.0f});
// Y_h
test_case.add_expected_output<float>(Shape{2, 1, 2},
{-0.0251062f, 0.0561262f, -0.0318928f, 0.0762679f});
test_case.set_tolerance(DEFAULT_FLOAT_TOLERANCE_BITS + 3);
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_mixed_seq_reverse)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/lstm_mixed_seq_reverse.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
size_t hidden_size = 3;
// X
test_case.add_input<float>({1.f, 2.f, 10.f, 11.f});
// W
test_case.add_input<float>(
{0.1f, 0.2f, 0.3f, 0.4f, 1.f, 2.f, 3.f, 4.f, 10.f, 11.f, 12.f, 13.f});
// R
test_case.add_input<float>(std::vector<float>(4 * hidden_size * hidden_size, 0.1f));
// sequence_lens
test_case.add_input<int>({1, 2});
// Y
test_case.add_expected_output<float>(Shape{2, 1, 2, 3},
{0.28828844f,
0.36581877f,
0.45679423f,
0.64046413f,
0.82303363f,
0.91610711f,
0.f,
0.f,
0.f,
0.62759886f,
0.71640738f,
0.74624585f});
// Y_h
test_case.add_expected_output<float>(
Shape{1, 2, 3},
{0.28828844f, 0.36581877f, 0.45679423f, 0.64046413f, 0.82303363f, 0.91610711f});
// Y_c
test_case.add_expected_output<float>(
Shape{1, 2, 3},
{0.52497941f, 0.54983425f, 0.5744428f, 1.34960834f, 1.54772296f, 1.65633056f});
test_case.set_tolerance(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
test_case.run();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment