Commit aeaaf8fb authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

Update ONNX importer Gemm to produce MatMul op (#3927)

* Update ONNX importer Gemm to produce MatMul op

* Address opset3 bug
parent 4dc9aa46
......@@ -17,8 +17,11 @@
#include <memory>
#include "gemm.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/multiply.hpp"
namespace ngraph
{
......@@ -45,18 +48,87 @@ namespace ngraph
input_b->get_element_type(), ngraph::Shape{}, {0});
}
double alpha = node.get_attribute_value<double>("alpha", 1);
double beta = node.get_attribute_value<double>("beta", 1);
const auto alpha = node.get_attribute_value<float>("alpha", 1);
const auto beta = node.get_attribute_value<float>("beta", 1);
bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
bool trans_b = node.get_attribute_value<int64_t>("transB", 0);
const auto alpha_node = ngraph::op::Constant::create(
element::Type_t::f32, Shape{}, std::vector<float>{alpha});
const auto beta_node = ngraph::op::Constant::create(
element::Type_t::f32, Shape{}, std::vector<float>{beta});
return NodeVector{std::make_shared<ngraph::op::Gemm>(
input_a, input_b, input_c, alpha, beta, trans_a, trans_b)};
const bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
const bool trans_b = node.get_attribute_value<int64_t>("transB", 0);
if (trans_a)
{
input_a = ngraph::builder::transpose(input_a);
}
if (trans_b)
{
input_b = ngraph::builder::transpose(input_b);
}
input_a = ngraph::builder::flatten(input_a, 1);
input_b = ngraph::builder::flatten(input_b, 1);
auto matmul_node = std::make_shared<ngraph::op::MatMul>(input_a, input_b);
auto alpha_times_product =
std::make_shared<ngraph::op::v1::Multiply>(alpha_node, matmul_node);
auto beta_times_input_c =
std::make_shared<ngraph::op::v1::Multiply>(beta_node, input_c);
return NodeVector{std::make_shared<ngraph::op::v1::Add>(alpha_times_product,
beta_times_input_c)};
}
} // namespace set_1
namespace set_6
{
NodeVector gemm(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
std::shared_ptr<ngraph::Node> input_a = inputs.at(0);
std::shared_ptr<ngraph::Node> input_b = inputs.at(1);
std::shared_ptr<ngraph::Node> input_c;
if (inputs.size() == 3)
{
input_c = inputs.at(2);
}
else
{
input_c = ngraph::op::Constant::create(
input_b->get_element_type(), ngraph::Shape{}, {0});
}
const auto alpha = node.get_attribute_value<float>("alpha", 1);
const auto beta = node.get_attribute_value<float>("beta", 1);
const auto alpha_node = ngraph::op::Constant::create(
element::Type_t::f32, Shape{}, std::vector<float>{alpha});
const auto beta_node = ngraph::op::Constant::create(
element::Type_t::f32, Shape{}, std::vector<float>{beta});
const bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
const bool trans_b = node.get_attribute_value<int64_t>("transB", 0);
auto matmul_node =
std::make_shared<ngraph::op::MatMul>(input_a, input_b, trans_a, trans_b);
auto alpha_times_product =
std::make_shared<ngraph::op::v1::Multiply>(alpha_node, matmul_node);
auto beta_times_input_c =
std::make_shared<ngraph::op::v1::Multiply>(beta_node, input_c);
return NodeVector{std::make_shared<ngraph::op::v1::Add>(alpha_times_product,
beta_times_input_c)};
}
} // namespace set_6
} // namespace op
} // namespace onnx_import
......
......@@ -31,6 +31,12 @@ namespace ngraph
} // namespace set_1
namespace set_6
{
NodeVector gemm(const Node& node);
} // namespace set_6
} // namespace op
} // namespace onnx_import
......
......@@ -272,6 +272,7 @@ namespace ngraph
REGISTER_OPERATOR("Floor", 1, floor);
REGISTER_OPERATOR("Gather", 1, gather);
REGISTER_OPERATOR("Gemm", 1, gemm);
REGISTER_OPERATOR("Gemm", 6, gemm);
REGISTER_OPERATOR("GlobalAveragePool", 1, global_average_pool);
REGISTER_OPERATOR("GlobalLpPool", 1, global_lp_pool);
REGISTER_OPERATOR("GlobalMaxPool", 1, global_max_pool);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment