Commit 9ffb5145 authored by tsocha's avatar tsocha Committed by Nick Korovaiko

[Py] Enable ngraph-cpp ops in Python API (#820)

* Enable BatchNorm op

* Enable function call op

* Enable get output element op
parent eec19220
...@@ -22,6 +22,7 @@ from ngraph.ops import add ...@@ -22,6 +22,7 @@ from ngraph.ops import add
from ngraph.ops import asin from ngraph.ops import asin
from ngraph.ops import atan from ngraph.ops import atan
from ngraph.ops import avg_pool from ngraph.ops import avg_pool
from ngraph.ops import batch_norm
from ngraph.ops import broadcast from ngraph.ops import broadcast
from ngraph.ops import ceiling from ngraph.ops import ceiling
from ngraph.ops import ceiling as ceil from ngraph.ops import ceiling as ceil
...@@ -35,7 +36,9 @@ from ngraph.ops import divide ...@@ -35,7 +36,9 @@ from ngraph.ops import divide
from ngraph.ops import dot from ngraph.ops import dot
from ngraph.ops import equal from ngraph.ops import equal
from ngraph.ops import exp from ngraph.ops import exp
from ngraph.ops import function_call
from ngraph.ops import floor from ngraph.ops import floor
from ngraph.ops import get_output_element
from ngraph.ops import greater from ngraph.ops import greater
from ngraph.ops import greater_eq from ngraph.ops import greater_eq
from ngraph.ops import less from ngraph.ops import less
......
...@@ -20,11 +20,12 @@ import numpy as np ...@@ -20,11 +20,12 @@ import numpy as np
from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, NodeVector, \ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, NodeVector, \
Shape, Strides Shape, Strides
from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, Broadcast, Ceiling, Concat, \ from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, BatchNorm, Broadcast, Ceiling,\
Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq,\ Concat, Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, \
Less, LessEq, Log, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, \ FunctionCall, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, Max, Maximum, MaxPool, \
OneHot, Pad, Parameter, Product, Power, Relu, ReplaceSlice, Reshape, Reverse, Select, \ Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Pad, Parameter, Product, Power, Relu, \
Sign, Sin, Sinh, Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, Sqrt, Subtract, Sum, \
Tan, Tanh
from typing import Iterable, List from typing import Iterable, List
...@@ -761,3 +762,33 @@ def reverse(node, reversed_axes, name=None): # type: (Node, List[int], str) -> ...@@ -761,3 +762,33 @@ def reverse(node, reversed_axes, name=None): # type: (Node, List[int], str) ->
:return: The new node with reversed axes. :return: The new node with reversed axes.
""" """
return Reverse(node, AxisSet(reversed_axes)) return Reverse(node, AxisSet(reversed_axes))
@nameable_op
def batch_norm(eps, # type: float
gamma, # type: Node
beta, # type: Node
data, # type: Node
mean=None, # type: Node
variance=None, # type: Node
training=False, # type: bool
name=None, # type: str
):
# type: (...) -> Node
"""Return batch normalization node."""
if mean is None and variance is None:
return BatchNorm(eps, gamma, beta, data)
else:
return BatchNorm(eps, gamma, beta, data, mean, variance, training)
@nameable_op
def function_call(function_to_call, args): # type: (Node, NodeVector) -> Node
"""Return Function call op."""
return FunctionCall(function_to_call, args)
@nameable_op
def get_output_element(data, index): # type: (Node, int) -> Node
"""Return the `n`th element of the input tuple."""
return GetOutputElement(data, index)
...@@ -33,6 +33,14 @@ void regclass_pyngraph_op_BatchNorm(py::module m) ...@@ -33,6 +33,14 @@ void regclass_pyngraph_op_BatchNorm(py::module m)
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&>()); const std::shared_ptr<ngraph::Node>&>());
batch_norm.def(py::init<double,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool&>());
} }
void regclass_pyngraph_op_BatchNormBackprop(py::module m) void regclass_pyngraph_op_BatchNormBackprop(py::module m)
......
...@@ -19,6 +19,7 @@ import json ...@@ -19,6 +19,7 @@ import json
import ngraph as ng import ngraph as ng
from test.ngraph.util import get_runtime, run_op_node from test.ngraph.util import get_runtime, run_op_node
from ngraph.impl import Function, NodeVector
@pytest.mark.parametrize('dtype', [np.float32, np.float64, @pytest.mark.parametrize('dtype', [np.float32, np.float64,
...@@ -48,6 +49,26 @@ def test_simple_computation_on_ndarrays(dtype): ...@@ -48,6 +49,26 @@ def test_simple_computation_on_ndarrays(dtype):
assert np.allclose(result, np.array([[630, 704], [782, 864]], dtype=dtype)) assert np.allclose(result, np.array([[630, 704], [782, 864]], dtype=dtype))
def test_function_call():
runtime = get_runtime()
dtype = int
shape = [2, 2]
parameter_a = ng.parameter(shape, dtype=dtype, name='A')
parameter_b = ng.parameter(shape, dtype=dtype, name='B')
parameter_c = ng.parameter(shape, dtype=dtype, name='C')
parameter_list = [parameter_a, parameter_b, parameter_c]
ops = ((parameter_a + parameter_b) * parameter_c)
func = Function(NodeVector([ops]), parameter_list, 'addmul')
fc = ng.function_call(func, NodeVector(parameter_list))
computation = runtime.computation(fc, parameter_a, parameter_b, parameter_c)
value_a = np.array([[1, 2], [3, 4]], dtype=dtype)
value_b = np.array([[5, 6], [7, 8]], dtype=dtype)
value_c = np.array([[9, 10], [11, 12]], dtype=dtype)
result = computation(value_a, value_b, value_c)
assert np.allclose(result, np.array([[54, 80], [110, 144]], dtype=dtype))
def test_serialization(): def test_serialization():
dtype = np.float32 dtype = np.float32
manager_name = pytest.config.getoption('backend', default='CPU') manager_name = pytest.config.getoption('backend', default='CPU')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment