Commit 4bd05b6f authored by tsocha's avatar tsocha Committed by Robert Kimball

[ONNX] MatMul operator (#1493)

* [ONNX] MatMul operator

* Add NL on EOF

* Review fix pt. 1
parent 497e5d39
...@@ -39,6 +39,7 @@ add_library(onnx_import STATIC ...@@ -39,6 +39,7 @@ add_library(onnx_import STATIC
op/conv.cpp op/conv.cpp
op/gemm.cpp op/gemm.cpp
op/gemm.hpp op/gemm.hpp
op/matmul.hpp
op/mul.hpp op/mul.hpp
op/relu.hpp op/relu.hpp
op/split.cpp op/split.cpp
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/dot.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector matmul(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Dot>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "op/constant.hpp" #include "op/constant.hpp"
#include "op/conv.hpp" #include "op/conv.hpp"
#include "op/gemm.hpp" #include "op/gemm.hpp"
#include "op/matmul.hpp"
#include "op/mul.hpp" #include "op/mul.hpp"
#include "op/relu.hpp" #include "op/relu.hpp"
#include "op/split.hpp" #include "op/split.hpp"
...@@ -76,6 +77,7 @@ namespace ngraph ...@@ -76,6 +77,7 @@ namespace ngraph
m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1)); m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1));
m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1)); m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1));
m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1)); m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1));
m_map.emplace("MatMul", std::bind(op::matmul, std::placeholders::_1));
m_map.emplace("Mul", std::bind(op::mul, std::placeholders::_1)); m_map.emplace("Mul", std::bind(op::mul, std::placeholders::_1));
m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1)); m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1));
m_map.emplace("Split", std::bind(op::split, std::placeholders::_1)); m_map.emplace("Split", std::bind(op::split, std::placeholders::_1));
......
 backend-test:b

a
bc"MatMultest_matmul_2dZ
a


Z
b


b
c


B
\ No newline at end of file
...@@ -148,7 +148,7 @@ namespace ...@@ -148,7 +148,7 @@ namespace
} }
} // namespace } // namespace
TEST(onnx, mode_conv2d_strides_padding) TEST(onnx, model_conv2d_strides_padding)
{ {
// Convolution with strides=2 and padding=1 // Convolution with strides=2 and padding=1
auto function = ngraph::onnx_import::import_onnx_function( auto function = ngraph::onnx_import::import_onnx_function(
...@@ -268,3 +268,24 @@ TEST(onnx, model_gemm_abc) ...@@ -268,3 +268,24 @@ TEST(onnx, model_gemm_abc)
auto result_vectors = execute(function, inputs, "INTERPRETER"); auto result_vectors = execute(function, inputs, "INTERPRETER");
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front())); EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
} }
TEST(onnx, model_matmul)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/matmul.onnx"));
std::vector<std::vector<float>> inputs;
inputs.emplace_back(
test::NDArray<float, 2>({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}).get_vector());
inputs.emplace_back(
test::NDArray<float, 2>({{13, 14, 15}, {16, 17, 18}, {19, 20, 21}, {22, 23, 24}})
.get_vector());
auto expected_output =
test::NDArray<float, 2>({{190, 200, 210}, {470, 496, 522}, {750, 792, 834}}).get_vector();
auto result_vectors = execute(function, inputs, "INTERPRETER");
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment