Commit ecbe0042 authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

[ONNX] Refactor LSTM tests to use NgraphTestCase (#2827)

* [ONNX] Refactor LSTM tests to use NgraphTestCase

* Enable passing instance values to comparator

* Review comments
parent a6df47e3
......@@ -332,6 +332,7 @@ namespace ngraph
REGISTER_OPERATOR("Xor", 1, logical_xor);
}
#undef REGISTER_OPERATOR
} // namespace onnx_import
} // namespace ngraph
......@@ -30,6 +30,7 @@
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
......@@ -37,92 +38,73 @@ using namespace ngraph;
static std::string s_manifest = "${MANIFEST}";
using Inputs = std::vector<std::vector<float>>;
using Outputs = std::vector<std::vector<float>>;
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_with_clip)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/lstm_fwd_with_clip.prototxt"));
Inputs inputs{};
// X
inputs.emplace_back(std::vector<float>{-0.455351, -0.276391, -0.185934, -0.269585});
// W
inputs.emplace_back(std::vector<float>{-0.494659f,
0.0453352f,
-0.487793f,
0.417264f,
-0.0175329f,
0.489074f,
-0.446013f,
0.414029f,
-0.0091708f,
-0.255364f,
-0.106952f,
-0.266717f,
-0.0888852f,
-0.428709f,
-0.283349f,
0.208792f});
// R
inputs.emplace_back(std::vector<float>{0.146626f,
-0.0620289f,
-0.0815302f,
0.100482f,
-0.219535f,
-0.306635f,
-0.28515f,
-0.314112f,
-0.228172f,
0.405972f,
0.31576f,
0.281487f,
-0.394864f,
0.42111f,
-0.386624f,
-0.390225f});
// B
inputs.emplace_back(std::vector<float>{0.381619f,
0.0323954f,
-0.14449f,
0.420804f,
-0.258721f,
0.45056f,
-0.250755f,
0.0967895f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f});
// P
inputs.emplace_back(std::vector<float>{0.2345f, 0.5235f, 0.4378f, 0.3475f, 0.8927f, 0.3456f});
Outputs expected_output{};
// Y_data
expected_output.emplace_back(
std::vector<float>{-0.02280854f, 0.02744377f, -0.03516197f, 0.03875681f});
// Y_h_data
expected_output.emplace_back(std::vector<float>{-0.03516197f, 0.03875681f});
// Y_c_data
expected_output.emplace_back(std::vector<float>{-0.07415761f, 0.07395997f});
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(outputs.size() == expected_output.size());
for (std::size_t i{0}; i < expected_output.size(); ++i)
{
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position.
EXPECT_TRUE(test::all_close_f(expected_output.at(i), outputs.at(i), 3));
}
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>({-0.455351, -0.276391, -0.185934, -0.269585}); // X
test_case.add_input<float>({-0.494659f, // W
0.0453352f,
-0.487793f,
0.417264f,
-0.0175329f,
0.489074f,
-0.446013f,
0.414029f,
-0.0091708f,
-0.255364f,
-0.106952f,
-0.266717f,
-0.0888852f,
-0.428709f,
-0.283349f,
0.208792f});
test_case.add_input<float>({0.146626f, // R
-0.0620289f,
-0.0815302f,
0.100482f,
-0.219535f,
-0.306635f,
-0.28515f,
-0.314112f,
-0.228172f,
0.405972f,
0.31576f,
0.281487f,
-0.394864f,
0.42111f,
-0.386624f,
-0.390225f});
test_case.add_input<float>({0.381619f, // B
0.0323954f,
-0.14449f,
0.420804f,
-0.258721f,
0.45056f,
-0.250755f,
0.0967895f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f});
test_case.add_input<float>({0.2345f, 0.5235f, 0.4378f, 0.3475f, 0.8927f, 0.3456f}); // P
test_case.add_expected_output<float>(
Shape{2, 1, 1, 2}, {-0.02280854f, 0.02744377f, -0.03516197f, 0.03875681f}); // Y_data
test_case.add_expected_output<float>(Shape{1, 1, 2}, {-0.03516197f, 0.03875681f}); // Y_h_data
test_case.add_expected_output<float>(Shape{1, 1, 2}, {-0.07415761f, 0.07395997f}); // Y_c_data
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position.
test_case.set_tolerance(3);
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq)
......@@ -130,92 +112,37 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq)
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/lstm_fwd_mixed_seq.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
int hidden_size{3};
int parameters_cout{5};
// X
std::vector<float> in_x{1.f, 2.f, 10.f, 11.f};
// W
std::vector<float> in_w{0.1f, 0.2f, 0.3f, 0.4f, 1.f, 2.f, 3.f, 4.f, 10.f, 11.f, 12.f, 13.f};
// R
std::vector<float> in_r(4 * hidden_size * hidden_size, 0.1f);
// B
std::vector<float> in_b(8 * hidden_size, 0.0f);
std::vector<int> in_seq_lengths{1, 2};
std::vector<float> out_y_data{0.28828835f,
0.36581863f,
0.45679406f,
0.34526032f,
0.47220859f,
0.55850911f,
0.f,
0.f,
0.f,
0.85882828f,
0.90703777f,
0.92382453f};
std::vector<float> out_y_h_data{
0.28828835f, 0.36581863f, 0.45679406f, 0.85882828f, 0.90703777f, 0.92382453f};
std::vector<float> out_y_c_data{
0.52497941f, 0.54983425f, 0.5744428f, 1.3249796f, 1.51063104f, 1.61451544f};
Outputs expected_output;
expected_output.emplace_back(out_y_data);
expected_output.emplace_back(out_y_h_data);
expected_output.emplace_back(out_y_c_data);
auto backend = ngraph::runtime::Backend::create("${BACKEND_NAME}");
auto parameters = function->get_parameters();
EXPECT_TRUE(parameters.size() == parameters_cout);
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> arg_tensors;
auto add_tensor = [&arg_tensors, &backend](const std::vector<float>& v,
const std::shared_ptr<ngraph::op::Parameter>& p) {
auto t = backend->create_tensor(p->get_element_type(), p->get_shape());
copy_data(t, v);
arg_tensors.push_back(t);
};
add_tensor(in_x, parameters.at(0));
add_tensor(in_w, parameters.at(1));
add_tensor(in_r, parameters.at(2));
add_tensor(in_b, parameters.at(3));
auto t_in_seq_lengths =
backend->create_tensor(parameters.at(4)->get_element_type(), parameters.at(4)->get_shape());
copy_data(t_in_seq_lengths, in_seq_lengths);
arg_tensors.push_back(t_in_seq_lengths);
auto results = function->get_results();
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> result_tensors(results.size());
for (std::size_t i{0}; i < results.size(); ++i)
{
result_tensors.at(i) =
backend->create_tensor(results.at(i)->get_element_type(), results.at(i)->get_shape());
}
auto handle = backend->compile(function);
handle->call_with_validate(result_tensors, arg_tensors);
Outputs outputs;
for (auto rt : result_tensors)
{
outputs.push_back(read_vector<float>(rt));
}
EXPECT_TRUE(outputs.size() == expected_output.size());
for (std::size_t i{0}; i < expected_output.size(); ++i)
{
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position.
EXPECT_TRUE(test::all_close_f(expected_output.at(i), outputs.at(i), 3));
}
test_case.add_input<float>({1.f, 2.f, 10.f, 11.f}); // X
test_case.add_input<float>(
{0.1f, 0.2f, 0.3f, 0.4f, 1.f, 2.f, 3.f, 4.f, 10.f, 11.f, 12.f, 13.f}); // W
test_case.add_input(std::vector<float>(4 * hidden_size * hidden_size, 0.1f)); // R
test_case.add_input(std::vector<float>(8 * hidden_size, 0.0f)); // B
test_case.add_input<int>({1, 2}); // seq_lengths
test_case.add_expected_output<float>(Shape{2, 1, 2, 3},
{0.28828835f, // Y_data
0.36581863f,
0.45679406f,
0.34526032f,
0.47220859f,
0.55850911f,
0.f,
0.f,
0.f,
0.85882828f,
0.90703777f,
0.92382453f});
test_case.add_expected_output<float>(
Shape{1, 2, 3},
{0.28828835f, 0.36581863f, 0.45679406f, 0.85882828f, 0.90703777f, 0.92382453f}); // Y_h_data
test_case.add_expected_output<float>(
Shape{1, 2, 3},
{0.52497941f, 0.54983425f, 0.5744428f, 1.3249796f, 1.51063104f, 1.61451544f}); // Y_c_data
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position.
test_case.set_tolerance(3);
test_case.run();
}
......@@ -19,19 +19,6 @@
#include "gtest/gtest.h"
#include "ngraph/assertion.hpp"
std::map<ngraph::element::Type_t, ngraph::test::NgraphTestCase::value_comparator_function>
ngraph::test::NgraphTestCase::m_value_comparators = {
{ngraph::element::Type_t::f32, NgraphTestCase::compare_values<float>},
{ngraph::element::Type_t::f64, NgraphTestCase::compare_values<double>},
{ngraph::element::Type_t::i8, NgraphTestCase::compare_values<int8_t>},
{ngraph::element::Type_t::i16, NgraphTestCase::compare_values<int16_t>},
{ngraph::element::Type_t::i32, NgraphTestCase::compare_values<int32_t>},
{ngraph::element::Type_t::i64, NgraphTestCase::compare_values<int64_t>},
{ngraph::element::Type_t::u8, NgraphTestCase::compare_values<uint8_t>},
{ngraph::element::Type_t::u16, NgraphTestCase::compare_values<uint16_t>},
{ngraph::element::Type_t::u32, NgraphTestCase::compare_values<uint32_t>},
{ngraph::element::Type_t::u64, NgraphTestCase::compare_values<uint64_t>}};
void ngraph::test::NgraphTestCase::run()
{
const auto& function_results = m_function->get_results();
......
......@@ -38,6 +38,7 @@ namespace ngraph
{
}
void set_tolerance(int tolerance_bits) { m_tolerance_bits = tolerance_bits; }
template <typename T>
void add_input(const std::vector<T>& values)
{
......@@ -122,21 +123,20 @@ namespace ngraph
}
void run();
protected:
private:
template <typename T>
static typename std::enable_if<std::is_floating_point<T>::value,
::testing::AssertionResult>::type
typename std::enable_if<std::is_floating_point<T>::value,
::testing::AssertionResult>::type
compare_values(const std::shared_ptr<ngraph::op::Constant>& expected_results,
const std::shared_ptr<ngraph::runtime::Tensor>& results)
{
const auto expected = expected_results->get_vector<T>();
const auto result = read_vector<T>(results);
return ngraph::test::all_close_f(expected, result);
return ngraph::test::all_close_f(expected, result, m_tolerance_bits);
}
template <typename T>
static typename std::enable_if<std::is_integral<T>::value,
::testing::AssertionResult>::type
typename std::enable_if<std::is_integral<T>::value, ::testing::AssertionResult>::type
compare_values(const std::shared_ptr<ngraph::op::Constant>& expected_results,
const std::shared_ptr<ngraph::runtime::Tensor>& results)
{
......@@ -149,6 +149,33 @@ namespace ngraph
const std::shared_ptr<ngraph::op::Constant>&,
const std::shared_ptr<ngraph::runtime::Tensor>&)>;
#define REGISTER_COMPARATOR(element_type_, type_) \
{ \
ngraph::element::Type_t::element_type_, std::bind(&NgraphTestCase::compare_values<type_>, \
this, \
std::placeholders::_1, \
std::placeholders::_2) \
}
std::map<ngraph::element::Type_t,
std::function<::testing::AssertionResult(
const std::shared_ptr<ngraph::op::Constant>&,
const std::shared_ptr<ngraph::runtime::Tensor>&)>>
m_value_comparators = {
REGISTER_COMPARATOR(f32, float),
REGISTER_COMPARATOR(f64, double),
REGISTER_COMPARATOR(i8, int8_t),
REGISTER_COMPARATOR(i16, int16_t),
REGISTER_COMPARATOR(i32, int32_t),
REGISTER_COMPARATOR(i64, int64_t),
REGISTER_COMPARATOR(u8, uint8_t),
REGISTER_COMPARATOR(u16, uint16_t),
REGISTER_COMPARATOR(u32, uint32_t),
REGISTER_COMPARATOR(u64, uint64_t),
};
#undef REGISTER_COMPARATOR
protected:
std::shared_ptr<Function> m_function;
std::shared_ptr<runtime::Backend> m_backend;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors;
......@@ -156,7 +183,7 @@ namespace ngraph
std::vector<std::shared_ptr<ngraph::op::Constant>> m_expected_outputs;
int m_input_index = 0;
int m_output_index = 0;
static std::map<ngraph::element::Type_t, value_comparator_function> m_value_comparators;
int m_tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS;
};
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment