Commit 92c1d504 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Robert Kimball

[ONNX] Gemm fix (#1877)

* Fix gemm `input_c` broadcasting.

* Comments.

* Add comment
parent 4552d024
......@@ -59,26 +59,35 @@ namespace ngraph
// code from python not implemented in c++ yet.
// reshape_for_matmul(node, input_a, input_b);
// A' * B'
std::shared_ptr<ngraph::Node> a_dot_b =
std::make_shared<ngraph::op::Dot>(input_a, input_b);
// alpha
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(a_dot_b->get_element_type(),
ngraph::Shape{},
std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, a_dot_b->get_shape());
// alpha * A' * B'
a_dot_b = std::make_shared<ngraph::op::Multiply>(alpha_node, a_dot_b);
// beta * C
std::shared_ptr<ngraph::Node> beta_node =
std::make_shared<ngraph::op::Constant>(input_c->get_element_type(),
ngraph::Shape{},
std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, input_c->get_shape());
input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c);
input_c = make_broadcast_node(input_c, a_dot_b->get_shape());
return {std::make_shared<ngraph::op::Add>(a_dot_b, input_c)};
// alpha * A' * B' + beta * C
NodeVector broadcasted_nodes =
numpy_style_broadcast_for_binary_operation(a_dot_b, input_c);
// The ONNX documentation says that `input_c` should be "unidirectional broadcastable"
// to the `a_dot_b` tensor. Since numpy style broadcasting is bidirectional, below we
// only use the second output from above broadcasting. In other words we want to
// preserve the shape of original `a_dot_b` tensor.
return {std::make_shared<ngraph::op::Add>(a_dot_b, broadcasted_nodes.at(1))};
}
} // namespace set_1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment