Commit 08ca928a authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

[ONNX] Importer should use fused op for MatMul (#3842)

* [ONNX] Importer should use fused op for MatMul

* Fix a bug in fused matmul op

* Dont reshape matmul inputs to at least 2D any more
parent 2f69f86c
......@@ -122,7 +122,6 @@ add_library(onnx_import STATIC
op/lrn.hpp
op/lstm.cpp
op/lstm.hpp
op/matmul.cpp
op/matmul.hpp
op/matmul_integer.cpp
op/matmul_integer.hpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "matmul.hpp"
#include "ngraph/builder/matmul_factory.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector matmul(const Node& node)
{
auto ng_inputs = node.get_ng_inputs();
auto factory = builder::MatmulFactory(
(OutputVector(std::begin(ng_inputs), std::end(ng_inputs))));
std::size_t left_rank{ng_inputs.at(0)->get_shape().size()};
std::size_t right_rank{ng_inputs.at(1)->get_shape().size()};
if (left_rank == 0 || right_rank == 0)
{
NGRAPH_WARN
<< (node) << " "
<< "ONNX standard doesn't allow scalar operands, however nGraph "
"accepts them. Consider use of element-wise multiplication instead "
"to conform with ONNX standard.";
}
return factory.make_matmul_op();
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -18,6 +18,7 @@
#include "core/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/fused/matmul.hpp"
namespace ngraph
{
......@@ -27,7 +28,11 @@ namespace ngraph
{
namespace set_1
{
NodeVector matmul(const Node& node);
NodeVector matmul(const Node& node)
{
return {std::make_shared<ngraph::op::MatMul>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1))};
}
} // namespace set_1
} // namespace op
......
......@@ -64,25 +64,10 @@ NodeVector op::MatMul::decompose_op() const
auto A = input_value(0);
auto B = input_value(1);
// Specification is expecting that A & B have at least 2 dimenstions.
// Missing dimensions are padded with 1.
int a_rank = A.get_shape().size();
if (a_rank < 2)
{
A = a_rank == 0 ? make_shared<op::Reshape>(A, AxisVector{}, Shape{1, 1})
: make_shared<op::Reshape>(A, AxisVector{1}, Shape{1, A.get_shape()[0]});
a_rank = 2;
}
int b_rank = B.get_shape().size();
if (b_rank < 2)
{
B = b_rank == 0 ? make_shared<op::Reshape>(B, AxisVector{}, Shape{1, 1})
: make_shared<op::Reshape>(B, AxisVector{1}, Shape{1, B.get_shape()[0]});
b_rank = 2;
}
const auto a_rank = A.get_shape().size();
const auto b_rank = B.get_shape().size();
if (m_transpose_a)
if (m_transpose_a && a_rank >= 2)
{
vector<size_t> axes_order(a_rank);
// generate default axes_order.
......@@ -92,7 +77,7 @@ NodeVector op::MatMul::decompose_op() const
A = builder::reorder_axes(A, axes_order);
}
if (m_transpose_b)
if (m_transpose_b && b_rank >= 2)
{
vector<size_t> axes_order(b_rank);
iota(axes_order.begin(), axes_order.end(), 0);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment