Commit 52043f64 authored by tsocha's avatar tsocha Committed by Scott Cyphers

[ONNX] Gemm operator (#1465)

* Enable Mul OP

* Reshape, broadcasting utils and Gemm op

* Style check

* Review fix pt. 1

* Review fix pt. 2

* Reuse documentation
parent 69ee7e7a
# ******************************************************************************
# Copyright (c) 2017-2018 Intel Copyright
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
......@@ -50,6 +50,9 @@ add_library(onnx_import STATIC
op/constant.cpp
op/constant.hpp
op/conv.cpp
op/gemm.cpp
op/gemm.hpp
op/mul.hpp
op/relu.hpp
op/split.cpp
op/split.hpp
......@@ -57,7 +60,9 @@ add_library(onnx_import STATIC
utils/broadcasting.cpp
utils/broadcasting.hpp
utils/convpool.cpp
utils/convpool.hpp)
utils/convpool.hpp
utils/reshape.cpp
utils/reshape.hpp)
add_dependencies(onnx_import onnx_import_interface)
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "op/gemm.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/utils/broadcasting.hpp"
#include "ngraph/frontend/onnx_import/utils/reshape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector gemm(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
auto input_a = inputs.at(0);
auto input_b = inputs.at(1);
auto input_c = inputs.at(2);
double alpha{node.get_attribute_value<double>("alpha", 1)};
double beta{node.get_attribute_value<double>("beta", 1)};
auto trans_a{node.get_attribute_value<int64_t>("transA", 0)};
auto trans_b{node.get_attribute_value<int64_t>("transB", 0)};
if (trans_a)
{
input_a = transpose(input_a);
}
if (trans_b)
{
input_b = transpose(input_b);
}
//code from python not implemented in c++ yet.
//reshape_for_matmul(node, input_a, input_b);
std::shared_ptr<ngraph::Node> a_dot_b =
std::make_shared<ngraph::op::Dot>(input_a, input_b);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
a_dot_b->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, a_dot_b->get_shape());
a_dot_b = std::make_shared<ngraph::op::Multiply>(alpha_node, a_dot_b);
std::shared_ptr<ngraph::Node> beta_node = std::make_shared<ngraph::op::Constant>(
input_c->get_element_type(), ngraph::Shape{}, std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, input_c->get_shape());
input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c);
return {std::make_shared<ngraph::op::Add>(a_dot_b, input_c)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <memory>
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector gemm(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/multiply.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector mul(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -22,6 +22,8 @@
#include "op/batch_norm.hpp"
#include "op/constant.hpp"
#include "op/conv.hpp"
#include "op/gemm.hpp"
#include "op/mul.hpp"
#include "op/relu.hpp"
#include "op/split.hpp"
#include "ops_bridge.hpp"
......@@ -73,6 +75,8 @@ namespace ngraph
std::bind(op::batch_norm, std::placeholders::_1));
m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1));
m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1));
m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1));
m_map.emplace("Mul", std::bind(op::mul, std::placeholders::_1));
m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1));
m_map.emplace("Split", std::bind(op::split, std::placeholders::_1));
}
......
......@@ -17,6 +17,8 @@
#pragma once
#include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
......@@ -65,6 +67,12 @@ namespace ngraph
output_shape, input_shape, output_shape.size() - input_shape.size());
}
inline std::shared_ptr<ngraph::Node>
make_broadcast_node(const std::shared_ptr<ngraph::Node>& node, ngraph::Shape new_shape)
{
return std::make_shared<ngraph::op::Broadcast>(
node, new_shape, calculate_broadcast_axes(new_shape, node->get_shape()));
}
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/reshape.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
namespace onnx_import
{
std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node,
std::vector<size_t> axes_order = {})
{
ngraph::Shape out_shape = node->get_shape();
if (axes_order.empty())
{
axes_order.resize(out_shape.size());
std::iota(std::begin(axes_order), std::end(axes_order), 0);
}
else
{
for (int i = 0; i < axes_order.size(); ++i)
{
out_shape[i] = node->get_shape().at(axes_order.at(i));
}
}
auto axis_vector = ngraph::AxisVector{axes_order.begin(), axes_order.end()};
return std::make_shared<ngraph::op::Reshape>(node, axis_vector, out_shape);
}
std::shared_ptr<ngraph::Node> transpose(const std::shared_ptr<ngraph::Node>& node)
{
std::vector<size_t> axes_order(node->get_shape().size());
std::iota(std::begin(axes_order), std::end(axes_order), 0);
std::reverse(std::begin(axes_order), std::end(axes_order));
return reorder_axes(node, axes_order);
}
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
/**
* @brief Permute axes according to specified axes_order parameter.
*
* @param node: The node which axes we want to permute.
* @param axes_order: The permutation of node tensor axes.
*
* @retur: New node with permuted axes.
*/
std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node,
std::vector<int> axes_order);
/**
* @brief Return transposed tensor (with axes in reversed order).
*
* @param node: Input tensor we want to transpose
*
* @return: New node with reversed dimensions.
*/
std::shared_ptr<ngraph::Node> transpose(const std::shared_ptr<ngraph::Node>& node);
} // namespace onnx_import
} // namespace ngraph
......@@ -54,6 +54,24 @@ TEST(onnx, model_add_abc_initializers)
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_addmul_abc)
{
auto function{ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/addmul_abc.onnx"))};
std::vector<std::vector<float>> inputs;
ngraph::Shape shape{1, 2, 2};
inputs.emplace_back(ngraph::test::NDArray<float, 3>({{{9, 10}}, {{11, 12}}}).get_vector());
inputs.emplace_back(ngraph::test::NDArray<float, 3>({{{5, 6}}, {{7, 8}}}).get_vector());
inputs.emplace_back(ngraph::test::NDArray<float, 3>({{{1, 2}}, {{3, 4}}}).get_vector());
auto expected_output = ngraph::test::NDArray<float, 3>({{{46, 62}}, {{80, 100}}}).get_vector();
auto result_vectors = execute(function, inputs, "INTERPRETER");
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
}
TEST(onnx, model_split_equal_parts_default)
{
Model model{onnx_import::load_onnx_model(
......@@ -219,3 +237,34 @@ TEST(onnx, model_relu)
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_gemm_abc)
{
auto function{ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/gemm_abc.onnx"))};
std::vector<std::vector<float>> inputs;
inputs.emplace_back(ngraph::test::NDArray<float, 2>(
{{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}, {13, 14, 15, 16, 17, 18}})
.get_vector());
inputs.emplace_back(ngraph::test::NDArray<float, 2>({{19, 20, 21, 22},
{23, 24, 25, 26},
{27, 28, 29, 30},
{31, 32, 33, 34},
{35, 36, 37, 38},
{39, 40, 41, 42}})
.get_vector());
inputs.emplace_back(
ngraph::test::NDArray<float, 2>({{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}).get_vector());
auto expected_output =
ngraph::test::NDArray<float, 2>(
{{340, 350.5, 361, 371.5}, {862, 890.5, 919, 947.5}, {1384, 1430.5, 1477, 1523.5}})
.get_vector();
auto result_vectors = execute(function, inputs, "INTERPRETER");
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment