Commit 6f547fdb authored by arogowie-intel's avatar arogowie-intel Committed by Scott Cyphers

Add Slice and Concat operations to nGraph Python API. (#697)

* Remove unnecessary Optional type annotations.
parent c94dd26c
......@@ -22,6 +22,7 @@ from ngraph.ops import avg_pool
from ngraph.ops import broadcast
from ngraph.ops import ceiling
from ngraph.ops import ceiling as ceil
from ngraph.ops import concat
from ngraph.ops import constant
from ngraph.ops import convert
from ngraph.ops import convolution
......@@ -47,6 +48,7 @@ from ngraph.ops import not_equal
from ngraph.ops import parameter
from ngraph.ops import prod
from ngraph.ops import reshape
from ngraph.ops import slice
from ngraph.ops import sqrt
from ngraph.ops import subtract
from ngraph.ops import sum
......
......@@ -17,12 +17,13 @@
"""Factory functions for all ngraph ops."""
import numpy as np
from ngraph.impl import AxisSet, AxisVector, CoordinateDiff, Node, Shape, Strides
from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, NodeVector, \
Shape, Strides
from ngraph.impl.op import Abs, Add, AvgPool, Broadcast, Ceiling, Constant, Convert, Convolution, \
Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq, Less, LessEq, Log, Max, Maximum, MaxPool, \
Min, Minimum, Multiply, Negative, Not, NotEqual, Parameter, Product, Reshape, Sqrt, Subtract, \
Sum, Tanh
from ngraph.impl.op import Abs, Add, AvgPool, Broadcast, Ceiling, Concat, Constant, Convert, \
Convolution, Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq, Less, LessEq, Log, Max, \
Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, Parameter, Product, \
Reshape, Slice, Sqrt, Subtract, Sum, Tanh
from typing import Iterable, List, Optional
......@@ -305,7 +306,7 @@ def max_pool(x, # type: Node
# reduction ops
@nameable_op
def sum(node, reduction_axes=None, name=None):
# type: (Node, Optional[Iterable[int]], Optional[str]) -> Node
# type: (Node, Iterable[int], str) -> Node
"""Element-wise sums the input tensor, eliminating the specified reduction axes.
:param reduction_axes: The axes to eliminate through summation.
......@@ -316,7 +317,7 @@ def sum(node, reduction_axes=None, name=None):
@nameable_op
def max(node, reduction_axes=None, name=None):
# type: (Node, Optional[Iterable[int]], Optional[str]) -> Node
# type: (Node, Iterable[int], str) -> Node
"""Max-reduction operation on input tensor, eliminating the specified reduction axes.
:param node: The tensor we want to max-reduce.
......@@ -329,7 +330,7 @@ def max(node, reduction_axes=None, name=None):
@nameable_op
def min(node, reduction_axes=None, name=None):
# type: (Node, Optional[Iterable[int]], Optional[str]) -> Node
# type: (Node, Iterable[int], str) -> Node
"""Min-reduction operation on input tensor, eliminating the specified reduction axes.
:param node: The tensor we want to max-reduce.
......@@ -342,7 +343,7 @@ def min(node, reduction_axes=None, name=None):
@nameable_op
def prod(node, reduction_axes=None, name=None):
# type: (Node, Optional[Iterable[int]], Optional[str]) -> Node
# type: (Node, Iterable[int], str) -> Node
"""Product-reduction operation on input tensor, eliminating the specified reduction axes.
:param node: The tensor we want to product-reduce.
......@@ -351,3 +352,35 @@ def prod(node, reduction_axes=None, name=None):
"""
reduction_axes = get_reduction_axes(node, reduction_axes)
return Product(node, AxisSet(reduction_axes))
# reshape ops
@nameable_op
def slice(node, lower_bounds, upper_bounds, strides=None, name=None):
# type: (Node, List[int], List[int], List[int], str) -> Node
"""Take a slice of an input tensor, (sub-tensor) that resides within a bounding box.
Optionally this function may be provided with stride along each axis.
:param node: The tensor we want to slice.
:param lower_bounds: The (inclusive) lower-bound coordinates for the tensor slice.
:param upper_bounds: The (exclusive) upper-bound coordinates for the tensor slice.
:param strides: The strides for the tensor slice.
:param name: Optional name for the output node.
:return: Return node that represents a slice of input nodes data.
"""
if strides is None:
return Slice(node, Coordinate(lower_bounds), Coordinate(upper_bounds))
else:
return Slice(node, Coordinate(lower_bounds), Coordinate(upper_bounds), Strides(strides))
@nameable_op
def concat(nodes, axis): # type: (List[Node], int) -> Node
"""Concatenate input nodes into single new node along specified axis.
:param nodes: The nodes we want concatenate into single new node.
:param axis: The axis along which we want to concatenate input nodes.
:return: Return new node that is a concatenation of input nodes.
"""
return Concat(NodeVector(nodes), axis)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment