Commit 7e78232a authored by Ewa Tusień's avatar Ewa Tusień Committed by Michał Karzyński

[ONNX] Add support for optional C input in Gemm op (#3821)

parent b9d7b7d2
......@@ -17,6 +17,7 @@
#include <memory>
#include "gemm.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/gemm.hpp"
namespace ngraph
......@@ -30,9 +31,19 @@ namespace ngraph
NodeVector gemm(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
auto input_a = inputs.at(0);
auto input_b = inputs.at(1);
auto input_c = inputs.at(2);
std::shared_ptr<ngraph::Node> input_a = inputs.at(0);
std::shared_ptr<ngraph::Node> input_b = inputs.at(1);
std::shared_ptr<ngraph::Node> input_c;
if (inputs.size() == 3)
{
input_c = inputs.at(2);
}
else
{
input_c = ngraph::op::Constant::create(
input_b->get_element_type(), ngraph::Shape{}, {0});
}
double alpha = node.get_attribute_value<double>("alpha", 1);
double beta = node.get_attribute_value<double>("beta", 1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment