Commit c9737e83 authored by arogowie-intel's avatar arogowie-intel Committed by Scott Cyphers

[Py] Python wrappers for nGraph operations (#719)

* Add test for ng.absolute

- remove unnecesary repetitive call to as_node

* Add Acos operation with UT.

- update docsrings
- refactor and parameterize unit tests

* Review refactoring fix.

* Review refactoring fix part 2.

* Add Asin and Atan nGraph operation wrappers.
parent b6e1065e
...@@ -17,7 +17,10 @@ ...@@ -17,7 +17,10 @@
from ngraph.ops import absolute from ngraph.ops import absolute
from ngraph.ops import absolute as abs from ngraph.ops import absolute as abs
from ngraph.ops import acos
from ngraph.ops import add from ngraph.ops import add
from ngraph.ops import asin
from ngraph.ops import atan
from ngraph.ops import avg_pool from ngraph.ops import avg_pool
from ngraph.ops import broadcast from ngraph.ops import broadcast
from ngraph.ops import ceiling from ngraph.ops import ceiling
......
...@@ -20,19 +20,19 @@ import numpy as np ...@@ -20,19 +20,19 @@ import numpy as np
from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, NodeVector, \ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, NodeVector, \
Shape, Strides Shape, Strides
from ngraph.impl.op import Abs, Add, AvgPool, Broadcast, Ceiling, Concat, Constant, Convert, \ from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, Broadcast, Ceiling, Concat, \
Convolution, Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq, Less, LessEq, Log, Max, \ Constant, Convert, Convolution, Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq, Less, \
Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, Parameter, Product, \ LessEq, Log, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, Parameter,\
Reshape, Slice, Sqrt, Subtract, Sum, Tanh Product, Reshape, Slice, Sqrt, Subtract, Sum, Tanh
from typing import Iterable, List, Optional from typing import Iterable, List
from ngraph.utils.broadcasting import get_broadcast_axes from ngraph.utils.broadcasting import get_broadcast_axes
from ngraph.utils.decorators import nameable_op, binary_op, unary_op from ngraph.utils.decorators import nameable_op, binary_op, unary_op
from ngraph.utils.input_validation import assert_list_of_ints from ngraph.utils.input_validation import assert_list_of_ints
from ngraph.utils.reduction import get_reduction_axes from ngraph.utils.reduction import get_reduction_axes
from ngraph.utils.types import NumericType, NumericData, TensorShape, make_constant_node, \ from ngraph.utils.types import NumericType, NumericData, TensorShape, make_constant_node, \
as_node, NodeInput NodeInput
from ngraph.utils.types import get_element_type from ngraph.utils.types import get_element_type
...@@ -54,11 +54,48 @@ def constant(value, dtype=None, name=None): # type: (NumericData, NumericType, ...@@ -54,11 +54,48 @@ def constant(value, dtype=None, name=None): # type: (NumericData, NumericType,
# Unary ops # Unary ops
@unary_op @unary_op
def absolute(node, name=None): # type: (NodeInput, str) -> Node def absolute(node, name=None): # type: (NodeInput, str) -> Node
"""Return node which applies f(x) = abs(x) to the input node elementwise.""" """Return node which applies f(x) = abs(x) to the input node element-wise.
node = as_node(node)
:param node: One of: input node, array or scalar.
:param name: Optional new name for output node.
:return: New node with Abs operation applied on it.
"""
return Abs(node) return Abs(node)
@unary_op
def acos(node, name=None): # type: (NodeInput, str) -> Node
"""Apply inverse cosine function on the input node element-wise.
:param node: One of: input node, array or scalar.
:param name: Optional new name for output node.
:return: New node with arccos operation applied on it.
"""
return Acos(node)
@unary_op
def asin(node, name=None): # type: (NodeInput, str) -> Node
"""Apply inverse sine function on the input node element-wise.
:param node: One of: input node, array or scalar.
:param name: Optional new name for output node.
:return: New node with arcsin operation applied on it.
"""
return Asin(node)
@unary_op
def atan(node, name=None): # type: (NodeInput, str) -> Node
"""Apply inverse tangent function on the input node element-wise.
:param node: One of: input node, array or scalar.
:param name: Optional new name for output node.
:return: New node with arctan operation applied on it.
"""
return Atan(node)
@unary_op @unary_op
def sqrt(node, name=None): # type: (NodeInput, str) -> Node def sqrt(node, name=None): # type: (NodeInput, str) -> Node
"""Return node which applies square root to the input node elementwise.""" """Return node which applies square root to the input node elementwise."""
...@@ -128,7 +165,7 @@ def subtract(left_node, right_node, name=None): # type: (NodeInput, NodeInput, ...@@ -128,7 +165,7 @@ def subtract(left_node, right_node, name=None): # type: (NodeInput, NodeInput,
@binary_op @binary_op
def add(left_node, right_node, name=None): # type: (NodeInput, NodeInput, str) -> Node def add(left_node, right_node, name=None): # type: (NodeInput, NodeInput, str) -> Node
"""Return node which applies f(x) = A+B to the input nodes elementwise.""" """Return node which applies f(x) = A+B to the input nodes element-wise."""
return Add(left_node, right_node) return Add(left_node, right_node)
......
# ******************************************************************************
# Copyright 2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import numpy as np
import pytest
import ngraph as ng
def _get_runtime():
manager_name = pytest.config.getoption('backend', default='CPU')
return ng.runtime(manager_name=manager_name)
def _run_unary_op_node(input_data, unary_op):
runtime = _get_runtime()
parameter_a = ng.parameter(input_data.shape, name='A', dtype=np.float32)
node = unary_op(parameter_a)
computation = runtime.computation(node, parameter_a)
return computation(input_data)
def _run_unary_op_numeric_data(input_data, unary_op):
runtime = _get_runtime()
node = unary_op(input_data)
computation = runtime.computation(node)
return computation()
@pytest.mark.parametrize('ng_api_fn, numpy_fn, input_data', [
(ng.absolute, np.abs, -1 + np.random.rand(2, 3, 4) * 2),
(ng.absolute, np.abs, np.float32(-3)),
(ng.acos, np.arccos, -1 + np.random.rand(2, 3, 4) * 2),
(ng.acos, np.arccos, np.float32(-0.5)),
(ng.asin, np.arcsin, -1 + np.random.rand(2, 3, 4) * 2),
(ng.asin, np.arcsin, np.float32(-0.5)),
(ng.atan, np.arctan, -100 + np.random.rand(2, 3, 4) * 200),
(ng.atan, np.arctan, np.float32(-0.5)),
])
def test_unary_op(ng_api_fn, numpy_fn, input_data):
expected = numpy_fn(input_data)
result = _run_unary_op_node(input_data, ng_api_fn)
assert np.allclose(result, expected)
result = _run_unary_op_numeric_data(input_data, ng_api_fn)
assert np.allclose(result, expected)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment