Commit 9244e45b authored by tsocha's avatar tsocha Committed by Sang Ik Lee

[Fused Op] Move Gemm operator from onnx import to ngraph fused ops (#2853)

* Move transpose and flatten to ngraph op utils dir

* Move gemm operator into ngraph fused ops

* Style fix

* Add Gemm to serializer

* Add type_prop test for gemm

* Use Gemm default values

* Add UT for Gemm

* Fix comments

* Little cleanup

* Remove artifact headers

* Fix gemm documentation

* Skip gemm test on GPU

* Add test for broadcasting input C

* Review fix pt. 1

* Fix typo
parent 391d50e0
...@@ -59,7 +59,7 @@ set (SRC ...@@ -59,7 +59,7 @@ set (SRC
descriptor/tensor.hpp descriptor/tensor.hpp
dimension.cpp dimension.cpp
dimension.hpp dimension.hpp
distributed.cpp distributed.cpp
distributed.hpp distributed.hpp
except.hpp except.hpp
file_util.cpp file_util.cpp
...@@ -278,6 +278,8 @@ set (SRC ...@@ -278,6 +278,8 @@ set (SRC
op/fused/depth_to_space.hpp op/fused/depth_to_space.hpp
op/fused/elu.cpp op/fused/elu.cpp
op/fused/elu.hpp op/fused/elu.hpp
op/fused/gemm.cpp
op/fused/gemm.hpp
op/fused/group_conv.hpp op/fused/group_conv.hpp
op/fused/group_conv.cpp op/fused/group_conv.cpp
op/fused/prelu.cpp op/fused/prelu.cpp
......
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= data->get_shape().size())) ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= data->get_shape().size()))
<< "provided 'axis' attribute is not valid."; << "provided 'axis' attribute is not valid.";
return {reshape::flatten(data, axis)}; return {ngraph::op::util::flatten(data, axis)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -14,14 +14,10 @@ ...@@ -14,14 +14,10 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "op/gemm.hpp" #include <memory>
#include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/utils/reshape.hpp" #include "gemm.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -41,48 +37,11 @@ namespace ngraph ...@@ -41,48 +37,11 @@ namespace ngraph
double alpha = node.get_attribute_value<double>("alpha", 1); double alpha = node.get_attribute_value<double>("alpha", 1);
double beta = node.get_attribute_value<double>("beta", 1); double beta = node.get_attribute_value<double>("beta", 1);
auto trans_a = node.get_attribute_value<int64_t>("transA", 0); bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
auto trans_b = node.get_attribute_value<int64_t>("transB", 0); bool trans_b = node.get_attribute_value<int64_t>("transB", 0);
if (trans_a != 0)
{
input_a = reshape::transpose(input_a);
}
if (trans_b != 0)
{
input_b = reshape::transpose(input_b);
}
input_a = reshape::flatten(input_a, 1);
input_b = reshape::flatten(input_b, 1);
// A' * B'
std::shared_ptr<ngraph::Node> a_dot_b =
std::make_shared<ngraph::op::Dot>(input_a, input_b);
// alpha
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(a_dot_b->get_element_type(),
a_dot_b->get_shape(),
std::vector<double>{alpha});
// alpha * A' * B'
a_dot_b = std::make_shared<ngraph::op::Multiply>(alpha_node, a_dot_b);
// beta * C
std::shared_ptr<ngraph::Node> beta_node =
std::make_shared<ngraph::op::Constant>(input_c->get_element_type(),
input_c->get_shape(),
std::vector<double>{beta});
input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c);
// alpha * A' * B' + beta * C return NodeVector{std::make_shared<ngraph::op::Gemm>(
NodeVector broadcasted_nodes = input_a, input_b, input_c, alpha, beta, trans_a, trans_b)};
ngraph::op::numpy_style_broadcast({a_dot_b, input_c});
// The ONNX documentation says that `input_c` should be "unidirectional broadcastable"
// to the `a_dot_b` tensor. Since numpy style broadcasting is bidirectional, below we
// only use the second output from above broadcasting. In other words we want to
// preserve the shape of original `a_dot_b` tensor.
return {std::make_shared<ngraph::op::Add>(a_dot_b, broadcasted_nodes.at(1))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
#pragma once #pragma once
#include <memory>
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
......
...@@ -387,11 +387,11 @@ namespace ngraph ...@@ -387,11 +387,11 @@ namespace ngraph
// * - Denotes dot product. // * - Denotes dot product.
// Xt*(W^T) -- for [iofc] gates. // Xt*(W^T) -- for [iofc] gates.
auto Xt_W = auto Xt_W = std::make_shared<ngraph::op::Dot>(
std::make_shared<ngraph::op::Dot>(in_x, reshape::transpose(m_W)); in_x, ngraph::op::util::transpose(m_W));
// Ht-1*(R^T) -- for [iofc] gates. // Ht-1*(R^T) -- for [iofc] gates.
auto Ht_R = auto Ht_R = std::make_shared<ngraph::op::Dot>(
std::make_shared<ngraph::op::Dot>(H_t, reshape::transpose(m_R)); H_t, ngraph::op::util::transpose(m_R));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates. // Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates.
auto gates = add(Xt_W, add(Ht_R, bias)); auto gates = add(Xt_W, add(Ht_R, bias));
......
...@@ -38,7 +38,7 @@ namespace ngraph ...@@ -38,7 +38,7 @@ namespace ngraph
node.get_attribute_value<std::vector<std::size_t>>("perm", {}); node.get_attribute_value<std::vector<std::size_t>>("perm", {});
return {(permute_axes.empty()) return {(permute_axes.empty())
? reshape::transpose(data) ? ngraph::op::util::transpose(data)
: ngraph::op::util::reorder_axes(data, permute_axes)}; : ngraph::op::util::reorder_axes(data, permute_axes)};
} }
......
...@@ -62,29 +62,6 @@ namespace ngraph ...@@ -62,29 +62,6 @@ namespace ngraph
} // namespace anonymous } // namespace anonymous
std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node,
int axis)
{
auto data_shape = node->get_shape();
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of input tensor.
// The last dimension is the product of the rest of input tensor dimensions: [d_{axis}, ..., d_n]
size_t first_dim_size = std::accumulate(std::begin(data_shape),
std::next(std::begin(data_shape), axis),
1UL,
std::multiplies<std::size_t>());
size_t last_dim_size = std::accumulate(std::next(std::begin(data_shape), axis),
std::end(data_shape),
1UL,
std::multiplies<std::size_t>());
return std::make_shared<ngraph::op::Reshape>(
node,
ngraph::get_default_order(data_shape.size()),
Shape{first_dim_size, last_dim_size});
}
std::vector<std::size_t> infer_dimensions(const std::string& node_name, std::vector<std::size_t> infer_dimensions(const std::string& node_name,
const std::vector<std::size_t>& input_shape, const std::vector<std::size_t>& input_shape,
const std::vector<std::size_t>& output_shape) const std::vector<std::size_t>& output_shape)
...@@ -140,14 +117,6 @@ namespace ngraph ...@@ -140,14 +117,6 @@ namespace ngraph
return inferred_dims; return inferred_dims;
} }
std::shared_ptr<ngraph::Node> transpose(const std::shared_ptr<ngraph::Node>& node)
{
std::vector<size_t> axes_order(node->get_shape().size());
std::iota(std::begin(axes_order), std::end(axes_order), 0);
std::reverse(std::begin(axes_order), std::end(axes_order));
return ngraph::op::util::reorder_axes(node, axes_order);
}
std::shared_ptr<ngraph::Node> squeeze(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> squeeze(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes) std::vector<std::size_t> axes)
{ {
......
...@@ -32,15 +32,6 @@ namespace ngraph ...@@ -32,15 +32,6 @@ namespace ngraph
{ {
namespace reshape namespace reshape
{ {
/// \brief Flatten the input tensor into a 2D matrix.
///
/// \param node The tensor to be flattened.
/// \param axis The axis dividing shape.
///
/// \return The new node being a 2D matrix representing flattened input node.
std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node,
int axis);
/// \brief Infer `output_shape` dimension values. /// \brief Infer `output_shape` dimension values.
/// ///
/// \par Inferention rules /// \par Inferention rules
...@@ -59,13 +50,6 @@ namespace ngraph ...@@ -59,13 +50,6 @@ namespace ngraph
const std::vector<std::size_t>& input_shape, const std::vector<std::size_t>& input_shape,
const std::vector<std::size_t>& output_shape); const std::vector<std::size_t>& output_shape);
/// \brief Return transposed tensor (with axes in reversed order).
///
/// \param node Input tensor we want to transpose
///
/// \return: New node with reversed dimensions.
std::shared_ptr<ngraph::Node> transpose(const std::shared_ptr<ngraph::Node>& node);
/// \brief Remove empty axes from input tensor. /// \brief Remove empty axes from input tensor.
/// ///
/// \param[in] node The node to be squeezed. /// \param[in] node The node to be squeezed.
......
...@@ -97,6 +97,7 @@ ...@@ -97,6 +97,7 @@
#include "ngraph/op/fused/conv_fused.hpp" #include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/depth_to_space.hpp" #include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp" #include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/group_conv.hpp" #include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/prelu.hpp" #include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/space_to_depth.hpp" #include "ngraph/op/fused/space_to_depth.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/reshape.hpp"
using namespace std;
using namespace ngraph;
op::Gemm::Gemm(const std::shared_ptr<ngraph::Node>& A,
const std::shared_ptr<ngraph::Node>& B,
const std::shared_ptr<ngraph::Node>& C,
double alpha,
double beta,
bool transA,
bool transB)
: FusedOp("Gemm", {A, B, C})
, m_alpha{alpha}
, m_beta{beta}
, m_transA{transA}
, m_transB{transB}
{
constructor_validate_and_infer_types();
}
NodeVector op::Gemm::decompose_op() const
{
auto A = get_argument(0);
auto B = get_argument(1);
auto C = get_argument(2);
if (m_transA)
{
A = ngraph::op::util::transpose(A);
}
if (m_transB)
{
B = ngraph::op::util::transpose(B);
}
A = ngraph::op::util::flatten(A, 1);
B = ngraph::op::util::flatten(B, 1);
// A' * B'
std::shared_ptr<ngraph::Node> a_dot_b = std::make_shared<ngraph::op::Dot>(A, B);
// alpha
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
a_dot_b->get_element_type(), a_dot_b->get_shape(), std::vector<double>{m_alpha});
// alpha * A' * B'
a_dot_b = std::make_shared<ngraph::op::Multiply>(alpha_node, a_dot_b);
// beta * C
std::shared_ptr<ngraph::Node> beta_node = std::make_shared<ngraph::op::Constant>(
C->get_element_type(), C->get_shape(), std::vector<double>{m_beta});
C = std::make_shared<ngraph::op::Multiply>(beta_node, C);
// alpha * A' * B' + beta * C
NodeVector broadcasted_nodes = ngraph::op::numpy_style_broadcast({a_dot_b, C});
// The input tensor `C` should be "unidirectionally broadcastable" to the `a_dot_b` tensor.
// Numpy style broadcast is bidirectional, so we only use the second output from broadcasting.
return {std::make_shared<ngraph::op::Add>(a_dot_b, broadcasted_nodes.at(1))};
}
shared_ptr<Node> op::Gemm::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Gemm>(
new_args.at(0), new_args.at(1), new_args.at(2), m_alpha, m_beta, m_transA, m_transB);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Operator performing General Matrix multiplication.
///
/// \note More information: https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3
///
/// A' = transpose(A) if transA else A
/// B' = transpose(B) if transB else B
///
/// Compute Y = alpha * A' * B' + beta * C
///
class Gemm : public ngraph::op::util::FusedOp
{
public:
/// \brief Constructs an Gemm operation.
///
/// \param A Input tensor A
/// \param B Input tensor B
/// \param C Input tensor C
/// \param alpha Scalar multiplier for the product of input tensors A * B
/// \param beta Scalar multiplier for input tensor C
/// \param transA Whether A should be transposed
/// \param transB Whether B should be transposed
Gemm(const std::shared_ptr<ngraph::Node>& A,
const std::shared_ptr<ngraph::Node>& B,
const std::shared_ptr<ngraph::Node>& C,
double alpha = 1.0,
double beta = 1.0,
bool transA = false,
bool transB = false);
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
double get_alpha() const { return m_alpha; }
double get_beta() const { return m_beta; }
bool get_transA() const { return m_transA; }
bool get_transB() const { return m_transB; }
private:
double m_alpha;
double m_beta;
bool m_transA;
bool m_transB;
};
}
}
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
// //
NGRAPH_OP(Elu, ngraph::op) NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op) NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(ConvolutionBias, ngraph::op) NGRAPH_OP(ConvolutionBias, ngraph::op)
NGRAPH_OP(ConvolutionBiasAdd, ngraph::op) NGRAPH_OP(ConvolutionBiasAdd, ngraph::op)
......
...@@ -25,32 +25,54 @@ ...@@ -25,32 +25,54 @@
#include "reshape.hpp" #include "reshape.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std;
std::shared_ptr<Node> op::util::reshape(const std::shared_ptr<Node>& node, shared_ptr<Node> op::util::reshape(const shared_ptr<Node>& node, const Shape& shape)
const AxisVector& axis_order,
const Shape& shape)
{ {
return std::make_shared<op::Reshape>( return make_shared<op::Reshape>(node, get_default_order(node->get_shape().size()), shape);
node, ngraph::get_default_order(node->get_shape().size()), shape);
} }
std::shared_ptr<Node> op::util::reorder_axes(const std::shared_ptr<Node>& node, shared_ptr<Node> op::util::reorder_axes(const shared_ptr<Node>& node,
std::vector<std::size_t> axes_order = {}) vector<size_t> axes_order = {})
{ {
Shape out_shape = node->get_shape(); Shape out_shape = node->get_shape();
if (axes_order.empty()) if (axes_order.empty())
{ {
axes_order.resize(out_shape.size()); axes_order.resize(out_shape.size());
std::iota(std::begin(axes_order), std::end(axes_order), 0); iota(begin(axes_order), end(axes_order), 0);
} }
else else
{ {
for (std::size_t i = 0; i < axes_order.size(); ++i) for (size_t i = 0; i < axes_order.size(); ++i)
{ {
out_shape[i] = node->get_shape().at(axes_order.at(i)); out_shape[i] = node->get_shape().at(axes_order.at(i));
} }
} }
auto axis_vector = AxisVector{std::begin(axes_order), std::end(axes_order)}; auto axis_vector = AxisVector{begin(axes_order), end(axes_order)};
return std::make_shared<op::Reshape>(node, axis_vector, out_shape); return make_shared<op::Reshape>(node, axis_vector, out_shape);
}
shared_ptr<Node> op::util::transpose(const shared_ptr<Node>& node)
{
vector<size_t> axes_order(node->get_shape().size());
iota(begin(axes_order), end(axes_order), 0);
reverse(begin(axes_order), end(axes_order));
return op::util::reorder_axes(node, axes_order);
}
shared_ptr<Node> op::util::flatten(const shared_ptr<Node>& node, int axis)
{
auto data_shape = node->get_shape();
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of input tensor.
// The last dimension is the product of the rest of input tensor dimensions: [d_{axis}, ..., d_n]
size_t first_dim_size =
accumulate(begin(data_shape), next(begin(data_shape), axis), 1UL, multiplies<size_t>());
size_t last_dim_size =
accumulate(next(begin(data_shape), axis), end(data_shape), 1UL, multiplies<size_t>());
return make_shared<op::Reshape>(
node, get_default_order(data_shape.size()), Shape{first_dim_size, last_dim_size});
} }
...@@ -34,15 +34,8 @@ namespace ngraph ...@@ -34,15 +34,8 @@ namespace ngraph
/// \return The node representing a Reshape operation. /// \return The node representing a Reshape operation.
/// ///
std::shared_ptr<ngraph::Node> reshape(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> reshape(const std::shared_ptr<ngraph::Node>& node,
const AxisVector& axis_order,
const Shape& shape); const Shape& shape);
inline std::shared_ptr<ngraph::Node> reshape(const std::shared_ptr<ngraph::Node>& node,
const Shape& shape)
{
return reshape(node, ngraph::get_default_order(node->get_shape().size()), shape);
}
/// \brief Permute axes according to specified axes_order parameter. /// \brief Permute axes according to specified axes_order parameter.
/// ///
/// \param node The node which axes we want to permute. /// \param node The node which axes we want to permute.
...@@ -51,6 +44,22 @@ namespace ngraph ...@@ -51,6 +44,22 @@ namespace ngraph
/// \return: New node with permuted axes. /// \return: New node with permuted axes.
std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes_order); std::vector<std::size_t> axes_order);
/// \brief Return transposed tensor (with axes in reversed order).
///
/// \param node Input tensor we want to transpose
///
/// \return: New node with reversed dimensions.
std::shared_ptr<ngraph::Node> transpose(const std::shared_ptr<ngraph::Node>& node);
/// \brief Flatten the input tensor into a 2D matrix.
///
/// \param node The tensor to be flattened.
/// \param axis The axis dividing shape.
///
/// \return The new node being a 2D matrix representing flattened input node.
std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node,
int axis);
} // namespace util } // namespace util
} // namespace op } // namespace op
} // namespace ngraph } // namespace ngraph
...@@ -153,3 +153,5 @@ gather_nd_batch_2d_from_3d ...@@ -153,3 +153,5 @@ gather_nd_batch_2d_from_3d
gather_scalar_indices_no_axis gather_scalar_indices_no_axis
gather_scalar_indices gather_scalar_indices
gather_nd_single_indices gather_nd_single_indices
gemm
gemm_broadcast_input_C
...@@ -79,6 +79,7 @@ ...@@ -79,6 +79,7 @@
#include "ngraph/op/fused/conv_fused.hpp" #include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/depth_to_space.hpp" #include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp" #include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/group_conv.hpp" #include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/space_to_depth.hpp" #include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
...@@ -1983,6 +1984,7 @@ shared_ptr<runtime::Executable> ...@@ -1983,6 +1984,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::Erf: case OP_TYPEID::Erf:
case OP_TYPEID::Gather: case OP_TYPEID::Gather:
case OP_TYPEID::GatherND: case OP_TYPEID::GatherND:
case OP_TYPEID::Gemm:
case OP_TYPEID::GenerateMask: case OP_TYPEID::GenerateMask:
case OP_TYPEID::PRelu: case OP_TYPEID::PRelu:
case OP_TYPEID::Passthrough: case OP_TYPEID::Passthrough:
......
...@@ -64,3 +64,5 @@ gather_nd_batch_2d_from_3d ...@@ -64,3 +64,5 @@ gather_nd_batch_2d_from_3d
gather_scalar_indices_no_axis gather_scalar_indices_no_axis
gather_scalar_indices gather_scalar_indices
gather_nd_single_indices gather_nd_single_indices
gemm
gemm_broadcast_input_C
...@@ -68,6 +68,7 @@ ...@@ -68,6 +68,7 @@
#include "ngraph/op/fused/conv_fused.hpp" #include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/depth_to_space.hpp" #include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp" #include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/group_conv.hpp" #include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/prelu.hpp" #include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/space_to_depth.hpp" #include "ngraph/op/fused/space_to_depth.hpp"
...@@ -940,6 +941,16 @@ static shared_ptr<ngraph::Function> ...@@ -940,6 +941,16 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::GatherND>(args[0], args[1]); node = make_shared<op::GatherND>(args[0], args[1]);
break; break;
} }
case OP_TYPEID::Gemm:
{
auto alpha = node_js.at("alpha").get<double>();
auto beta = node_js.at("beta").get<double>();
auto transA = node_js.at("transA").get<bool>();
auto transB = node_js.at("transB").get<bool>();
node =
make_shared<op::Gemm>(args[0], args[1], args[2], alpha, beta, transA, transB);
break;
}
case OP_TYPEID::GenerateMask: case OP_TYPEID::GenerateMask:
{ {
auto output_shape = node_js.at("output_shape").get<vector<size_t>>(); auto output_shape = node_js.at("output_shape").get<vector<size_t>>();
...@@ -1803,6 +1814,15 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1803,6 +1814,15 @@ static json write(const Node& n, bool binary_constant_data)
node["n"] = tmp->get_n(); node["n"] = tmp->get_n();
break; break;
} }
case OP_TYPEID::Gemm:
{
auto tmp = dynamic_cast<const op::Gemm*>(&n);
node["alpha"] = tmp->get_alpha();
node["beta"] = tmp->get_beta();
node["transA"] = tmp->get_transA();
node["transB"] = tmp->get_transB();
break;
}
case OP_TYPEID::GenerateMask: case OP_TYPEID::GenerateMask:
{ {
auto tmp = dynamic_cast<const op::GenerateMask*>(&n); auto tmp = dynamic_cast<const op::GenerateMask*>(&n);
......
...@@ -337,3 +337,43 @@ NGRAPH_TEST(${BACKEND_NAME}, depth_to_space) ...@@ -337,3 +337,43 @@ NGRAPH_TEST(${BACKEND_NAME}, depth_to_space)
22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f}); 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f});
test_case.run(); test_case.run();
} }
NGRAPH_TEST(${BACKEND_NAME}, gemm)
{
auto A = make_shared<op::Parameter>(element::f64, Shape{3, 6});
auto B = make_shared<op::Parameter>(element::f64, Shape{6, 4});
auto C = make_shared<op::Parameter>(element::f64, Shape{3, 4});
auto gemm_func = make_shared<op::Gemm>(A, B, C);
auto function = make_shared<Function>(NodeVector{gemm_func}, ParameterVector{A, B, C});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
// A
test_case.add_input<double>(vector<double>(18, 1));
// B
test_case.add_input<double>(vector<double>(24, 2));
// C
test_case.add_input<double>(vector<double>(12, 0));
//output
test_case.add_expected_output<double>(Shape{3, 4}, vector<double>(12, 12));
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, gemm_broadcast_input_C)
{
auto A = make_shared<op::Parameter>(element::f64, Shape{3, 6});
auto B = make_shared<op::Parameter>(element::f64, Shape{6, 4});
auto C = make_shared<op::Parameter>(element::f64, Shape{});
auto gemm_func = make_shared<op::Gemm>(A, B, C, 0.5);
auto function = make_shared<Function>(NodeVector{gemm_func}, ParameterVector{A, B, C});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
// A
test_case.add_input<double>(vector<double>(18, 1));
// B
test_case.add_input<double>(vector<double>(24, 2));
// C
test_case.add_input<double>(vector<double>{1});
//output
test_case.add_expected_output<double>(Shape{3, 4}, vector<double>(12, 7));
test_case.run();
}
...@@ -13826,3 +13826,23 @@ TEST(type_prop, group_conv_invalid_groups) ...@@ -13826,3 +13826,23 @@ TEST(type_prop, group_conv_invalid_groups)
FAIL() << "Deduced type check failed for unexpected reason"; FAIL() << "Deduced type check failed for unexpected reason";
} }
} }
TEST(type_prop, gemm)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{3, 6});
auto B = make_shared<op::Parameter>(element::f32, Shape{6, 4});
auto C = make_shared<op::Parameter>(element::f32, Shape{3, 4});
auto gemm_func = make_shared<op::Gemm>(A, B, C);
EXPECT_EQ(gemm_func->get_element_type(), element::f32);
EXPECT_EQ(gemm_func->get_shape(), (Shape{3, 4}));
}
TEST(type_prop, gemm_broadcast_input_C)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{3, 6});
auto B = make_shared<op::Parameter>(element::f32, Shape{6, 4});
auto C = make_shared<op::Parameter>(element::f32, Shape{});
auto gemm_func = make_shared<op::Gemm>(A, B, C);
EXPECT_EQ(gemm_func->get_element_type(), element::f32);
EXPECT_EQ(gemm_func->get_shape(), (Shape{3, 4}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment