Commit 9ffb5145 authored by tsocha's avatar tsocha Committed by Nick Korovaiko

[Py] Enable ngraph-cpp ops in Python API (#820)

* Enable BatchNorm op

* Enable function call op

* Enable get output element op
parent eec19220
......@@ -22,6 +22,7 @@ from ngraph.ops import add
from ngraph.ops import asin
from ngraph.ops import atan
from ngraph.ops import avg_pool
from ngraph.ops import batch_norm
from ngraph.ops import broadcast
from ngraph.ops import ceiling
from ngraph.ops import ceiling as ceil
......@@ -35,7 +36,9 @@ from ngraph.ops import divide
from ngraph.ops import dot
from ngraph.ops import equal
from ngraph.ops import exp
from ngraph.ops import function_call
from ngraph.ops import floor
from ngraph.ops import get_output_element
from ngraph.ops import greater
from ngraph.ops import greater_eq
from ngraph.ops import less
......
......@@ -20,11 +20,12 @@ import numpy as np
from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, NodeVector, \
Shape, Strides
from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, Broadcast, Ceiling, Concat, \
Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq,\
Less, LessEq, Log, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, \
OneHot, Pad, Parameter, Product, Power, Relu, ReplaceSlice, Reshape, Reverse, Select, \
Sign, Sin, Sinh, Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh
from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, BatchNorm, Broadcast, Ceiling,\
Concat, Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, \
FunctionCall, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, Max, Maximum, MaxPool, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Pad, Parameter, Product, Power, Relu, \
ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, Sqrt, Subtract, Sum, \
Tan, Tanh
from typing import Iterable, List
......@@ -761,3 +762,33 @@ def reverse(node, reversed_axes, name=None): # type: (Node, List[int], str) ->
:return: The new node with reversed axes.
"""
return Reverse(node, AxisSet(reversed_axes))
@nameable_op
def batch_norm(eps, # type: float
gamma, # type: Node
beta, # type: Node
data, # type: Node
mean=None, # type: Node
variance=None, # type: Node
training=False, # type: bool
name=None, # type: str
):
# type: (...) -> Node
"""Return batch normalization node."""
if mean is None and variance is None:
return BatchNorm(eps, gamma, beta, data)
else:
return BatchNorm(eps, gamma, beta, data, mean, variance, training)
@nameable_op
def function_call(function_to_call, args): # type: (Node, NodeVector) -> Node
"""Return Function call op."""
return FunctionCall(function_to_call, args)
@nameable_op
def get_output_element(data, index): # type: (Node, int) -> Node
"""Return the `n`th element of the input tuple."""
return GetOutputElement(data, index)
......@@ -33,6 +33,14 @@ void regclass_pyngraph_op_BatchNorm(py::module m)
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&>());
batch_norm.def(py::init<double,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool&>());
}
void regclass_pyngraph_op_BatchNormBackprop(py::module m)
......
......@@ -19,6 +19,7 @@ import json
import ngraph as ng
from test.ngraph.util import get_runtime, run_op_node
from ngraph.impl import Function, NodeVector
@pytest.mark.parametrize('dtype', [np.float32, np.float64,
......@@ -48,6 +49,26 @@ def test_simple_computation_on_ndarrays(dtype):
assert np.allclose(result, np.array([[630, 704], [782, 864]], dtype=dtype))
def test_function_call():
runtime = get_runtime()
dtype = int
shape = [2, 2]
parameter_a = ng.parameter(shape, dtype=dtype, name='A')
parameter_b = ng.parameter(shape, dtype=dtype, name='B')
parameter_c = ng.parameter(shape, dtype=dtype, name='C')
parameter_list = [parameter_a, parameter_b, parameter_c]
ops = ((parameter_a + parameter_b) * parameter_c)
func = Function(NodeVector([ops]), parameter_list, 'addmul')
fc = ng.function_call(func, NodeVector(parameter_list))
computation = runtime.computation(fc, parameter_a, parameter_b, parameter_c)
value_a = np.array([[1, 2], [3, 4]], dtype=dtype)
value_b = np.array([[5, 6], [7, 8]], dtype=dtype)
value_c = np.array([[9, 10], [11, 12]], dtype=dtype)
result = computation(value_a, value_b, value_c)
assert np.allclose(result, np.array([[54, 80], [110, 144]], dtype=dtype))
def test_serialization():
dtype = np.float32
manager_name = pytest.config.getoption('backend', default='CPU')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment