Commit fa6c2a60 authored by arogowie-intel's avatar arogowie-intel Committed by Scott Cyphers

[Py] Python nGraph operations wrappers. (#821)

* Add/update Python wrappers for nGraph operations.

- NotEqual, OneHot, Power, Sqrt, Relu, Sign, Sin, Sinh, Tan, Subtract, Select, Tanh, Sum, Reduce,
Softmax, ReplaceSlice, Reverse
- Add UT for Relu, Sign, Sin, Sinh, Sqrt, Tan, Tanh,

* Add UT for cases when Cos and Sin are giving incorrect results.

* Alphabetically sorted imports.

* Small refactoring.

- Update docstrings
- Remove unnecesary auxiliary local variable.
parent 84c692cf
......@@ -38,27 +38,37 @@ from ngraph.ops import exp
from ngraph.ops import floor
from ngraph.ops import greater
from ngraph.ops import greater_eq
from ngraph.ops import log
from ngraph.ops import less
from ngraph.ops import less_eq
from ngraph.ops import log
from ngraph.ops import logical_not
from ngraph.ops import max
from ngraph.ops import maximum
from ngraph.ops import max_pool
from ngraph.ops import maximum
from ngraph.ops import min
from ngraph.ops import minimum
from ngraph.ops import multiply
from ngraph.ops import negative
from ngraph.ops import not_equal
from ngraph.ops import one_hot
from ngraph.ops import pad
from ngraph.ops import parameter
from ngraph.ops import power
from ngraph.ops import prod
from ngraph.ops import relu
from ngraph.ops import replace_slice
from ngraph.ops import reshape
from ngraph.ops import reverse
from ngraph.ops import select
from ngraph.ops import sign
from ngraph.ops import sin
from ngraph.ops import sinh
from ngraph.ops import slice
from ngraph.ops import softmax
from ngraph.ops import sqrt
from ngraph.ops import subtract
from ngraph.ops import sum
from ngraph.ops import tan
from ngraph.ops import tanh
from ngraph.runtime import runtime
......@@ -35,10 +35,13 @@ sys.setdlopenflags(flags)
from _pyngraph.op import Abs
from _pyngraph.op import Acos
from _pyngraph.op import Add
from _pyngraph.op import AllReduce
from _pyngraph.op import Asin
from _pyngraph.op import Atan
from _pyngraph.op import AvgPool
from _pyngraph.op import AvgPoolBackprop
from _pyngraph.op import BatchNorm
from _pyngraph.op import BatchNormBackprop
from _pyngraph.op import Broadcast
from _pyngraph.op import Ceiling
from _pyngraph.op import Concat
......@@ -54,28 +57,33 @@ from _pyngraph.op import Dot
from _pyngraph.op import Equal
from _pyngraph.op import Exp
from _pyngraph.op import Floor
from _pyngraph.op import FunctionCall
from _pyngraph.op import GetOutputElement
from _pyngraph.op import Greater
from _pyngraph.op import GreaterEq
from _pyngraph.op import Less
from _pyngraph.op import LessEq
from _pyngraph.op import Log
from _pyngraph.op import Max
from _pyngraph.op import Maximum
from _pyngraph.op import MaxPool
from _pyngraph.op import MaxPoolBackprop
from _pyngraph.op import Maximum
from _pyngraph.op import Min
from _pyngraph.op import Minimum
from _pyngraph.op import Multiply
from _pyngraph.op import Negative
from _pyngraph.op import NotEqual
from _pyngraph.op import Not
from _pyngraph.op import NotEqual
from _pyngraph.op import OneHot
from _pyngraph.op import Op
from _pyngraph.op import Pad
from _pyngraph.op import Parameter
from _pyngraph.op import ParameterVector
from _pyngraph.op import Power
from _pyngraph.op import Product
from _pyngraph.op import Reduce
from _pyngraph.op import Relu
from _pyngraph.op import ReluBackprop
from _pyngraph.op import ReplaceSlice
from _pyngraph.op import Reshape
from _pyngraph.op import Reverse
......@@ -84,17 +92,9 @@ from _pyngraph.op import Sign
from _pyngraph.op import Sin
from _pyngraph.op import Sinh
from _pyngraph.op import Slice
from _pyngraph.op import Softmax
from _pyngraph.op import Sqrt
from _pyngraph.op import Subtract
from _pyngraph.op import Sum
from _pyngraph.op import Tan
from _pyngraph.op import Tanh
from _pyngraph.op import Relu
from _pyngraph.op import ReluBackprop
from _pyngraph.op import Product
from _pyngraph.op import AllReduce
from _pyngraph.op import FunctionCall
from _pyngraph.op import GetOutputElement
from _pyngraph.op import BatchNorm
from _pyngraph.op import BatchNormBackprop
from _pyngraph.op import Softmax
......@@ -23,7 +23,8 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, N
from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, Broadcast, Ceiling, Concat, \
Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq,\
Less, LessEq, Log, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, \
Pad, Parameter, Product, Reshape, Slice, Softmax, Sqrt, Subtract, Sum, Tanh
OneHot, Pad, Parameter, Product, Power, Relu, ReplaceSlice, Reshape, Reverse, Select, \
Sign, Sin, Sinh, Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh
from typing import Iterable, List
......@@ -126,7 +127,12 @@ def cosh(node, name=None): # type: (NodeInput, str) -> Node
@unary_op
def sqrt(node, name=None): # type: (NodeInput, str) -> Node
"""Return node which applies square root to the input node elementwise."""
"""Return node which applies square root to the input node element-wise.
:param node: One of: input node, array or scalar.
:param name: Optional new name for output node.
:return: The new node with sqrt operation applied element-wise.
"""
return Sqrt(node)
......@@ -192,6 +198,62 @@ def reshape(node, input_order, output_shape, name=None):
return Reshape(node, AxisVector(input_order), Shape(output_shape))
@unary_op
def relu(node, name=None): # type: (NodeInput, str) -> Node
"""Perform rectified linear unit operation on input node element-wise.
:param node: One of: input node, array or scalar.
:param name: The optional ouptut node name.
:return: The new node performing relu operation on its input element-wise.
"""
return Relu(node)
@unary_op
def sign(node, name=None): # type: (NodeInput, str) -> Node
"""Perform element-wise sign operation.
:param node: One of: input node, array or scalar.
:param name: The optional new name for ouptut node.
:return: The node with mapped elements of the input tensor to -1 (if it is negative),
0 (if it is zero), or 1 (if it is positive).
"""
return Sign(node)
@unary_op
def sin(node, name=None): # type: (NodeInput, str) -> Node
"""Apply sine function on the input node element-wise.
:param node: One of: input node, array or scalar.
:param name: Optional new name for output node.
:return: New node with sin operation applied on it.
"""
return Sin(node)
@unary_op
def sinh(node, name=None): # type: (NodeInput, str) -> Node
"""Apply hyperbolic sine function on the input node element-wise.
:param node: One of: input node, array or scalar.
:param name: Optional new name for output node.
:return: New node with sin operation applied on it.
"""
return Sinh(node)
@unary_op
def tan(node, name=None): # type: (NodeInput, str) -> Node
"""Apply tangent function on the input node element-wise.
:param node: One of: input node, array or scalar.
:param name: Optional new name for output node.
:return: New node with tan operation applied on it.
"""
return Tan(node)
# Binary ops
@binary_op
def divide(left_node, right_node, name=None): # type: (NodeInput, NodeInput, str) -> Node
......@@ -213,7 +275,13 @@ def multiply(left_node, right_node, name=None): # type: (NodeInput, NodeInput,
@binary_op
def subtract(left_node, right_node, name=None): # type: (NodeInput, NodeInput, str) -> Node
"""Return node which applies f(x) = A-B to the input nodes elementwise."""
"""Return node which applies f(x) = A-B to the input nodes element-wise.
:param left_node: The node providing data for left hand side of operator.
:param right_node: The node providing data for right hand side of operator.
:param name: The optional name for output node.
:return: The new output node performing subtraction operation on both tensors element-wise.
"""
return Subtract(left_node, right_node)
......@@ -235,6 +303,18 @@ def maximum(left_node, right_node, name=None): # type: (NodeInput, NodeInput, s
return Maximum(left_node, right_node)
@binary_op
def power(left_node, right_node, name=None): # type: (NodeInput, NodeInput, str) -> Node
"""Return node which perform element-wise exponentiation operation.
:param left_node: The node providing the base of operation.
:param right_node: The node providing the exponent of operation.
:param name: The optional name for the new output node.
:return: The new node performing element-wise exponentiation operation on input nodes.
"""
return Power(left_node, right_node)
# Logical ops
@binary_op
def equal(left_node, right_node, name=None): # type: (NodeInput, NodeInput, str) -> Node
......@@ -250,7 +330,13 @@ def equal(left_node, right_node, name=None): # type: (NodeInput, NodeInput, str
@binary_op
def not_equal(left_node, right_node, name=None): # type: (NodeInput, NodeInput, str) -> Node
"""Return node which checks if input nodes are unequal elementwise."""
"""Return node which checks if input nodes are unequal element-wise.
:param left_node: The first input node for not-equal operation.
:param right_node: The second input node for not-equal operation.
:param name: The optional name for output new node.
:return: The node performing element-wise inequality check.
"""
return NotEqual(left_node, right_node)
......@@ -350,10 +436,31 @@ def convert(node, new_type, name=None): # type: (Node, NumericType, str) -> Nod
return Convert(node, new_element_type)
@nameable_op
def select(selection_node, input_node1, input_node2, name=None):
# type: (Node, Node, Node, str) -> Node
"""Perform an element-wise selection operation on input tensors.
:param selection_node: The node providing selection values of `bool` type.
:param input_node1: The node providing data to be selected if respective `selection_node`
item value is `True`.
:param input_node2: The node providing data to be selected if respective `selection_node`
item value is `False`.
:param name: The optional new name for output node.
:return: The new node with values selected according to provided arguments.
"""
return Select(selection_node, input_node1, input_node2)
# Non-linear ops
@unary_op
def tanh(node, name=None): # type: (Node, str) -> Node
"""Return node which applies tanh to the input node elementwise."""
"""Return node which applies hyperbolic tangent to the input node element-wise.
:param node: One of: input node, array or scalar.
:param name: Optional new name for output node.
:return: New node with tanh operation applied on it.
"""
return Tanh(node)
......@@ -479,12 +586,14 @@ def max_pool(x, # type: Node
@nameable_op
def sum(node, reduction_axes=None, name=None):
# type: (Node, Iterable[int], str) -> Node
"""Element-wise sums the input tensor, eliminating the specified reduction axes.
"""Perform element-wise sums of the input tensor, eliminating the specified reduction axes.
:param node: The node providing data for operation.
:param reduction_axes: The axes to eliminate through summation.
:param name: The optional new name for ouptut node.
:return: The new node performing summation along `reduction_axes` element-wise.
"""
reduction_axes = get_reduction_axes(node, reduction_axes)
return Sum(node, AxisSet(reduction_axes))
return Sum(node, AxisSet(get_reduction_axes(node, reduction_axes)))
@nameable_op
......@@ -496,8 +605,7 @@ def max(node, reduction_axes=None, name=None):
:param reduction_axes: The axes to eliminate through max operation.
:param name: Optional name for output node.
"""
reduction_axes = get_reduction_axes(node, reduction_axes)
return Max(node, AxisSet(reduction_axes))
return Max(node, AxisSet(get_reduction_axes(node, reduction_axes)))
@nameable_op
......@@ -509,8 +617,7 @@ def min(node, reduction_axes=None, name=None):
:param reduction_axes: The axes to eliminate through min operation.
:param name: Optional name for output node.
"""
reduction_axes = get_reduction_axes(node, reduction_axes)
return Min(node, AxisSet(reduction_axes))
return Min(node, AxisSet(get_reduction_axes(node, reduction_axes)))
@nameable_op
......@@ -521,9 +628,9 @@ def prod(node, reduction_axes=None, name=None):
:param node: The tensor we want to product-reduce.
:param reduction_axes: The axes to eliminate through product operation.
:param name: Optional name for output node.
:return: The new node performing product-reduction operation.
"""
reduction_axes = get_reduction_axes(node, reduction_axes)
return Product(node, AxisSet(reduction_axes))
return Product(node, AxisSet(get_reduction_axes(node, reduction_axes)))
# reshape ops
......@@ -560,8 +667,15 @@ def concat(nodes, axis, name=None): # type: (List[Node], int, str) -> Node
@nameable_op
def softmax(node, axes): # type: (Node, Iterable[int]) -> Node
"""Softmax operation on input tensor."""
def softmax(node, axes, name=None): # type: (Node, Iterable[int], str) -> Node
"""Apply softmax operation on each element of input tensor.
:param node: The tensor providing input data.
:param axes: The list of axes indices which are used to calculate divider of
the softmax function.
:param name: The optional new name for output node.
:return: The new node with softmax operation applied on each element.
"""
if type(axes) is not set:
axes = set(axes)
return Softmax(node, AxisSet(axes))
......@@ -595,3 +709,55 @@ def pad(data_batch, # type: Node
padding_in = [0] * dim_count
return Pad(data_batch, value, Shape(padding_below), Shape(padding_above), Shape(padding_in))
@nameable_op
def one_hot(node, shape, one_hot_axis, name=None): # type: (Node, TensorShape, int, str) -> Node
"""Create node performing one-hot encoding on input data.
:param node: The input node providing data for operation.
:param shape: The output node shape including the new one-hot axis.
:param one_hot_axis: The index within the output shape of the new one-hot axis.
:param name: The optional name for new output node.
:return: New node performing one-hot operation.
"""
return OneHot(node, Shape(shape), one_hot_axis)
@nameable_op
def replace_slice(dest_node, # type: Node
src_node, # type: Node
lower_bounds, # type: List[int]
upper_bounds, # type: List[int]
strides=None, # type: List[int]
name=None, # type: str
):
# type: (...) -> Node
"""Return a copy of `dest_node` with the specified slice overwritten by the `src_node` data.
:param dest_node: The node providing data to be overwritten by the specified slice.
:param src_node: The node providing data for overwriting.
:param lower_bounds: The (inclusive) lower-bound coordinates for the replaced slice.
:param upper_bounds: The (exclusive) upper-bound coordinates for the replaced slice.
:param strides: The strides for the replaced slice.
:param name: The optional name for the output new node.
:return: The new node with copy of `dest_node` with the specified slice overwritten
by the `src_node`.
"""
if strides is None:
return ReplaceSlice(dest_node, src_node, Coordinate(lower_bounds), Coordinate(upper_bounds))
else:
return ReplaceSlice(dest_node, src_node, Coordinate(lower_bounds), Coordinate(upper_bounds),
Strides(strides))
@nameable_op
def reverse(node, reversed_axes, name=None): # type: (Node, List[int], str) -> Node
"""Perform axis-reverse operation.
:param node: The input node on which operation will be carried out.
:param reversed_axes: The list of indices of axes to be reversed.
:param name: The optional name of the output node.
:return: The new node with reversed axes.
"""
return Reverse(node, AxisSet(reversed_axes))
......@@ -20,19 +20,40 @@ import ngraph as ng
from test.ngraph.util import run_op_numeric_data, run_op_node
@pytest.mark.xfail(reason='Results mismatch when passing created Constant node from raw data.')
@pytest.mark.parametrize('ng_api_fn, numpy_fn, range_start, range_end', [
(ng.cos, np.cos, -100., 100.),
(ng.sin, np.sin, -100., 100.),
])
def test_unary_op_array_err(ng_api_fn, numpy_fn, range_start, range_end):
np.random.seed(133391)
input_data = range_start + np.random.rand(2, 3, 4) * (range_end - range_start)
expected = numpy_fn(input_data)
result = run_op_numeric_data(input_data, ng_api_fn)
assert np.allclose(result, expected)
@pytest.mark.parametrize('ng_api_fn, numpy_fn, range_start, range_end', [
(ng.absolute, np.abs, -1, 1),
(ng.abs, np.abs, -1, 1),
(ng.acos, np.arccos, -1, 1),
(ng.asin, np.arcsin, -1, 1),
(ng.atan, np.arctan, -100, 100),
(ng.ceiling, np.ceil, -100, 100),
(ng.ceil, np.ceil, -100, 100),
(ng.cos, np.cos, -np.pi, np.pi),
(ng.cosh, np.cosh, -np.pi, np.pi),
(ng.exp, np.exp, -100, 100),
(ng.floor, np.floor, -100, 100),
(ng.log, np.log, 0, 100),
(ng.atan, np.arctan, -100., 100.),
(ng.ceiling, np.ceil, -100., 100.),
(ng.ceil, np.ceil, -100., 100.),
(ng.cos, np.cos, -np.pi * 2., np.pi * 2.),
(ng.cosh, np.cosh, -100., 100.),
(ng.exp, np.exp, -100., 100.),
(ng.floor, np.floor, -100., 100.),
(ng.log, np.log, 0, 100.),
(ng.relu, lambda x: np.maximum(0, x), -100., 100.),
(ng.sign, np.sign, -100., 100.),
(ng.sin, np.sin, -np.pi * 2., np.pi * 2.),
(ng.sinh, np.sinh, -100., 100.),
(ng.sqrt, np.sqrt, 0., 100.),
(ng.tan, np.tan, -1., 1.),
(ng.tanh, np.tanh, -100., 100.),
])
def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end):
np.random.seed(133391)
......@@ -59,6 +80,13 @@ def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end):
(ng.exp, np.exp, np.float32(1.5)),
(ng.floor, np.floor, np.float32(1.5)),
(ng.log, np.log, np.float32(1.5)),
(ng.relu, lambda x: np.maximum(0, x), np.float32(-0.125)),
(ng.sign, np.sign, np.float32(0.)),
(ng.sin, np.sin, np.float32(np.pi / 4.0)),
(ng.sinh, np.sinh, np.float32(0.)),
(ng.sqrt, np.sqrt, np.float32(3.5)),
(ng.tan, np.tan, np.float32(np.pi / 4.0)),
(ng.tanh, np.tanh, np.float32(0.1234)),
])
def test_unary_op_scalar(ng_api_fn, numpy_fn, input_data):
expected = numpy_fn(input_data)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment