Commit c80a1076 authored by arogowie-intel's avatar arogowie-intel Committed by Nick Korovaiko

[Py] Add python wrapper for nGraph Reduce operation. (#827)

* Add python wrapper for nGraph Reduce operation.

- Add UT.

* Refactoring.

- Add UT case with default reduction on all axes.

* Extend `reduce` operation signature to also accept `Function` object.

- Add UT case.

* Fix formatting errors.
parent e7cf2662
...@@ -60,6 +60,7 @@ from ngraph.ops import power ...@@ -60,6 +60,7 @@ from ngraph.ops import power
from ngraph.ops import prod from ngraph.ops import prod
from ngraph.ops import relu from ngraph.ops import relu
from ngraph.ops import replace_slice from ngraph.ops import replace_slice
from ngraph.ops import reduce
from ngraph.ops import reshape from ngraph.ops import reshape
from ngraph.ops import reverse from ngraph.ops import reverse
from ngraph.ops import select from ngraph.ops import select
......
...@@ -17,15 +17,15 @@ ...@@ -17,15 +17,15 @@
"""Factory functions for all ngraph ops.""" """Factory functions for all ngraph ops."""
import numpy as np import numpy as np
from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, NodeVector, \ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Function, Node, \
Shape, Strides NodeVector, Shape, Strides
from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, BatchNorm, Broadcast, Ceiling,\ from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, BatchNorm, Broadcast, Ceiling, \
Concat, Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, \ Concat, Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, \
FunctionCall, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, Max, Maximum, MaxPool, \ FunctionCall, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, Max, Maximum, MaxPool, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Pad, Parameter, Product, Power, Relu, \ Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Pad, Parameter, Product, Power, \
ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, Sqrt, Subtract, Sum, \ Reduce, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, Sqrt, \
Tan, Tanh Subtract, Sum, Tan, Tanh
from typing import Iterable, List from typing import Iterable, List
...@@ -34,7 +34,7 @@ from ngraph.utils.decorators import nameable_op, binary_op, unary_op ...@@ -34,7 +34,7 @@ from ngraph.utils.decorators import nameable_op, binary_op, unary_op
from ngraph.utils.input_validation import assert_list_of_ints from ngraph.utils.input_validation import assert_list_of_ints
from ngraph.utils.reduction import get_reduction_axes from ngraph.utils.reduction import get_reduction_axes
from ngraph.utils.types import NumericType, NumericData, TensorShape, make_constant_node, \ from ngraph.utils.types import NumericType, NumericData, TensorShape, make_constant_node, \
NodeInput NodeInput, ScalarData, CallableData
from ngraph.utils.types import get_element_type from ngraph.utils.types import get_element_type
...@@ -634,6 +634,39 @@ def prod(node, reduction_axes=None, name=None): ...@@ -634,6 +634,39 @@ def prod(node, reduction_axes=None, name=None):
return Product(node, AxisSet(get_reduction_axes(node, reduction_axes))) return Product(node, AxisSet(get_reduction_axes(node, reduction_axes)))
@nameable_op
def reduce(node, # type: Node
initial_value, # type: ScalarData
reduction_function, # type: CallableData
reduction_axes=None, # type: List[int]
name=None, # type: str
):
# type: (...) -> Node
"""Perform general tensor reduction operation.
:param node: The node providing data for reduction operation.
:param initial_value: The initial value for reduction operation.
:param reduction_function: The function performing binary reduction operation or a nGraph
Function object. The operation must accept two nodes providing scalar
operands and return a node which produces a scalar result.
:param reduction_axes: The list of axes indices to be reduced. Default to reduce all axes.
:param name: The new name for output node.
:return: The node performing reduction operation with provided reduction node.
"""
if reduction_axes is None:
reduction_axes = list(range(len(node.shape)))
init_val_node = constant(initial_value)
if not isinstance(reduction_function, Function):
# wrap reduction function into Function object
param1 = Parameter(node.get_element_type(), Shape([]))
param2 = Parameter(node.get_element_type(), Shape([]))
reduction_operation = Function(NodeVector([reduction_function(param1, param2)]),
[param1, param2], 'reduction_operation')
else:
reduction_operation = reduction_function
return Reduce(node, init_val_node, reduction_operation, AxisSet(set(reduction_axes)))
# reshape ops # reshape ops
@nameable_op @nameable_op
def slice(node, lower_bounds, upper_bounds, strides=None, name=None): def slice(node, lower_bounds, upper_bounds, strides=None, name=None):
......
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
"""Functions related to converting between Python and numpy types and ngraph types.""" """Functions related to converting between Python and numpy types and ngraph types."""
import logging import logging
from typing import Union, List from typing import Callable, Union, List
import numpy as np import numpy as np
from ngraph.impl import Type as NgraphType from ngraph.impl import Type as NgraphType
from ngraph.impl import Node, Shape from ngraph.impl import Function, Node, Shape
from ngraph.impl.op import Constant from ngraph.impl.op import Constant
from ngraph.exceptions import NgraphTypeError from ngraph.exceptions import NgraphTypeError
...@@ -32,7 +32,9 @@ log = logging.getLogger(__file__) ...@@ -32,7 +32,9 @@ log = logging.getLogger(__file__)
TensorShape = List[int] TensorShape = List[int]
NumericData = Union[int, float, np.ndarray] NumericData = Union[int, float, np.ndarray]
NumericType = Union[type, np.dtype] NumericType = Union[type, np.dtype]
ScalarData = Union[int, float]
NodeInput = Union[Node, NumericData] NodeInput = Union[Node, NumericData]
CallableData = Union[Callable, Function]
ngraph_to_numpy_types_map = [ ngraph_to_numpy_types_map = [
(NgraphType.boolean, np.bool), (NgraphType.boolean, np.bool),
......
...@@ -18,6 +18,8 @@ import pytest ...@@ -18,6 +18,8 @@ import pytest
import ngraph as ng import ngraph as ng
from test.ngraph.util import run_op_node from test.ngraph.util import run_op_node
from ngraph.impl import Function, NodeVector, Shape
from ngraph.utils.types import get_element_type
@pytest.mark.parametrize('ng_api_helper, numpy_function, reduction_axes', [ @pytest.mark.parametrize('ng_api_helper, numpy_function, reduction_axes', [
...@@ -43,3 +45,49 @@ def test_reduction_ops(ng_api_helper, numpy_function, reduction_axes): ...@@ -43,3 +45,49 @@ def test_reduction_ops(ng_api_helper, numpy_function, reduction_axes):
expected = numpy_function(input_data, axis=reduction_axes) expected = numpy_function(input_data, axis=reduction_axes)
result = run_op_node([input_data], ng_api_helper, reduction_axes) result = run_op_node([input_data], ng_api_helper, reduction_axes)
assert np.allclose(result, expected) assert np.allclose(result, expected)
def test_reduce():
from functools import reduce
np.random.seed(133391)
# default reduce all axes
init_val = np.float32(0.)
input_data = np.random.randn(3, 4, 5).astype(np.float32)
expected = np.sum(input_data)
reduction_function_args = [init_val, ng.impl.op.Add]
result = run_op_node([input_data], ng.reduce, *reduction_function_args)
assert np.allclose(result, expected)
reduction_axes = (0, 2)
init_val = np.float32(0.)
input_data = np.random.randn(3, 4, 5).astype(np.float32)
expected = np.sum(input_data, axis=reduction_axes)
reduction_function_args = [init_val, ng.impl.op.Add, list(reduction_axes)]
result = run_op_node([input_data], ng.reduce, *reduction_function_args)
assert np.allclose(result, expected)
reduction_axes = (0, )
input_data = np.random.randn(100).astype(np.float32)
expected = reduce(lambda x, y: x - y, input_data, np.float32(0.))
reduction_function_args = [init_val, ng.impl.op.Subtract, list(reduction_axes)]
result = run_op_node([input_data], ng.reduce, *reduction_function_args)
assert np.allclose(result, expected)
reduction_axes = (0, )
input_data = np.random.randn(100).astype(np.float32)
expected = reduce(lambda x, y: x + y * y, input_data, np.float32(0.))
reduction_function_args = [init_val, lambda x, y: x + y * y, list(reduction_axes)]
result = run_op_node([input_data], ng.reduce, *reduction_function_args)
assert np.allclose(result, expected)
def custom_reduction_function(a, b):
return a + b * b
param1 = ng.impl.op.Parameter(get_element_type(np.float32), Shape([]))
param2 = ng.impl.op.Parameter(get_element_type(np.float32), Shape([]))
reduction_operation = Function(NodeVector([custom_reduction_function(param1, param2)]),
[param1, param2], 'reduction_op')
reduction_function_args = [init_val, reduction_operation, list(reduction_axes)]
result = run_op_node([input_data], ng.reduce, *reduction_function_args)
assert np.allclose(result, expected)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment