Unverified Commit 72fa2015 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #3331 from NervanaSystems/etusien/gemm

[Py] Added gemm operator to Python API.
parents c70c2798 c305329e
...@@ -38,6 +38,7 @@ ngraph.ops ...@@ -38,6 +38,7 @@ ngraph.ops
exp exp
floor floor
gelu gelu
gemm
get_output_element get_output_element
greater greater
greater_eq greater_eq
......
...@@ -51,6 +51,7 @@ from ngraph.ops import equal ...@@ -51,6 +51,7 @@ from ngraph.ops import equal
from ngraph.ops import exp from ngraph.ops import exp
from ngraph.ops import floor from ngraph.ops import floor
from ngraph.ops import gelu from ngraph.ops import gelu
from ngraph.ops import gemm
from ngraph.ops import get_output_element from ngraph.ops import get_output_element
from ngraph.ops import greater from ngraph.ops import greater
from ngraph.ops import greater_eq from ngraph.ops import greater_eq
......
...@@ -75,6 +75,7 @@ from _pyngraph.op import Equal ...@@ -75,6 +75,7 @@ from _pyngraph.op import Equal
from _pyngraph.op import Exp from _pyngraph.op import Exp
from _pyngraph.op import Floor from _pyngraph.op import Floor
from _pyngraph.op import Gelu from _pyngraph.op import Gelu
from _pyngraph.op import Gemm
from _pyngraph.op import GetOutputElement from _pyngraph.op import GetOutputElement
from _pyngraph.op import Greater from _pyngraph.op import Greater
from _pyngraph.op import GreaterEq from _pyngraph.op import GreaterEq
......
...@@ -23,8 +23,8 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio ...@@ -23,8 +23,8 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \ BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \ Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \
Gelu, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \ Gelu, Gemm, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \ MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \
Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \ Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \
Sqrt, Subtract, Sum, Tan, Tanh, TopK Sqrt, Subtract, Sum, Tan, Tanh, TopK
...@@ -520,6 +520,46 @@ def broadcast_to(node, new_shape, axis=None, name=None): ...@@ -520,6 +520,46 @@ def broadcast_to(node, new_shape, axis=None, name=None):
return Broadcast(node, Shape(new_shape), get_broadcast_axes(new_shape, node.shape, axis)) return Broadcast(node, Shape(new_shape), get_broadcast_axes(new_shape, node.shape, axis))
@nameable_op
def gemm(A, # type: Node
B, # type: Node
C, # type: Node
alpha, # type: ScalarData
beta, # type: ScalarData
transA, # type: bool
transB, # type: bool
name=None, # type: str
):
# type: (...) -> Node
r"""Perform General matrix-matrix multiplication on input tensors A, B and C.
Computes:
.. math:: Y = alpha\cdot A'\cdot B' + beta\cdot C
:code:`A'` is the transpose of matrix :code:`A` with shape (M, K),
if :code:`transA` is :code:`True`, otherwise :code:`A` with shape (K, N).
:code:`B'` is the transpose of matrix :code:`B` with shape (K, N),
if :code:`transB` is :code:`True`, otherwise :code:`B` with shape (N, K).
:code:`C`: Matrix broadcastable to shape (M, N).
:code:`Y`: Matrix with shape (M, N).
:param A: The node with input tensor A.
:param B: The node with input tensor B.
:param C: The node with input tensor C.
:param alpha: Scalar multiplier for the product of input tensors A * B.
:param beta: Scalar multiplier for input tensor C.
:param transA: Whether A should be transposed. Boolean value.
:param transB: Whether B should be transposed. Boolean value.
:param name: Optional name for the output node.
:return: Return node with tensor of shape (M, N).
"""
return Gemm(A, B, C, alpha, beta, transA, transB)
@nameable_op @nameable_op
def convert(node, new_type, name=None): # type: (Node, NumericType, str) -> Node def convert(node, new_type, name=None): # type: (Node, NumericType, str) -> Node
"""Return node which casts input node values to specified type.""" """Return node which casts input node values to specified type."""
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/gemm.hpp"
#include "pyngraph/ops/fused/gemm.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Gemm(py::module m)
{
py::class_<ngraph::op::Gemm, std::shared_ptr<ngraph::op::Gemm>, ngraph::op::Op> gemm(m, "Gemm");
gemm.doc() = "ngraph.impl.op.Gemm wraps ngraph::op::Gemm";
gemm.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
double&,
double&,
bool&,
bool&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Gemm(py::module m);
...@@ -55,6 +55,7 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -55,6 +55,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Exp(m_op); regclass_pyngraph_op_Exp(m_op);
regclass_pyngraph_op_Floor(m_op); regclass_pyngraph_op_Floor(m_op);
regclass_pyngraph_op_Gelu(m_op); regclass_pyngraph_op_Gelu(m_op);
regclass_pyngraph_op_Gemm(m_op);
regclass_pyngraph_op_GetOutputElement(m_op); regclass_pyngraph_op_GetOutputElement(m_op);
regclass_pyngraph_op_Greater(m_op); regclass_pyngraph_op_Greater(m_op);
regclass_pyngraph_op_GreaterEq(m_op); regclass_pyngraph_op_GreaterEq(m_op);
......
...@@ -45,6 +45,7 @@ ...@@ -45,6 +45,7 @@
#include "pyngraph/ops/fused/clamp.hpp" #include "pyngraph/ops/fused/clamp.hpp"
#include "pyngraph/ops/fused/elu.hpp" #include "pyngraph/ops/fused/elu.hpp"
#include "pyngraph/ops/fused/gelu.hpp" #include "pyngraph/ops/fused/gelu.hpp"
#include "pyngraph/ops/fused/gemm.hpp"
#include "pyngraph/ops/get_output_element.hpp" #include "pyngraph/ops/get_output_element.hpp"
#include "pyngraph/ops/greater.hpp" #include "pyngraph/ops/greater.hpp"
#include "pyngraph/ops/greater_eq.hpp" #include "pyngraph/ops/greater_eq.hpp"
......
...@@ -185,6 +185,7 @@ sources = [ ...@@ -185,6 +185,7 @@ sources = [
'pyngraph/ops/exp.cpp', 'pyngraph/ops/exp.cpp',
'pyngraph/ops/floor.cpp', 'pyngraph/ops/floor.cpp',
'pyngraph/ops/fused/gelu.cpp', 'pyngraph/ops/fused/gelu.cpp',
'pyngraph/ops/fused/gemm.cpp',
'pyngraph/ops/greater.cpp', 'pyngraph/ops/greater.cpp',
'pyngraph/ops/greater_eq.cpp', 'pyngraph/ops/greater_eq.cpp',
'pyngraph/ops/less.cpp', 'pyngraph/ops/less.cpp',
......
...@@ -69,6 +69,43 @@ def test_elu_operator_with_scalar(): ...@@ -69,6 +69,43 @@ def test_elu_operator_with_scalar():
assert np.allclose(result, expected) assert np.allclose(result, expected)
def test_gemm_operator():
runtime = get_runtime()
shape_a = [3, 2]
shape_b = [3, 2]
shape_c = [2, 1]
value_a = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32)
value_b = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32)
value_c = np.array([[13], [14]], dtype=np.float32)
parameter_a = ng.parameter(shape_a, name='A', dtype=np.float32)
parameter_b = ng.parameter(shape_b, name='B', dtype=np.float32)
parameter_c = ng.parameter(shape_c, name='C', dtype=np.float32)
alpha_value = np.float32(3)
beta_value = np.float32(3)
transA = True
transB = False
model = ng.gemm(parameter_a, parameter_b, parameter_c, alpha_value, beta_value, transA, transB)
computation = runtime.computation(model, parameter_a, parameter_b, parameter_c)
result = computation(value_a, value_b, value_c)
# expected = value_alpha * value_a' * value_b + value_beta * value_c
value_a = value_a.transpose()
a_mul_a = np.multiply(alpha_value, value_a)
aa_mul_b = np.dot(a_mul_a, value_b)
b_mul_c = np.dot(beta_value, value_c)
expected = np.add(aa_mul_b, b_mul_c)
assert np.allclose(result, expected)
def test_gelu_operator_with_parameters(): def test_gelu_operator_with_parameters():
runtime = get_runtime() runtime = get_runtime()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment