Commit 56bd183a authored by arogowie-intel's avatar arogowie-intel Committed by Michał Karzyński

Refactor type annotation for reduce parameter. (#870)

parent c6d1af4f
...@@ -27,14 +27,14 @@ from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, BatchNorm, Broad ...@@ -27,14 +27,14 @@ from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, BatchNorm, Broad
Reduce, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, Sqrt, \ Reduce, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, Sqrt, \
Subtract, Sum, Tan, Tanh Subtract, Sum, Tan, Tanh
from typing import Iterable, List from typing import Callable, Iterable, List, Union
from ngraph.utils.broadcasting import get_broadcast_axes from ngraph.utils.broadcasting import get_broadcast_axes
from ngraph.utils.decorators import nameable_op, binary_op, unary_op from ngraph.utils.decorators import nameable_op, binary_op, unary_op
from ngraph.utils.input_validation import assert_list_of_ints from ngraph.utils.input_validation import assert_list_of_ints
from ngraph.utils.reduction import get_reduction_axes from ngraph.utils.reduction import get_reduction_axes
from ngraph.utils.types import NumericType, NumericData, TensorShape, make_constant_node, \ from ngraph.utils.types import NumericType, NumericData, TensorShape, make_constant_node, \
NodeInput, ScalarData, CallableData NodeInput, ScalarData
from ngraph.utils.types import get_element_type from ngraph.utils.types import get_element_type
...@@ -637,7 +637,7 @@ def prod(node, reduction_axes=None, name=None): ...@@ -637,7 +637,7 @@ def prod(node, reduction_axes=None, name=None):
@nameable_op @nameable_op
def reduce(node, # type: Node def reduce(node, # type: Node
initial_value, # type: ScalarData initial_value, # type: ScalarData
reduction_function, # type: CallableData reduction_function, # type: Union[Callable, Function]
reduction_axes=None, # type: List[int] reduction_axes=None, # type: List[int]
name=None, # type: str name=None, # type: str
): ):
......
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
"""Functions related to converting between Python and numpy types and ngraph types.""" """Functions related to converting between Python and numpy types and ngraph types."""
import logging import logging
from typing import Callable, Union, List from typing import Union, List
import numpy as np import numpy as np
from ngraph.impl import Type as NgraphType from ngraph.impl import Type as NgraphType
from ngraph.impl import Function, Node, Shape from ngraph.impl import Node, Shape
from ngraph.impl.op import Constant from ngraph.impl.op import Constant
from ngraph.exceptions import NgraphTypeError from ngraph.exceptions import NgraphTypeError
...@@ -34,7 +34,6 @@ NumericData = Union[int, float, np.ndarray] ...@@ -34,7 +34,6 @@ NumericData = Union[int, float, np.ndarray]
NumericType = Union[type, np.dtype] NumericType = Union[type, np.dtype]
ScalarData = Union[int, float] ScalarData = Union[int, float]
NodeInput = Union[Node, NumericData] NodeInput = Union[Node, NumericData]
CallableData = Union[Callable, Function]
ngraph_to_numpy_types_map = [ ngraph_to_numpy_types_map = [
(NgraphType.boolean, np.bool), (NgraphType.boolean, np.bool),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment