Commit 7e78232a authored by Ewa Tusień's avatar Ewa Tusień Committed by Michał Karzyński

[ONNX] Add support for optional C input in Gemm op (#3821)

parent b9d7b7d2
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include "gemm.hpp" #include "gemm.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/gemm.hpp" #include "ngraph/op/fused/gemm.hpp"
namespace ngraph namespace ngraph
...@@ -30,9 +31,19 @@ namespace ngraph ...@@ -30,9 +31,19 @@ namespace ngraph
NodeVector gemm(const Node& node) NodeVector gemm(const Node& node)
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector inputs{node.get_ng_inputs()};
auto input_a = inputs.at(0); std::shared_ptr<ngraph::Node> input_a = inputs.at(0);
auto input_b = inputs.at(1); std::shared_ptr<ngraph::Node> input_b = inputs.at(1);
auto input_c = inputs.at(2); std::shared_ptr<ngraph::Node> input_c;
if (inputs.size() == 3)
{
input_c = inputs.at(2);
}
else
{
input_c = ngraph::op::Constant::create(
input_b->get_element_type(), ngraph::Shape{}, {0});
}
double alpha = node.get_attribute_value<double>("alpha", 1); double alpha = node.get_attribute_value<double>("alpha", 1);
double beta = node.get_attribute_value<double>("beta", 1); double beta = node.get_attribute_value<double>("beta", 1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment