Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
f86c0557
Commit
f86c0557
authored
Jul 30, 2019
by
Ewa21
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Py] Added gemm operator to Python API.
parent
c1220108
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
146 additions
and
1 deletion
+146
-1
ngraph.ops.rst
doc/sphinx/source/python_api/_autosummary/ngraph.ops.rst
+1
-0
__init__.py
python/ngraph/__init__.py
+1
-0
__init__.py
python/ngraph/impl/op/__init__.py
+1
-0
ops.py
python/ngraph/ops.py
+45
-1
gemm.cpp
python/pyngraph/ops/fused/gemm.cpp
+36
-0
gemm.hpp
python/pyngraph/ops/fused/gemm.hpp
+23
-0
regmodule_pyngraph_op.cpp
python/pyngraph/ops/regmodule_pyngraph_op.cpp
+1
-0
regmodule_pyngraph_op.hpp
python/pyngraph/ops/regmodule_pyngraph_op.hpp
+1
-0
setup.py
python/setup.py
+1
-0
test_ops_fused.py
python/test/ngraph/test_ops_fused.py
+36
-0
No files found.
doc/sphinx/source/python_api/_autosummary/ngraph.ops.rst
View file @
f86c0557
...
...
@@ -36,6 +36,7 @@ ngraph.ops
equal
exp
floor
gemm
get_output_element
greater
greater_eq
...
...
python/ngraph/__init__.py
View file @
f86c0557
...
...
@@ -49,6 +49,7 @@ from ngraph.ops import elu
from
ngraph.ops
import
equal
from
ngraph.ops
import
exp
from
ngraph.ops
import
floor
from
ngraph.ops
import
gemm
from
ngraph.ops
import
get_output_element
from
ngraph.ops
import
greater
from
ngraph.ops
import
greater_eq
...
...
python/ngraph/impl/op/__init__.py
View file @
f86c0557
...
...
@@ -73,6 +73,7 @@ from _pyngraph.op import Elu
from
_pyngraph.op
import
Equal
from
_pyngraph.op
import
Exp
from
_pyngraph.op
import
Floor
from
_pyngraph.op
import
Gemm
from
_pyngraph.op
import
GetOutputElement
from
_pyngraph.op
import
Greater
from
_pyngraph.op
import
GreaterEq
...
...
python/ngraph/ops.py
View file @
f86c0557
...
...
@@ -23,7 +23,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from
ngraph.impl.op
import
Abs
,
Acos
,
Add
,
And
,
Asin
,
ArgMax
,
ArgMin
,
Atan
,
AvgPool
,
\
BatchNormTraining
,
BatchNormInference
,
Broadcast
,
Ceiling
,
Concat
,
Constant
,
Convert
,
\
Convolution
,
ConvolutionBackpropData
,
Cos
,
Cosh
,
Divide
,
Dot
,
Elu
,
Equal
,
Exp
,
Floor
,
\
GetOutputElement
,
Greater
,
GreaterEq
,
Less
,
LessEq
,
Log
,
LRN
,
Max
,
Maximum
,
MaxPool
,
\
Ge
mm
,
Ge
tOutputElement
,
Greater
,
GreaterEq
,
Less
,
LessEq
,
Log
,
LRN
,
Max
,
Maximum
,
MaxPool
,
\
Min
,
Minimum
,
Multiply
,
Negative
,
Not
,
NotEqual
,
OneHot
,
Or
,
Pad
,
Parameter
,
Product
,
\
Power
,
Relu
,
ReplaceSlice
,
Reshape
,
Reverse
,
Select
,
Sign
,
Sin
,
Sinh
,
Slice
,
Softmax
,
\
Sqrt
,
Subtract
,
Sum
,
Tan
,
Tanh
,
TopK
...
...
@@ -520,6 +520,50 @@ def broadcast_to(node, new_shape, axis=None, name=None):
return
Broadcast
(
node
,
Shape
(
new_shape
),
get_broadcast_axes
(
new_shape
,
node
.
shape
,
axis
))
@nameable_op
def
gemm
(
A
,
# type: Node
B
,
# type: Node
C
,
# type: Node
alpha
,
# type: ScalarData
beta
,
# type: ScalarData
transA
,
# type: bool
transB
,
# type: bool
name
=
None
,
# type: str
):
# type: (...) -> Node
r"""Perform General matrix-matrix multiplication on input tensors A, B and C.
Computes:
.. math:: Y = alpha\cdot A'\cdot B' + beta\cdot C
:code:`A'`: The transpose of matrix :code:`A` with shape (M, K),
if :code:`transA` is :code:`True`, otherwise :code:`A` with shape (K, N).
:code:`B'`: The transpose of matrix :code:`B` with shape (K, N),
if :code:`transB` is :code:`True`, otherwise :code:`B` with shape (N, K).
:code:`C`: Matrix broadcastable to shape (M, N).
:code:`Y`: Matrix with shape (M, N).
For more information refer to:
`Low-memory GEMM-based convolution algorithms for deep neural networks
<https://arxiv.org/pdf/1709.03395.pdf>`_
:param A: The node with input tensor A.
:param B: The node with input tensor B.
:param C: The node with input tensor C.
:param alpha: Scalar multiplier for the product of input tensors A * B.
:param beta: Scalar multiplier for input tensor C.
:param transA: Whether A should be transposed. Boolean value.
:param transB: Whether B should be transposed. Boolean value.
:param name: Optional name for the output node.
:return: Return node with tensor of shape (M, N).
"""
return
Gemm
(
A
,
B
,
C
,
alpha
,
beta
,
transA
,
transB
)
@nameable_op
def
convert
(
node
,
new_type
,
name
=
None
):
# type: (Node, NumericType, str) -> Node
"""Return node which casts input node values to specified type."""
...
...
python/pyngraph/ops/fused/gemm.cpp
0 → 100644
View file @
f86c0557
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/gemm.hpp"
#include "pyngraph/ops/fused/gemm.hpp"
namespace
py
=
pybind11
;
void
regclass_pyngraph_op_Gemm
(
py
::
module
m
)
{
py
::
class_
<
ngraph
::
op
::
Gemm
,
std
::
shared_ptr
<
ngraph
::
op
::
Gemm
>
,
ngraph
::
op
::
Op
>
gemm
(
m
,
"Gemm"
);
gemm
.
doc
()
=
"ngraph.impl.op.Gemm wraps ngraph::op::Gemm"
;
gemm
.
def
(
py
::
init
<
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
double
&
,
double
&
,
bool
&
,
bool
&>
());
}
python/pyngraph/ops/fused/gemm.hpp
0 → 100644
View file @
f86c0557
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace
py
=
pybind11
;
void
regclass_pyngraph_op_Gemm
(
py
::
module
m
);
python/pyngraph/ops/regmodule_pyngraph_op.cpp
View file @
f86c0557
...
...
@@ -53,6 +53,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Equal
(
m_op
);
regclass_pyngraph_op_Exp
(
m_op
);
regclass_pyngraph_op_Floor
(
m_op
);
regclass_pyngraph_op_Gemm
(
m_op
);
regclass_pyngraph_op_GetOutputElement
(
m_op
);
regclass_pyngraph_op_Greater
(
m_op
);
regclass_pyngraph_op_GreaterEq
(
m_op
);
...
...
python/pyngraph/ops/regmodule_pyngraph_op.hpp
View file @
f86c0557
...
...
@@ -43,6 +43,7 @@
#include "pyngraph/ops/equal.hpp"
#include "pyngraph/ops/exp.hpp"
#include "pyngraph/ops/floor.hpp"
#include "pyngraph/ops/fused/gemm.hpp"
#include "pyngraph/ops/get_output_element.hpp"
#include "pyngraph/ops/greater.hpp"
#include "pyngraph/ops/greater_eq.hpp"
...
...
python/setup.py
View file @
f86c0557
...
...
@@ -183,6 +183,7 @@ sources = [
'pyngraph/ops/equal.cpp'
,
'pyngraph/ops/exp.cpp'
,
'pyngraph/ops/floor.cpp'
,
'pyngraph/ops/fused/gemm.cpp'
,
'pyngraph/ops/greater.cpp'
,
'pyngraph/ops/greater_eq.cpp'
,
'pyngraph/ops/less.cpp'
,
...
...
python/test/ngraph/test_ops_fused.py
View file @
f86c0557
...
...
@@ -67,3 +67,39 @@ def test_elu_operator_with_scalar():
result
=
computation
(
data_value
)
expected
=
np
.
array
([[
-
2.9797862
,
1.
],
[
-
2.5939941
,
3.
]],
dtype
=
np
.
float32
)
assert
np
.
allclose
(
result
,
expected
)
def
test_gemm_operator
():
runtime
=
get_runtime
()
shape_a
=
[
3
,
2
]
shape_b
=
[
3
,
2
]
shape_c
=
[
2
,
1
]
value_a
=
np
.
array
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
]],
dtype
=
np
.
float32
)
value_b
=
np
.
array
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
]],
dtype
=
np
.
float32
)
value_c
=
np
.
array
([[
13
],
[
14
]],
dtype
=
np
.
float32
)
parameter_a
=
ng
.
parameter
(
shape_a
,
name
=
'A'
,
dtype
=
np
.
float32
)
parameter_b
=
ng
.
parameter
(
shape_b
,
name
=
'B'
,
dtype
=
np
.
float32
)
parameter_c
=
ng
.
parameter
(
shape_c
,
name
=
'C'
,
dtype
=
np
.
float32
)
alpha_value
=
np
.
float32
(
3
)
beta_value
=
np
.
float32
(
3
)
transA
=
True
transB
=
False
model
=
ng
.
gemm
(
parameter_a
,
parameter_b
,
parameter_c
,
alpha_value
,
beta_value
,
transA
,
transB
)
computation
=
runtime
.
computation
(
model
,
parameter_a
,
parameter_b
,
parameter_c
)
result
=
computation
(
value_a
,
value_b
,
value_c
)
# expected = value_alpha * value_a' * value_b + value_beta * value_c
value_a
=
value_a
.
transpose
()
a_mul_a
=
np
.
multiply
(
alpha_value
,
value_a
)
aa_mul_b
=
np
.
dot
(
a_mul_a
,
value_b
)
b_mul_c
=
np
.
dot
(
beta_value
,
value_c
)
expected
=
np
.
add
(
aa_mul_b
,
b_mul_c
)
assert
np
.
allclose
(
result
,
expected
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment