Commit 77fb55ff authored by Adam Rogowiec's avatar Adam Rogowiec

Add UT for lstm model with large batch with clipping.

parent 22c4f3fb
ir_version: 4
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
output: ""
output: "Y_h"
op_type: "LSTM"
attribute {
name: "clip"
f: 4.0
type: FLOAT
}
attribute {
name: "direction"
s: "forward"
type: STRING
}
attribute {
name: "hidden_size"
i: 3
type: INT
}
attribute {
name: "input_forget"
i: 0
type: INT
}
}
name: "compute_graph"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 32
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 12
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 12
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 32
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 7
}
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <fstream> #include <fstream>
#include <iterator> #include <iterator>
#include <limits> #include <limits>
#include <numeric>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
...@@ -203,3 +204,66 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation) ...@@ -203,3 +204,66 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation)
test_case.set_tolerance(6); test_case.set_tolerance(6);
test_case.run(); test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_large_batch_with_clip)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/lstm_fwd_large_batch_with_clip.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
std::size_t seq_length = 2;
std::size_t batch_size = 32;
std::size_t input_size = 1;
std::size_t hidden_size = 3;
std::vector<float> in_X(seq_length * batch_size * input_size);
std::iota(std::begin(in_X), std::end(in_X), 1.f);
std::vector<float> in_R(4 * hidden_size * hidden_size, 0.1f);
// X
test_case.add_input<float>(in_X);
// W
test_case.add_input<float>(
{0.1f, 0.2f, 0.3f, 0.4f, 1.f, 2.f, 3.f, 4.f, 10.f, 11.f, 12.f, 13.f});
// R
test_case.add_input<float>(in_R);
// Y_h_data
test_case.add_expected_output<float>(
Shape{1, batch_size, hidden_size},
{0.88572926f, 0.89251395f, 0.89655037f,
0.89074291f, 0.90035688f, 0.90727429f,
0.89535827f, 0.90727429f, 0.91596163f,
0.89963124f, 0.91328279f, 0.9228067f,
0.90358195f, 0.91843507f, 0.92809163f,
0.90723279f, 0.9228067f, 0.93211437f,
0.91038955f, 0.92648469f, 0.93514718f,
0.91328279f, 0.92955856f, 0.93741938f,
0.91596163f, 0.93211437f, 0.9391149f,
0.91843507f, 0.93423112f, 0.94037686f,
0.92071318f, 0.9359791f, 0.94131462f,
0.9228067f, 0.93741938f, 0.94201073f,
0.92472679f, 0.9386042f, 0.94252713f,
0.92648469f, 0.9395777f, 0.94266769f,
0.92809163f, 0.94037686f, 0.94266769f,
0.92955856f, 0.94103248f, 0.94266769f,
0.93089609f, 0.94157007f, 0.94266769f,
0.93211437f, 0.94201073f, 0.94266769f,
0.93322302f, 0.94237184f, 0.94266769f,
0.93423112f, 0.94266769f, 0.94266769f,
0.93514718f, 0.94266769f, 0.94266769f,
0.9359791f, 0.94266769f, 0.94266769f,
0.93673424f, 0.94266769f, 0.94266769f,
0.93741938f, 0.94266769f, 0.94266769f,
0.93804079f, 0.94266769f, 0.94266769f,
0.9386042f, 0.94266769f, 0.94266769f,
0.9391149f, 0.94266769f, 0.94266769f,
0.9395777f, 0.94266769f, 0.94266769f,
0.93999702f, 0.94266769f, 0.94266769f,
0.94037686f, 0.94266769f, 0.94266769f,
0.94072091f, 0.94266769f, 0.94266769f,
0.94103248f, 0.94266769f, 0.94266769f});
test_case.run();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment