Unverified Commit 72fa2015 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #3331 from NervanaSystems/etusien/gemm

[Py] Added gemm operator to Python API.
parents c70c2798 c305329e
......@@ -38,6 +38,7 @@ ngraph.ops
exp
floor
gelu
gemm
get_output_element
greater
greater_eq
......
......@@ -51,6 +51,7 @@ from ngraph.ops import equal
from ngraph.ops import exp
from ngraph.ops import floor
from ngraph.ops import gelu
from ngraph.ops import gemm
from ngraph.ops import get_output_element
from ngraph.ops import greater
from ngraph.ops import greater_eq
......
......@@ -75,6 +75,7 @@ from _pyngraph.op import Equal
from _pyngraph.op import Exp
from _pyngraph.op import Floor
from _pyngraph.op import Gelu
from _pyngraph.op import Gemm
from _pyngraph.op import GetOutputElement
from _pyngraph.op import Greater
from _pyngraph.op import GreaterEq
......
......@@ -23,8 +23,8 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \
Gelu, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \
Gelu, Gemm, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, \
MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \
Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \
Sqrt, Subtract, Sum, Tan, Tanh, TopK
......@@ -520,6 +520,46 @@ def broadcast_to(node, new_shape, axis=None, name=None):
return Broadcast(node, Shape(new_shape), get_broadcast_axes(new_shape, node.shape, axis))
@nameable_op
def gemm(A, # type: Node
B, # type: Node
C, # type: Node
alpha, # type: ScalarData
beta, # type: ScalarData
transA, # type: bool
transB, # type: bool
name=None, # type: str
):
# type: (...) -> Node
r"""Perform General matrix-matrix multiplication on input tensors A, B and C.
Computes:
.. math:: Y = alpha\cdot A'\cdot B' + beta\cdot C
:code:`A'` is the transpose of matrix :code:`A` with shape (M, K),
if :code:`transA` is :code:`True`, otherwise :code:`A` with shape (K, N).
:code:`B'` is the transpose of matrix :code:`B` with shape (K, N),
if :code:`transB` is :code:`True`, otherwise :code:`B` with shape (N, K).
:code:`C`: Matrix broadcastable to shape (M, N).
:code:`Y`: Matrix with shape (M, N).
:param A: The node with input tensor A.
:param B: The node with input tensor B.
:param C: The node with input tensor C.
:param alpha: Scalar multiplier for the product of input tensors A * B.
:param beta: Scalar multiplier for input tensor C.
:param transA: Whether A should be transposed. Boolean value.
:param transB: Whether B should be transposed. Boolean value.
:param name: Optional name for the output node.
:return: Return node with tensor of shape (M, N).
"""
return Gemm(A, B, C, alpha, beta, transA, transB)
@nameable_op
def convert(node, new_type, name=None): # type: (Node, NumericType, str) -> Node
"""Return node which casts input node values to specified type."""
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/gemm.hpp"
#include "pyngraph/ops/fused/gemm.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Gemm(py::module m)
{
py::class_<ngraph::op::Gemm, std::shared_ptr<ngraph::op::Gemm>, ngraph::op::Op> gemm(m, "Gemm");
gemm.doc() = "ngraph.impl.op.Gemm wraps ngraph::op::Gemm";
gemm.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
double&,
double&,
bool&,
bool&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Gemm(py::module m);
......@@ -55,6 +55,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Exp(m_op);
regclass_pyngraph_op_Floor(m_op);
regclass_pyngraph_op_Gelu(m_op);
regclass_pyngraph_op_Gemm(m_op);
regclass_pyngraph_op_GetOutputElement(m_op);
regclass_pyngraph_op_Greater(m_op);
regclass_pyngraph_op_GreaterEq(m_op);
......
......@@ -45,6 +45,7 @@
#include "pyngraph/ops/fused/clamp.hpp"
#include "pyngraph/ops/fused/elu.hpp"
#include "pyngraph/ops/fused/gelu.hpp"
#include "pyngraph/ops/fused/gemm.hpp"
#include "pyngraph/ops/get_output_element.hpp"
#include "pyngraph/ops/greater.hpp"
#include "pyngraph/ops/greater_eq.hpp"
......
......@@ -185,6 +185,7 @@ sources = [
'pyngraph/ops/exp.cpp',
'pyngraph/ops/floor.cpp',
'pyngraph/ops/fused/gelu.cpp',
'pyngraph/ops/fused/gemm.cpp',
'pyngraph/ops/greater.cpp',
'pyngraph/ops/greater_eq.cpp',
'pyngraph/ops/less.cpp',
......
......@@ -69,6 +69,43 @@ def test_elu_operator_with_scalar():
assert np.allclose(result, expected)
def test_gemm_operator():
runtime = get_runtime()
shape_a = [3, 2]
shape_b = [3, 2]
shape_c = [2, 1]
value_a = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32)
value_b = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32)
value_c = np.array([[13], [14]], dtype=np.float32)
parameter_a = ng.parameter(shape_a, name='A', dtype=np.float32)
parameter_b = ng.parameter(shape_b, name='B', dtype=np.float32)
parameter_c = ng.parameter(shape_c, name='C', dtype=np.float32)
alpha_value = np.float32(3)
beta_value = np.float32(3)
transA = True
transB = False
model = ng.gemm(parameter_a, parameter_b, parameter_c, alpha_value, beta_value, transA, transB)
computation = runtime.computation(model, parameter_a, parameter_b, parameter_c)
result = computation(value_a, value_b, value_c)
# expected = value_alpha * value_a' * value_b + value_beta * value_c
value_a = value_a.transpose()
a_mul_a = np.multiply(alpha_value, value_a)
aa_mul_b = np.dot(a_mul_a, value_b)
b_mul_c = np.dot(beta_value, value_c)
expected = np.add(aa_mul_b, b_mul_c)
assert np.allclose(result, expected)
def test_gelu_operator_with_parameters():
runtime = get_runtime()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment