Commit fa6c2a60 authored by arogowie-intel's avatar arogowie-intel Committed by Scott Cyphers

[Py] Python nGraph operations wrappers. (#821)

* Add/update Python wrappers for nGraph operations.

- NotEqual, OneHot, Power, Sqrt, Relu, Sign, Sin, Sinh, Tan, Subtract, Select, Tanh, Sum, Reduce,
Softmax, ReplaceSlice, Reverse
- Add UT for Relu, Sign, Sin, Sinh, Sqrt, Tan, Tanh,

* Add UT for cases when Cos and Sin are giving incorrect results.

* Alphabetically sorted imports.

* Small refactoring.

- Update docstrings
- Remove unnecesary auxiliary local variable.
parent 84c692cf
......@@ -38,27 +38,37 @@ from ngraph.ops import exp
from ngraph.ops import floor
from ngraph.ops import greater
from ngraph.ops import greater_eq
from ngraph.ops import log
from ngraph.ops import less
from ngraph.ops import less_eq
from ngraph.ops import log
from ngraph.ops import logical_not
from ngraph.ops import max
from ngraph.ops import maximum
from ngraph.ops import max_pool
from ngraph.ops import maximum
from ngraph.ops import min
from ngraph.ops import minimum
from ngraph.ops import multiply
from ngraph.ops import negative
from ngraph.ops import not_equal
from ngraph.ops import one_hot
from ngraph.ops import pad
from ngraph.ops import parameter
from ngraph.ops import power
from ngraph.ops import prod
from ngraph.ops import relu
from ngraph.ops import replace_slice
from ngraph.ops import reshape
from ngraph.ops import reverse
from ngraph.ops import select
from ngraph.ops import sign
from ngraph.ops import sin
from ngraph.ops import sinh
from ngraph.ops import slice
from ngraph.ops import softmax
from ngraph.ops import sqrt
from ngraph.ops import subtract
from ngraph.ops import sum
from ngraph.ops import tan
from ngraph.ops import tanh
from ngraph.runtime import runtime
......@@ -35,10 +35,13 @@ sys.setdlopenflags(flags)
from _pyngraph.op import Abs
from _pyngraph.op import Acos
from _pyngraph.op import Add
from _pyngraph.op import AllReduce
from _pyngraph.op import Asin
from _pyngraph.op import Atan
from _pyngraph.op import AvgPool
from _pyngraph.op import AvgPoolBackprop
from _pyngraph.op import BatchNorm
from _pyngraph.op import BatchNormBackprop
from _pyngraph.op import Broadcast
from _pyngraph.op import Ceiling
from _pyngraph.op import Concat
......@@ -54,28 +57,33 @@ from _pyngraph.op import Dot
from _pyngraph.op import Equal
from _pyngraph.op import Exp
from _pyngraph.op import Floor
from _pyngraph.op import FunctionCall
from _pyngraph.op import GetOutputElement
from _pyngraph.op import Greater
from _pyngraph.op import GreaterEq
from _pyngraph.op import Less
from _pyngraph.op import LessEq
from _pyngraph.op import Log
from _pyngraph.op import Max
from _pyngraph.op import Maximum
from _pyngraph.op import MaxPool
from _pyngraph.op import MaxPoolBackprop
from _pyngraph.op import Maximum
from _pyngraph.op import Min
from _pyngraph.op import Minimum
from _pyngraph.op import Multiply
from _pyngraph.op import Negative
from _pyngraph.op import NotEqual
from _pyngraph.op import Not
from _pyngraph.op import NotEqual
from _pyngraph.op import OneHot
from _pyngraph.op import Op
from _pyngraph.op import Pad
from _pyngraph.op import Parameter
from _pyngraph.op import ParameterVector
from _pyngraph.op import Power
from _pyngraph.op import Product
from _pyngraph.op import Reduce
from _pyngraph.op import Relu
from _pyngraph.op import ReluBackprop
from _pyngraph.op import ReplaceSlice
from _pyngraph.op import Reshape
from _pyngraph.op import Reverse
......@@ -84,17 +92,9 @@ from _pyngraph.op import Sign
from _pyngraph.op import Sin
from _pyngraph.op import Sinh
from _pyngraph.op import Slice
from _pyngraph.op import Softmax
from _pyngraph.op import Sqrt
from _pyngraph.op import Subtract
from _pyngraph.op import Sum
from _pyngraph.op import Tan
from _pyngraph.op import Tanh
from _pyngraph.op import Relu
from _pyngraph.op import ReluBackprop
from _pyngraph.op import Product
from _pyngraph.op import AllReduce
from _pyngraph.op import FunctionCall
from _pyngraph.op import GetOutputElement
from _pyngraph.op import BatchNorm
from _pyngraph.op import BatchNormBackprop
from _pyngraph.op import Softmax
This diff is collapsed.
......@@ -20,19 +20,40 @@ import ngraph as ng
from test.ngraph.util import run_op_numeric_data, run_op_node
@pytest.mark.xfail(reason='Results mismatch when passing created Constant node from raw data.')
@pytest.mark.parametrize('ng_api_fn, numpy_fn, range_start, range_end', [
(ng.cos, np.cos, -100., 100.),
(ng.sin, np.sin, -100., 100.),
])
def test_unary_op_array_err(ng_api_fn, numpy_fn, range_start, range_end):
np.random.seed(133391)
input_data = range_start + np.random.rand(2, 3, 4) * (range_end - range_start)
expected = numpy_fn(input_data)
result = run_op_numeric_data(input_data, ng_api_fn)
assert np.allclose(result, expected)
@pytest.mark.parametrize('ng_api_fn, numpy_fn, range_start, range_end', [
(ng.absolute, np.abs, -1, 1),
(ng.abs, np.abs, -1, 1),
(ng.acos, np.arccos, -1, 1),
(ng.asin, np.arcsin, -1, 1),
(ng.atan, np.arctan, -100, 100),
(ng.ceiling, np.ceil, -100, 100),
(ng.ceil, np.ceil, -100, 100),
(ng.cos, np.cos, -np.pi, np.pi),
(ng.cosh, np.cosh, -np.pi, np.pi),
(ng.exp, np.exp, -100, 100),
(ng.floor, np.floor, -100, 100),
(ng.log, np.log, 0, 100),
(ng.atan, np.arctan, -100., 100.),
(ng.ceiling, np.ceil, -100., 100.),
(ng.ceil, np.ceil, -100., 100.),
(ng.cos, np.cos, -np.pi * 2., np.pi * 2.),
(ng.cosh, np.cosh, -100., 100.),
(ng.exp, np.exp, -100., 100.),
(ng.floor, np.floor, -100., 100.),
(ng.log, np.log, 0, 100.),
(ng.relu, lambda x: np.maximum(0, x), -100., 100.),
(ng.sign, np.sign, -100., 100.),
(ng.sin, np.sin, -np.pi * 2., np.pi * 2.),
(ng.sinh, np.sinh, -100., 100.),
(ng.sqrt, np.sqrt, 0., 100.),
(ng.tan, np.tan, -1., 1.),
(ng.tanh, np.tanh, -100., 100.),
])
def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end):
np.random.seed(133391)
......@@ -59,6 +80,13 @@ def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end):
(ng.exp, np.exp, np.float32(1.5)),
(ng.floor, np.floor, np.float32(1.5)),
(ng.log, np.log, np.float32(1.5)),
(ng.relu, lambda x: np.maximum(0, x), np.float32(-0.125)),
(ng.sign, np.sign, np.float32(0.)),
(ng.sin, np.sin, np.float32(np.pi / 4.0)),
(ng.sinh, np.sinh, np.float32(0.)),
(ng.sqrt, np.sqrt, np.float32(3.5)),
(ng.tan, np.tan, np.float32(np.pi / 4.0)),
(ng.tanh, np.tanh, np.float32(0.1234)),
])
def test_unary_op_scalar(ng_api_fn, numpy_fn, input_data):
expected = numpy_fn(input_data)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment