Commit 4996acc4 authored by tsocha's avatar tsocha Committed by Scott Cyphers

[ONNX] Fix Gemm operator by adding reshape (#2619)

parent aa52081c
...@@ -54,8 +54,8 @@ namespace ngraph ...@@ -54,8 +54,8 @@ namespace ngraph
input_b = reshape::transpose(input_b); input_b = reshape::transpose(input_b);
} }
// code from python not implemented in c++ yet. input_a = reshape::flatten(input_a, 1);
// reshape_for_matmul(node, input_a, input_b); input_b = reshape::flatten(input_b, 1);
// A' * B' // A' * B'
std::shared_ptr<ngraph::Node> a_dot_b = std::shared_ptr<ngraph::Node> a_dot_b =
...@@ -64,18 +64,16 @@ namespace ngraph ...@@ -64,18 +64,16 @@ namespace ngraph
// alpha // alpha
std::shared_ptr<ngraph::Node> alpha_node = std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(a_dot_b->get_element_type(), std::make_shared<ngraph::op::Constant>(a_dot_b->get_element_type(),
ngraph::Shape{}, a_dot_b->get_shape(),
std::vector<double>{alpha}); std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, a_dot_b->get_shape());
// alpha * A' * B' // alpha * A' * B'
a_dot_b = std::make_shared<ngraph::op::Multiply>(alpha_node, a_dot_b); a_dot_b = std::make_shared<ngraph::op::Multiply>(alpha_node, a_dot_b);
// beta * C // beta * C
std::shared_ptr<ngraph::Node> beta_node = std::shared_ptr<ngraph::Node> beta_node =
std::make_shared<ngraph::op::Constant>(input_c->get_element_type(), std::make_shared<ngraph::op::Constant>(input_c->get_element_type(),
ngraph::Shape{}, input_c->get_shape(),
std::vector<double>{beta}); std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, input_c->get_shape());
input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c); input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c);
// alpha * A' * B' + beta * C // alpha * A' * B' + beta * C
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment