Unverified Commit 22d4285f authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Fixed GEMM conversion (#4412)

* Fixed GEMM conversion

* Renaming
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 596bb169
......@@ -73,15 +73,20 @@ namespace ngraph
input_a = ngraph::builder::opset1::flatten(input_a, 1);
input_b = ngraph::builder::opset1::flatten(input_b, 1);
auto matmul_node = std::make_shared<default_opset::MatMul>(input_a, input_b);
std::shared_ptr<ngraph::Node> matmul_node =
std::make_shared<default_opset::MatMul>(input_a, input_b);
if (alpha != 1)
{
matmul_node =
std::make_shared<default_opset::Multiply>(matmul_node, alpha_node);
}
auto alpha_times_product =
std::make_shared<default_opset::Multiply>(alpha_node, matmul_node);
auto beta_times_input_c =
std::make_shared<default_opset::Multiply>(beta_node, input_c);
return NodeVector{std::make_shared<default_opset::Add>(alpha_times_product,
beta_times_input_c)};
return NodeVector{
std::make_shared<default_opset::Add>(matmul_node, beta_times_input_c)};
}
} // namespace set_1
......@@ -116,16 +121,20 @@ namespace ngraph
const bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
const bool trans_b = node.get_attribute_value<int64_t>("transB", 0);
auto matmul_node =
std::shared_ptr<ngraph::Node> matmul_node =
std::make_shared<default_opset::MatMul>(input_a, input_b, trans_a, trans_b);
auto alpha_times_product =
std::make_shared<default_opset::Multiply>(alpha_node, matmul_node);
if (alpha != 1)
{
matmul_node =
std::make_shared<default_opset::Multiply>(matmul_node, alpha_node);
}
auto beta_times_input_c =
std::make_shared<default_opset::Multiply>(beta_node, input_c);
return NodeVector{std::make_shared<default_opset::Add>(alpha_times_product,
beta_times_input_c)};
return NodeVector{
std::make_shared<default_opset::Add>(matmul_node, beta_times_input_c)};
}
} // namespace set_6
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment