Commit 54bbb154 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Add BatchMatMul unit test to all builds (#2996)

* add BatchMatMul unit test to all builds

* Check that interpreter is available

* fix error

* enable tests for features not yet implemented
parent 308dc966
......@@ -136,6 +136,8 @@ create_tensor_2_input
create_tensor_2_output
# Not implemented
batch_mat_mul_forward
backwards_batchmatmul_tensor2_tensor2
erf
zero_sized_erf
model_erf
......
......@@ -40,6 +40,8 @@ pad_reflect_2d
pad_reflect_2d_with_neg
# Not implemented
batch_mat_mul_forward
backwards_batchmatmul_tensor2_tensor2
erf
gather_no_axis
gather
......
......@@ -855,8 +855,6 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_dot_tensor3_tensor3)
EXPECT_TRUE(autodiff_numeric_compare<float>(backend.get(), make_graph, {x0, x1}, .01f, .01f));
}
#if defined(AUTODIFF_BACKEND_CPU) || defined(AUTODIFF_BACKEND_INTERPRETER)
// XXX lfeng: remove backend check once all backends support this
NGRAPH_TEST(${BACKEND_NAME}, backwards_batchmatmul_tensor2_tensor2)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
......@@ -875,31 +873,6 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_batchmatmul_tensor2_tensor2)
EXPECT_TRUE(autodiff_numeric_compare<float>(backend.get(), make_graph, {x0, x1}, .01f, .01f));
}
#endif
#if defined(AUTODIFF_BACKEND_CPU) && defined(NGRAPH_JSON_ENABLE)
NGRAPH_TEST(${BACKEND_NAME}, backwards_batchmatmultranspose_tensor2_tensor2)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
std::string backend_name = "${BACKEND_NAME}";
const std::string file_name("mxnet/batch_dot_3.json");
auto f = make_function_from_file(file_name);
test::Uniform<float> rng(-1.0f, 1.0f);
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> args;
for (shared_ptr<op::Parameter> param : f->get_parameters())
{
args.push_back(rng.initialize(backend->create_tensor<float>(param->get_shape())));
}
auto g = make_function_from_file(file_name);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.run_passes(g);
EXPECT_TRUE(autodiff_numeric_compare<float>(backend.get(), f, g, args, .01f, .01f));
}
#endif
NGRAPH_TEST(${BACKEND_NAME}, backwards_exp)
{
......
......@@ -37,10 +37,6 @@
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
// clang-format off
#define BACKEND_TEST_${BACKEND_NAME}
// clang-format on
using namespace std;
using namespace ngraph;
......@@ -7413,9 +7409,8 @@ NGRAPH_TEST(${BACKEND_NAME}, quantize_dynamic_offset)
read_vector<output_c_type>(y));
}
#if defined(BACKEND_TEST_CPU) || defined(BACKEND_TEST_INTERPRETER)
// XXX lfeng: remove backend check once all backends support this
TEST(${BACKEND_NAME}, batch_mat_mul_forward)
#if NGRAPH_INTERPRETER_ENABLE
NGRAPH_TEST(${BACKEND_NAME}, batch_mat_mul_forward)
{
auto make_dot = [](ParameterVector& a_params, ParameterVector& b_params) {
Shape shape_a{2, 3};
......@@ -7475,14 +7470,7 @@ TEST(${BACKEND_NAME}, batch_mat_mul_forward)
EXPECT_TRUE(test::all_close(ref_results.at(i), backend_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
#endif
// clang-format off
#ifdef BACKEND_TEST_${BACKEND_NAME}
#undef BACKEND_TEST_${BACKEND_NAME}
#endif
// clang-format on
NGRAPH_TEST(${BACKEND_NAME}, validate_function_for_dynamic_shape)
{
......
......@@ -3972,4 +3972,28 @@ TEST(cpu_fusion, validate_fuse_gru_inputs)
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
#if defined(AUTODIFF_BACKEND_CPU) && defined(NGRAPH_JSON_ENABLE)
NGRAPH_TEST(cpu_fusion, backwards_batchmatmultranspose_tensor2_tensor2)
{
auto backend = runtime::Backend::create("CPU");
const std::string file_name("mxnet/batch_dot_3.json");
auto f = make_function_from_file(file_name);
test::Uniform<float> rng(-1.0f, 1.0f);
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> args;
for (shared_ptr<op::Parameter> param : f->get_parameters())
{
args.push_back(rng.initialize(backend->create_tensor<float>(param->get_shape())));
}
auto g = make_function_from_file(file_name);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.run_passes(g);
EXPECT_TRUE(autodiff_numeric_compare<float>(backend.get(), f, g, args, .01f, .01f));
}
#endif
#endif
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment