//***************************************************************************** // Copyright 2017-2019 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //***************************************************************************** #include "gtest/gtest.h" #include "ngraph/ngraph.hpp" #include "util/all_close.hpp" #include "util/all_close_f.hpp" #include "util/known_element_types.hpp" #include "util/ndarray.hpp" #include "util/random.hpp" #include "util/test_control.hpp" #include "util/test_tools.hpp" using namespace std; using namespace ngraph; static string s_manifest = "${MANIFEST}"; // This test operates against the INTERPRETER backend as a reference, so it is // disabled if INTERPRETER is disabled. #if NGRAPH_INTERPRETER_ENABLE NGRAPH_TEST(${BACKEND_NAME}, batch_mat_mul_forward) { auto make_dot = [](ParameterVector& a_params, ParameterVector& b_params) { Shape shape_a{2, 3}; Shape shape_b{3, 2}; auto A = make_shared<op::Parameter>(element::f32, shape_a); auto B = make_shared<op::Parameter>(element::f32, shape_b); a_params.push_back(A); b_params.push_back(B); return make_shared<op::Dot>(A, B); }; ParameterVector dot_a_params; ParameterVector dot_b_params; auto dot1 = make_dot(dot_a_params, dot_b_params); auto dot2 = make_dot(dot_a_params, dot_b_params); auto dot3 = make_dot(dot_a_params, dot_b_params); auto dot_concat = make_shared<op::Concat>(NodeVector{dot1, dot2, dot3}, 0); ParameterVector dot_params(dot_a_params); dot_params.insert(dot_params.end(), dot_b_params.begin(), dot_b_params.end()); auto ref_f = make_shared<Function>(dot_concat, dot_params); auto make_batchmatmul = [](ParameterVector& params) { Shape shape_a{3, 2, 3}; Shape shape_b{3, 3, 2}; auto A = make_shared<op::Parameter>(element::f32, shape_a); auto B = make_shared<op::Parameter>(element::f32, shape_b); params.push_back(A); params.push_back(B); return make_shared<op::BatchMatMul>(A, B); }; ParameterVector batchmatmul_params; auto batchmatmul = make_batchmatmul(batchmatmul_params); auto backend_f = make_shared<Function>(batchmatmul, batchmatmul_params); test::Uniform<float> dot_rng(-1.0f, 1.0f); vector<vector<float>> dot_args; for (shared_ptr<op::Parameter> param : dot_params) { vector<float> tensor_val(shape_size(param->get_shape())); dot_rng.initialize(tensor_val); dot_args.push_back(tensor_val); } test::Uniform<float> batchmatmul_rng(-1.0f, 1.0f); vector<vector<float>> batchmatmul_args; for (shared_ptr<op::Parameter> param : batchmatmul_params) { vector<float> tensor_val(shape_size(param->get_shape())); batchmatmul_rng.initialize(tensor_val); batchmatmul_args.push_back(tensor_val); } auto ref_results = execute(ref_f, dot_args, "INTERPRETER"); auto backend_results = execute(backend_f, batchmatmul_args, "${BACKEND_NAME}"); for (size_t i = 0; i < ref_results.size(); i++) { EXPECT_TRUE(test::all_close_f( ref_results.at(i), backend_results.at(i), DEFAULT_FLOAT_TOLERANCE_BITS + 3)); } } #endif