Commit 48b14943 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[ONNX] LSTM mixed length sequences. (#2606)

* Handle mixed length sequences.

* UT for mixed sequence length LSTM.

* Style apply.

* Fix typos.

* Add std:: prefix to preserve unified style within file.
parent 75fadde5
...@@ -35,11 +35,13 @@ ...@@ -35,11 +35,13 @@
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/maximum.hpp" #include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp" #include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/sigmoid.hpp" #include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/tanh.hpp" #include "ngraph/op/tanh.hpp"
...@@ -380,6 +382,7 @@ namespace ngraph ...@@ -380,6 +382,7 @@ namespace ngraph
in_x = reshape::squeeze(in_x); in_x = reshape::squeeze(in_x);
} }
std::int32_t time_step{1};
for (const auto& in_x : in_seqs) for (const auto& in_x : in_seqs)
{ {
// (.) - Denotes element-wise multiplication. // (.) - Denotes element-wise multiplication.
...@@ -423,21 +426,26 @@ namespace ngraph ...@@ -423,21 +426,26 @@ namespace ngraph
o = m_activation_f(clip(add(o, mul(p_o, C)), m_clip_threshold)); o = m_activation_f(clip(add(o, mul(p_o, C)), m_clip_threshold));
// ot (.) h(Ct) // ot (.) h(Ct)
auto H = mul(o, m_activation_h(C)); auto H = mul(o, m_activation_h(C));
h_list.push_back(H);
H_t = H; // Expand tensors with empty outermost dim, so we can later concatenate
C_t = C; // them.
// Mask hidden state tensor in order to handle mixed sequence lengths.
// This results in zeroing out values in batches with sequence shorter
// than current time_step.
h_list.push_back(
get_masked_node(reshape::expand_dims(H), time_step, 1));
// Reference implementation in ONNX Runtime doesn't mask values of Y_h
// and Y_c outputs, thus here we make sure that only appropriate batches
// (in respect to its sequence length) are updated. Those batches which
// has shorter sequences preserve the last value.
H_t = get_masked_node(H, time_step, 0, H_t);
C_t = get_masked_node(C, time_step, 0, C_t);
time_step++;
} }
// The tensor that concats all the intermediate output values of the hidden. // The tensor that concats all the intermediate output values of the hidden.
// It has shape [seq_length, batch_size, hidden_size] // It has shape [seq_length, batch_size, hidden_size]
NodeVector exp_h_list;
for (const auto& ht : h_list)
{
// Expand tensors with empty outermost dim, so we can later concatenate them.
exp_h_list.push_back(reshape::expand_dims(ht));
}
std::shared_ptr<ngraph::Node> Y{ std::shared_ptr<ngraph::Node> Y{
std::make_shared<ngraph::op::Concat>(exp_h_list, 0)}; std::make_shared<ngraph::op::Concat>(h_list, 0)};
// Get back the original order of the output data. // Get back the original order of the output data.
if (reverse) if (reverse)
...@@ -449,13 +457,68 @@ namespace ngraph ...@@ -449,13 +457,68 @@ namespace ngraph
// [seq_length, num_directions, batch_size, hidden_size] // [seq_length, num_directions, batch_size, hidden_size]
Y = reshape::expand_dims(Y, 1); Y = reshape::expand_dims(Y, 1);
// expand C_t so that it has expected shape: // expand H_t and C_t so that it has expected shape:
// [num_directions, batch_size, hidden_size] // [num_directions, batch_size, hidden_size]
auto Y_h = reshape::expand_dims(H_t);
auto Y_c = reshape::expand_dims(C_t); auto Y_c = reshape::expand_dims(C_t);
return {Y, exp_h_list.back(), Y_c}; return {Y, Y_h, Y_c};
} }
private: private:
///
/// \brief Gets the masked node according to sequence lenght in a batch.
///
/// \note Zeros out values or sets them to default value for inputs with
/// sequence lenght shorter than currently procssed time step.
///
/// \param[in] data The input node.
/// \param[in] time_step The current time step denoting sequence lenght.
/// \param[in] batch_axis The batch axis index of data tensor.
/// \param[in] default_value The default value for masked elements.
///
/// \return The masked node.
///
std::shared_ptr<ngraph::Node> get_masked_node(
const std::shared_ptr<ngraph::Node>& data,
std::int32_t time_step,
std::size_t batch_axis = 0,
const std::shared_ptr<ngraph::Node>& default_value = {nullptr})
{
std::shared_ptr<ngraph::Node> mask_value = default_value;
// Create zero mask value node.
if (!mask_value)
{
mask_value = ngraph::op::Constant::create(
data->get_element_type(),
data->get_shape(),
std::vector<float>(shape_size(data->get_shape()), 0.f));
}
// Create predicate nodes. The condition is whether current time step value
// is greater than sequence length for respective batch inputs.
std::shared_ptr<ngraph::Node> curr_time_step_node =
ngraph::op::Constant::create(
element::i32,
data->get_shape(),
std::vector<std::int32_t>(shape_size(data->get_shape()),
time_step));
std::shared_ptr<ngraph::Node> batch_seq_length =
legacy_style_broadcast_for_binary_operation(
curr_time_step_node, m_seq_lengths, batch_axis)
.at(1);
// Create mask node deciding whether or not to mask batch data.
std::shared_ptr<ngraph::Node> mask_condition =
std::make_shared<ngraph::op::Greater>(curr_time_step_node,
batch_seq_length);
// Select values depnding on mask_condition.
// Select(<condition>, <true_value>, <false_value>)
return std::make_shared<ngraph::op::Select>(
mask_condition, mask_value, data);
}
std::shared_ptr<ngraph::Node> m_X; std::shared_ptr<ngraph::Node> m_X;
std::shared_ptr<ngraph::Node> m_W; std::shared_ptr<ngraph::Node> m_W;
std::shared_ptr<ngraph::Node> m_R; std::shared_ptr<ngraph::Node> m_R;
......
ir_version: 4
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
input: "B"
input: "sequence_lens"
input: ""
input: ""
input: ""
output: "Y"
output: "Y_h"
output: "Y_c"
op_type: "LSTM"
attribute {
name: "clip"
f: 9999.0
type: FLOAT
}
attribute {
name: "direction"
s: "forward"
type: STRING
}
attribute {
name: "hidden_size"
i: 3
type: INT
}
attribute {
name: "input_forget"
i: 0
type: INT
}
}
name: "compute_graph"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 12
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 12
}
dim {
dim_value: 3
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 24
}
}
}
}
}
input {
name: "sequence_lens"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "Y_c"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 7
}
...@@ -2227,6 +2227,101 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, import_non_existing_file) ...@@ -2227,6 +2227,101 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, import_non_existing_file)
} }
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/lstm_fwd_mixed_seq.prototxt"));
int hidden_size{3};
int parameters_cout{5};
// X
std::vector<float> in_x{1.f, 2.f, 10.f, 11.f};
// W
std::vector<float> in_w{0.1f, 0.2f, 0.3f, 0.4f, 1.f, 2.f, 3.f, 4.f, 10.f, 11.f, 12.f, 13.f};
// R
std::vector<float> in_r(4 * hidden_size * hidden_size, 0.1f);
// B
std::vector<float> in_b(8 * hidden_size, 0.0f);
std::vector<int> in_seq_lengths{1, 2};
std::vector<float> out_y_data{0.28828835f,
0.36581863f,
0.45679406f,
0.34526032f,
0.47220859f,
0.55850911f,
0.f,
0.f,
0.f,
0.85882828f,
0.90703777f,
0.92382453f};
std::vector<float> out_y_h_data{
0.28828835f, 0.36581863f, 0.45679406f, 0.85882828f, 0.90703777f, 0.92382453f};
std::vector<float> out_y_c_data{
0.52497941f, 0.54983425f, 0.5744428f, 1.3249796f, 1.51063104f, 1.61451544f};
Outputs expected_output;
expected_output.emplace_back(out_y_data);
expected_output.emplace_back(out_y_h_data);
expected_output.emplace_back(out_y_c_data);
auto backend = ngraph::runtime::Backend::create("${BACKEND_NAME}");
auto parameters = function->get_parameters();
EXPECT_TRUE(parameters.size() == parameters_cout);
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> arg_tensors;
auto add_tensor = [&arg_tensors, &backend](const std::vector<float>& v,
const std::shared_ptr<ngraph::op::Parameter>& p) {
auto t = backend->create_tensor(p->get_element_type(), p->get_shape());
copy_data(t, v);
arg_tensors.push_back(t);
};
add_tensor(in_x, parameters.at(0));
add_tensor(in_w, parameters.at(1));
add_tensor(in_r, parameters.at(2));
add_tensor(in_b, parameters.at(3));
auto t_in_seq_lengths =
backend->create_tensor(parameters.at(4)->get_element_type(), parameters.at(4)->get_shape());
copy_data(t_in_seq_lengths, in_seq_lengths);
arg_tensors.push_back(t_in_seq_lengths);
auto results = function->get_results();
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> result_tensors(results.size());
for (std::size_t i{0}; i < results.size(); ++i)
{
result_tensors.at(i) =
backend->create_tensor(results.at(i)->get_element_type(), results.at(i)->get_shape());
}
auto handle = backend->compile(function);
handle->call_with_validate(result_tensors, arg_tensors);
Outputs outputs;
for (auto rt : result_tensors)
{
outputs.push_back(read_vector<float>(rt));
}
EXPECT_TRUE(outputs.size() == expected_output.size());
for (std::size_t i{0}; i < expected_output.size(); ++i)
{
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position.
EXPECT_TRUE(test::all_close_f(expected_output.at(i), outputs.at(i), 3));
}
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_quantize_linear) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_quantize_linear)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_model(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment