Commit fa6c2a60 authored by arogowie-intel's avatar arogowie-intel Committed by Scott Cyphers

[Py] Python nGraph operations wrappers. (#821)

* Add/update Python wrappers for nGraph operations.

- NotEqual, OneHot, Power, Sqrt, Relu, Sign, Sin, Sinh, Tan, Subtract, Select, Tanh, Sum, Reduce,
Softmax, ReplaceSlice, Reverse
- Add UT for Relu, Sign, Sin, Sinh, Sqrt, Tan, Tanh,

* Add UT for cases when Cos and Sin are giving incorrect results.

* Alphabetically sorted imports.

* Small refactoring.

- Update docstrings
- Remove unnecesary auxiliary local variable.
parent 84c692cf
...@@ -38,27 +38,37 @@ from ngraph.ops import exp ...@@ -38,27 +38,37 @@ from ngraph.ops import exp
from ngraph.ops import floor from ngraph.ops import floor
from ngraph.ops import greater from ngraph.ops import greater
from ngraph.ops import greater_eq from ngraph.ops import greater_eq
from ngraph.ops import log
from ngraph.ops import less from ngraph.ops import less
from ngraph.ops import less_eq from ngraph.ops import less_eq
from ngraph.ops import log
from ngraph.ops import logical_not from ngraph.ops import logical_not
from ngraph.ops import max from ngraph.ops import max
from ngraph.ops import maximum
from ngraph.ops import max_pool from ngraph.ops import max_pool
from ngraph.ops import maximum
from ngraph.ops import min from ngraph.ops import min
from ngraph.ops import minimum from ngraph.ops import minimum
from ngraph.ops import multiply from ngraph.ops import multiply
from ngraph.ops import negative from ngraph.ops import negative
from ngraph.ops import not_equal from ngraph.ops import not_equal
from ngraph.ops import one_hot
from ngraph.ops import pad from ngraph.ops import pad
from ngraph.ops import parameter from ngraph.ops import parameter
from ngraph.ops import power
from ngraph.ops import prod from ngraph.ops import prod
from ngraph.ops import relu
from ngraph.ops import replace_slice
from ngraph.ops import reshape from ngraph.ops import reshape
from ngraph.ops import reverse
from ngraph.ops import select
from ngraph.ops import sign
from ngraph.ops import sin
from ngraph.ops import sinh
from ngraph.ops import slice from ngraph.ops import slice
from ngraph.ops import softmax from ngraph.ops import softmax
from ngraph.ops import sqrt from ngraph.ops import sqrt
from ngraph.ops import subtract from ngraph.ops import subtract
from ngraph.ops import sum from ngraph.ops import sum
from ngraph.ops import tan
from ngraph.ops import tanh from ngraph.ops import tanh
from ngraph.runtime import runtime from ngraph.runtime import runtime
...@@ -35,10 +35,13 @@ sys.setdlopenflags(flags) ...@@ -35,10 +35,13 @@ sys.setdlopenflags(flags)
from _pyngraph.op import Abs from _pyngraph.op import Abs
from _pyngraph.op import Acos from _pyngraph.op import Acos
from _pyngraph.op import Add from _pyngraph.op import Add
from _pyngraph.op import AllReduce
from _pyngraph.op import Asin from _pyngraph.op import Asin
from _pyngraph.op import Atan from _pyngraph.op import Atan
from _pyngraph.op import AvgPool from _pyngraph.op import AvgPool
from _pyngraph.op import AvgPoolBackprop from _pyngraph.op import AvgPoolBackprop
from _pyngraph.op import BatchNorm
from _pyngraph.op import BatchNormBackprop
from _pyngraph.op import Broadcast from _pyngraph.op import Broadcast
from _pyngraph.op import Ceiling from _pyngraph.op import Ceiling
from _pyngraph.op import Concat from _pyngraph.op import Concat
...@@ -54,28 +57,33 @@ from _pyngraph.op import Dot ...@@ -54,28 +57,33 @@ from _pyngraph.op import Dot
from _pyngraph.op import Equal from _pyngraph.op import Equal
from _pyngraph.op import Exp from _pyngraph.op import Exp
from _pyngraph.op import Floor from _pyngraph.op import Floor
from _pyngraph.op import FunctionCall
from _pyngraph.op import GetOutputElement
from _pyngraph.op import Greater from _pyngraph.op import Greater
from _pyngraph.op import GreaterEq from _pyngraph.op import GreaterEq
from _pyngraph.op import Less from _pyngraph.op import Less
from _pyngraph.op import LessEq from _pyngraph.op import LessEq
from _pyngraph.op import Log from _pyngraph.op import Log
from _pyngraph.op import Max from _pyngraph.op import Max
from _pyngraph.op import Maximum
from _pyngraph.op import MaxPool from _pyngraph.op import MaxPool
from _pyngraph.op import MaxPoolBackprop from _pyngraph.op import MaxPoolBackprop
from _pyngraph.op import Maximum
from _pyngraph.op import Min from _pyngraph.op import Min
from _pyngraph.op import Minimum from _pyngraph.op import Minimum
from _pyngraph.op import Multiply from _pyngraph.op import Multiply
from _pyngraph.op import Negative from _pyngraph.op import Negative
from _pyngraph.op import NotEqual
from _pyngraph.op import Not from _pyngraph.op import Not
from _pyngraph.op import NotEqual
from _pyngraph.op import OneHot from _pyngraph.op import OneHot
from _pyngraph.op import Op from _pyngraph.op import Op
from _pyngraph.op import Pad from _pyngraph.op import Pad
from _pyngraph.op import Parameter from _pyngraph.op import Parameter
from _pyngraph.op import ParameterVector from _pyngraph.op import ParameterVector
from _pyngraph.op import Power from _pyngraph.op import Power
from _pyngraph.op import Product
from _pyngraph.op import Reduce from _pyngraph.op import Reduce
from _pyngraph.op import Relu
from _pyngraph.op import ReluBackprop
from _pyngraph.op import ReplaceSlice from _pyngraph.op import ReplaceSlice
from _pyngraph.op import Reshape from _pyngraph.op import Reshape
from _pyngraph.op import Reverse from _pyngraph.op import Reverse
...@@ -84,17 +92,9 @@ from _pyngraph.op import Sign ...@@ -84,17 +92,9 @@ from _pyngraph.op import Sign
from _pyngraph.op import Sin from _pyngraph.op import Sin
from _pyngraph.op import Sinh from _pyngraph.op import Sinh
from _pyngraph.op import Slice from _pyngraph.op import Slice
from _pyngraph.op import Softmax
from _pyngraph.op import Sqrt from _pyngraph.op import Sqrt
from _pyngraph.op import Subtract from _pyngraph.op import Subtract
from _pyngraph.op import Sum from _pyngraph.op import Sum
from _pyngraph.op import Tan from _pyngraph.op import Tan
from _pyngraph.op import Tanh from _pyngraph.op import Tanh
from _pyngraph.op import Relu
from _pyngraph.op import ReluBackprop
from _pyngraph.op import Product
from _pyngraph.op import AllReduce
from _pyngraph.op import FunctionCall
from _pyngraph.op import GetOutputElement
from _pyngraph.op import BatchNorm
from _pyngraph.op import BatchNormBackprop
from _pyngraph.op import Softmax
This diff is collapsed.
...@@ -20,19 +20,40 @@ import ngraph as ng ...@@ -20,19 +20,40 @@ import ngraph as ng
from test.ngraph.util import run_op_numeric_data, run_op_node from test.ngraph.util import run_op_numeric_data, run_op_node
@pytest.mark.xfail(reason='Results mismatch when passing created Constant node from raw data.')
@pytest.mark.parametrize('ng_api_fn, numpy_fn, range_start, range_end', [
(ng.cos, np.cos, -100., 100.),
(ng.sin, np.sin, -100., 100.),
])
def test_unary_op_array_err(ng_api_fn, numpy_fn, range_start, range_end):
np.random.seed(133391)
input_data = range_start + np.random.rand(2, 3, 4) * (range_end - range_start)
expected = numpy_fn(input_data)
result = run_op_numeric_data(input_data, ng_api_fn)
assert np.allclose(result, expected)
@pytest.mark.parametrize('ng_api_fn, numpy_fn, range_start, range_end', [ @pytest.mark.parametrize('ng_api_fn, numpy_fn, range_start, range_end', [
(ng.absolute, np.abs, -1, 1), (ng.absolute, np.abs, -1, 1),
(ng.abs, np.abs, -1, 1), (ng.abs, np.abs, -1, 1),
(ng.acos, np.arccos, -1, 1), (ng.acos, np.arccos, -1, 1),
(ng.asin, np.arcsin, -1, 1), (ng.asin, np.arcsin, -1, 1),
(ng.atan, np.arctan, -100, 100), (ng.atan, np.arctan, -100., 100.),
(ng.ceiling, np.ceil, -100, 100), (ng.ceiling, np.ceil, -100., 100.),
(ng.ceil, np.ceil, -100, 100), (ng.ceil, np.ceil, -100., 100.),
(ng.cos, np.cos, -np.pi, np.pi), (ng.cos, np.cos, -np.pi * 2., np.pi * 2.),
(ng.cosh, np.cosh, -np.pi, np.pi), (ng.cosh, np.cosh, -100., 100.),
(ng.exp, np.exp, -100, 100), (ng.exp, np.exp, -100., 100.),
(ng.floor, np.floor, -100, 100), (ng.floor, np.floor, -100., 100.),
(ng.log, np.log, 0, 100), (ng.log, np.log, 0, 100.),
(ng.relu, lambda x: np.maximum(0, x), -100., 100.),
(ng.sign, np.sign, -100., 100.),
(ng.sin, np.sin, -np.pi * 2., np.pi * 2.),
(ng.sinh, np.sinh, -100., 100.),
(ng.sqrt, np.sqrt, 0., 100.),
(ng.tan, np.tan, -1., 1.),
(ng.tanh, np.tanh, -100., 100.),
]) ])
def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end): def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end):
np.random.seed(133391) np.random.seed(133391)
...@@ -59,6 +80,13 @@ def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end): ...@@ -59,6 +80,13 @@ def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end):
(ng.exp, np.exp, np.float32(1.5)), (ng.exp, np.exp, np.float32(1.5)),
(ng.floor, np.floor, np.float32(1.5)), (ng.floor, np.floor, np.float32(1.5)),
(ng.log, np.log, np.float32(1.5)), (ng.log, np.log, np.float32(1.5)),
(ng.relu, lambda x: np.maximum(0, x), np.float32(-0.125)),
(ng.sign, np.sign, np.float32(0.)),
(ng.sin, np.sin, np.float32(np.pi / 4.0)),
(ng.sinh, np.sinh, np.float32(0.)),
(ng.sqrt, np.sqrt, np.float32(3.5)),
(ng.tan, np.tan, np.float32(np.pi / 4.0)),
(ng.tanh, np.tanh, np.float32(0.1234)),
]) ])
def test_unary_op_scalar(ng_api_fn, numpy_fn, input_data): def test_unary_op_scalar(ng_api_fn, numpy_fn, input_data):
expected = numpy_fn(input_data) expected = numpy_fn(input_data)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment