Commit c80a1076 authored by arogowie-intel's avatar arogowie-intel Committed by Nick Korovaiko

[Py] Add python wrapper for nGraph Reduce operation. (#827)

* Add python wrapper for nGraph Reduce operation.

- Add UT.

* Refactoring.

- Add UT case with default reduction on all axes.

* Extend `reduce` operation signature to also accept `Function` object.

- Add UT case.

* Fix formatting errors.
parent e7cf2662
......@@ -60,6 +60,7 @@ from ngraph.ops import power
from ngraph.ops import prod
from ngraph.ops import relu
from ngraph.ops import replace_slice
from ngraph.ops import reduce
from ngraph.ops import reshape
from ngraph.ops import reverse
from ngraph.ops import select
......
......@@ -17,15 +17,15 @@
"""Factory functions for all ngraph ops."""
import numpy as np
from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, NodeVector, \
Shape, Strides
from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Function, Node, \
NodeVector, Shape, Strides
from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, BatchNorm, Broadcast, Ceiling,\
from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, BatchNorm, Broadcast, Ceiling, \
Concat, Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, \
FunctionCall, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, Max, Maximum, MaxPool, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Pad, Parameter, Product, Power, Relu, \
ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, Sqrt, Subtract, Sum, \
Tan, Tanh
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Pad, Parameter, Product, Power, \
Reduce, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, Sqrt, \
Subtract, Sum, Tan, Tanh
from typing import Iterable, List
......@@ -34,7 +34,7 @@ from ngraph.utils.decorators import nameable_op, binary_op, unary_op
from ngraph.utils.input_validation import assert_list_of_ints
from ngraph.utils.reduction import get_reduction_axes
from ngraph.utils.types import NumericType, NumericData, TensorShape, make_constant_node, \
NodeInput
NodeInput, ScalarData, CallableData
from ngraph.utils.types import get_element_type
......@@ -634,6 +634,39 @@ def prod(node, reduction_axes=None, name=None):
return Product(node, AxisSet(get_reduction_axes(node, reduction_axes)))
@nameable_op
def reduce(node, # type: Node
initial_value, # type: ScalarData
reduction_function, # type: CallableData
reduction_axes=None, # type: List[int]
name=None, # type: str
):
# type: (...) -> Node
"""Perform general tensor reduction operation.
:param node: The node providing data for reduction operation.
:param initial_value: The initial value for reduction operation.
:param reduction_function: The function performing binary reduction operation or a nGraph
Function object. The operation must accept two nodes providing scalar
operands and return a node which produces a scalar result.
:param reduction_axes: The list of axes indices to be reduced. Default to reduce all axes.
:param name: The new name for output node.
:return: The node performing reduction operation with provided reduction node.
"""
if reduction_axes is None:
reduction_axes = list(range(len(node.shape)))
init_val_node = constant(initial_value)
if not isinstance(reduction_function, Function):
# wrap reduction function into Function object
param1 = Parameter(node.get_element_type(), Shape([]))
param2 = Parameter(node.get_element_type(), Shape([]))
reduction_operation = Function(NodeVector([reduction_function(param1, param2)]),
[param1, param2], 'reduction_operation')
else:
reduction_operation = reduction_function
return Reduce(node, init_val_node, reduction_operation, AxisSet(set(reduction_axes)))
# reshape ops
@nameable_op
def slice(node, lower_bounds, upper_bounds, strides=None, name=None):
......
......@@ -16,12 +16,12 @@
"""Functions related to converting between Python and numpy types and ngraph types."""
import logging
from typing import Union, List
from typing import Callable, Union, List
import numpy as np
from ngraph.impl import Type as NgraphType
from ngraph.impl import Node, Shape
from ngraph.impl import Function, Node, Shape
from ngraph.impl.op import Constant
from ngraph.exceptions import NgraphTypeError
......@@ -32,7 +32,9 @@ log = logging.getLogger(__file__)
TensorShape = List[int]
NumericData = Union[int, float, np.ndarray]
NumericType = Union[type, np.dtype]
ScalarData = Union[int, float]
NodeInput = Union[Node, NumericData]
CallableData = Union[Callable, Function]
ngraph_to_numpy_types_map = [
(NgraphType.boolean, np.bool),
......
......@@ -18,6 +18,8 @@ import pytest
import ngraph as ng
from test.ngraph.util import run_op_node
from ngraph.impl import Function, NodeVector, Shape
from ngraph.utils.types import get_element_type
@pytest.mark.parametrize('ng_api_helper, numpy_function, reduction_axes', [
......@@ -43,3 +45,49 @@ def test_reduction_ops(ng_api_helper, numpy_function, reduction_axes):
expected = numpy_function(input_data, axis=reduction_axes)
result = run_op_node([input_data], ng_api_helper, reduction_axes)
assert np.allclose(result, expected)
def test_reduce():
from functools import reduce
np.random.seed(133391)
# default reduce all axes
init_val = np.float32(0.)
input_data = np.random.randn(3, 4, 5).astype(np.float32)
expected = np.sum(input_data)
reduction_function_args = [init_val, ng.impl.op.Add]
result = run_op_node([input_data], ng.reduce, *reduction_function_args)
assert np.allclose(result, expected)
reduction_axes = (0, 2)
init_val = np.float32(0.)
input_data = np.random.randn(3, 4, 5).astype(np.float32)
expected = np.sum(input_data, axis=reduction_axes)
reduction_function_args = [init_val, ng.impl.op.Add, list(reduction_axes)]
result = run_op_node([input_data], ng.reduce, *reduction_function_args)
assert np.allclose(result, expected)
reduction_axes = (0, )
input_data = np.random.randn(100).astype(np.float32)
expected = reduce(lambda x, y: x - y, input_data, np.float32(0.))
reduction_function_args = [init_val, ng.impl.op.Subtract, list(reduction_axes)]
result = run_op_node([input_data], ng.reduce, *reduction_function_args)
assert np.allclose(result, expected)
reduction_axes = (0, )
input_data = np.random.randn(100).astype(np.float32)
expected = reduce(lambda x, y: x + y * y, input_data, np.float32(0.))
reduction_function_args = [init_val, lambda x, y: x + y * y, list(reduction_axes)]
result = run_op_node([input_data], ng.reduce, *reduction_function_args)
assert np.allclose(result, expected)
def custom_reduction_function(a, b):
return a + b * b
param1 = ng.impl.op.Parameter(get_element_type(np.float32), Shape([]))
param2 = ng.impl.op.Parameter(get_element_type(np.float32), Shape([]))
reduction_operation = Function(NodeVector([custom_reduction_function(param1, param2)]),
[param1, param2], 'reduction_op')
reduction_function_args = [init_val, reduction_operation, list(reduction_axes)]
result = run_op_node([input_data], ng.reduce, *reduction_function_args)
assert np.allclose(result, expected)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment