Commit 6f547fdb authored by arogowie-intel's avatar arogowie-intel Committed by Scott Cyphers

Add Slice and Concat operations to nGraph Python API. (#697)

* Remove unnecessary Optional type annotations.
parent c94dd26c
...@@ -22,6 +22,7 @@ from ngraph.ops import avg_pool ...@@ -22,6 +22,7 @@ from ngraph.ops import avg_pool
from ngraph.ops import broadcast from ngraph.ops import broadcast
from ngraph.ops import ceiling from ngraph.ops import ceiling
from ngraph.ops import ceiling as ceil from ngraph.ops import ceiling as ceil
from ngraph.ops import concat
from ngraph.ops import constant from ngraph.ops import constant
from ngraph.ops import convert from ngraph.ops import convert
from ngraph.ops import convolution from ngraph.ops import convolution
...@@ -47,6 +48,7 @@ from ngraph.ops import not_equal ...@@ -47,6 +48,7 @@ from ngraph.ops import not_equal
from ngraph.ops import parameter from ngraph.ops import parameter
from ngraph.ops import prod from ngraph.ops import prod
from ngraph.ops import reshape from ngraph.ops import reshape
from ngraph.ops import slice
from ngraph.ops import sqrt from ngraph.ops import sqrt
from ngraph.ops import subtract from ngraph.ops import subtract
from ngraph.ops import sum from ngraph.ops import sum
......
...@@ -17,12 +17,13 @@ ...@@ -17,12 +17,13 @@
"""Factory functions for all ngraph ops.""" """Factory functions for all ngraph ops."""
import numpy as np import numpy as np
from ngraph.impl import AxisSet, AxisVector, CoordinateDiff, Node, Shape, Strides from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, NodeVector, \
Shape, Strides
from ngraph.impl.op import Abs, Add, AvgPool, Broadcast, Ceiling, Constant, Convert, Convolution, \ from ngraph.impl.op import Abs, Add, AvgPool, Broadcast, Ceiling, Concat, Constant, Convert, \
Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq, Less, LessEq, Log, Max, Maximum, MaxPool, \ Convolution, Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq, Less, LessEq, Log, Max, \
Min, Minimum, Multiply, Negative, Not, NotEqual, Parameter, Product, Reshape, Sqrt, Subtract, \ Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, Parameter, Product, \
Sum, Tanh Reshape, Slice, Sqrt, Subtract, Sum, Tanh
from typing import Iterable, List, Optional from typing import Iterable, List, Optional
...@@ -305,7 +306,7 @@ def max_pool(x, # type: Node ...@@ -305,7 +306,7 @@ def max_pool(x, # type: Node
# reduction ops # reduction ops
@nameable_op @nameable_op
def sum(node, reduction_axes=None, name=None): def sum(node, reduction_axes=None, name=None):
# type: (Node, Optional[Iterable[int]], Optional[str]) -> Node # type: (Node, Iterable[int], str) -> Node
"""Element-wise sums the input tensor, eliminating the specified reduction axes. """Element-wise sums the input tensor, eliminating the specified reduction axes.
:param reduction_axes: The axes to eliminate through summation. :param reduction_axes: The axes to eliminate through summation.
...@@ -316,7 +317,7 @@ def sum(node, reduction_axes=None, name=None): ...@@ -316,7 +317,7 @@ def sum(node, reduction_axes=None, name=None):
@nameable_op @nameable_op
def max(node, reduction_axes=None, name=None): def max(node, reduction_axes=None, name=None):
# type: (Node, Optional[Iterable[int]], Optional[str]) -> Node # type: (Node, Iterable[int], str) -> Node
"""Max-reduction operation on input tensor, eliminating the specified reduction axes. """Max-reduction operation on input tensor, eliminating the specified reduction axes.
:param node: The tensor we want to max-reduce. :param node: The tensor we want to max-reduce.
...@@ -329,7 +330,7 @@ def max(node, reduction_axes=None, name=None): ...@@ -329,7 +330,7 @@ def max(node, reduction_axes=None, name=None):
@nameable_op @nameable_op
def min(node, reduction_axes=None, name=None): def min(node, reduction_axes=None, name=None):
# type: (Node, Optional[Iterable[int]], Optional[str]) -> Node # type: (Node, Iterable[int], str) -> Node
"""Min-reduction operation on input tensor, eliminating the specified reduction axes. """Min-reduction operation on input tensor, eliminating the specified reduction axes.
:param node: The tensor we want to max-reduce. :param node: The tensor we want to max-reduce.
...@@ -342,7 +343,7 @@ def min(node, reduction_axes=None, name=None): ...@@ -342,7 +343,7 @@ def min(node, reduction_axes=None, name=None):
@nameable_op @nameable_op
def prod(node, reduction_axes=None, name=None): def prod(node, reduction_axes=None, name=None):
# type: (Node, Optional[Iterable[int]], Optional[str]) -> Node # type: (Node, Iterable[int], str) -> Node
"""Product-reduction operation on input tensor, eliminating the specified reduction axes. """Product-reduction operation on input tensor, eliminating the specified reduction axes.
:param node: The tensor we want to product-reduce. :param node: The tensor we want to product-reduce.
...@@ -351,3 +352,35 @@ def prod(node, reduction_axes=None, name=None): ...@@ -351,3 +352,35 @@ def prod(node, reduction_axes=None, name=None):
""" """
reduction_axes = get_reduction_axes(node, reduction_axes) reduction_axes = get_reduction_axes(node, reduction_axes)
return Product(node, AxisSet(reduction_axes)) return Product(node, AxisSet(reduction_axes))
# reshape ops
@nameable_op
def slice(node, lower_bounds, upper_bounds, strides=None, name=None):
# type: (Node, List[int], List[int], List[int], str) -> Node
"""Take a slice of an input tensor, (sub-tensor) that resides within a bounding box.
Optionally this function may be provided with stride along each axis.
:param node: The tensor we want to slice.
:param lower_bounds: The (inclusive) lower-bound coordinates for the tensor slice.
:param upper_bounds: The (exclusive) upper-bound coordinates for the tensor slice.
:param strides: The strides for the tensor slice.
:param name: Optional name for the output node.
:return: Return node that represents a slice of input nodes data.
"""
if strides is None:
return Slice(node, Coordinate(lower_bounds), Coordinate(upper_bounds))
else:
return Slice(node, Coordinate(lower_bounds), Coordinate(upper_bounds), Strides(strides))
@nameable_op
def concat(nodes, axis): # type: (List[Node], int) -> Node
"""Concatenate input nodes into single new node along specified axis.
:param nodes: The nodes we want concatenate into single new node.
:param axis: The axis along which we want to concatenate input nodes.
:return: Return new node that is a concatenation of input nodes.
"""
return Concat(NodeVector(nodes), axis)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment