Unverified Commit 22d4285f authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Fixed GEMM conversion (#4412)

* Fixed GEMM conversion

* Renaming
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 596bb169
...@@ -73,15 +73,20 @@ namespace ngraph ...@@ -73,15 +73,20 @@ namespace ngraph
input_a = ngraph::builder::opset1::flatten(input_a, 1); input_a = ngraph::builder::opset1::flatten(input_a, 1);
input_b = ngraph::builder::opset1::flatten(input_b, 1); input_b = ngraph::builder::opset1::flatten(input_b, 1);
auto matmul_node = std::make_shared<default_opset::MatMul>(input_a, input_b); std::shared_ptr<ngraph::Node> matmul_node =
std::make_shared<default_opset::MatMul>(input_a, input_b);
if (alpha != 1)
{
matmul_node =
std::make_shared<default_opset::Multiply>(matmul_node, alpha_node);
}
auto alpha_times_product =
std::make_shared<default_opset::Multiply>(alpha_node, matmul_node);
auto beta_times_input_c = auto beta_times_input_c =
std::make_shared<default_opset::Multiply>(beta_node, input_c); std::make_shared<default_opset::Multiply>(beta_node, input_c);
return NodeVector{std::make_shared<default_opset::Add>(alpha_times_product, return NodeVector{
beta_times_input_c)}; std::make_shared<default_opset::Add>(matmul_node, beta_times_input_c)};
} }
} // namespace set_1 } // namespace set_1
...@@ -116,16 +121,20 @@ namespace ngraph ...@@ -116,16 +121,20 @@ namespace ngraph
const bool trans_a = node.get_attribute_value<int64_t>("transA", 0); const bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
const bool trans_b = node.get_attribute_value<int64_t>("transB", 0); const bool trans_b = node.get_attribute_value<int64_t>("transB", 0);
auto matmul_node = std::shared_ptr<ngraph::Node> matmul_node =
std::make_shared<default_opset::MatMul>(input_a, input_b, trans_a, trans_b); std::make_shared<default_opset::MatMul>(input_a, input_b, trans_a, trans_b);
auto alpha_times_product = if (alpha != 1)
std::make_shared<default_opset::Multiply>(alpha_node, matmul_node); {
matmul_node =
std::make_shared<default_opset::Multiply>(matmul_node, alpha_node);
}
auto beta_times_input_c = auto beta_times_input_c =
std::make_shared<default_opset::Multiply>(beta_node, input_c); std::make_shared<default_opset::Multiply>(beta_node, input_c);
return NodeVector{std::make_shared<default_opset::Add>(alpha_times_product, return NodeVector{
beta_times_input_c)}; std::make_shared<default_opset::Add>(matmul_node, beta_times_input_c)};
} }
} // namespace set_6 } // namespace set_6
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment