Commit 83b3a352 authored by tsocha's avatar tsocha Committed by Scott Cyphers

[Py] Enable Softmax op (#749)

parent c60a0e4b
...@@ -52,6 +52,7 @@ from ngraph.ops import parameter ...@@ -52,6 +52,7 @@ from ngraph.ops import parameter
from ngraph.ops import prod from ngraph.ops import prod
from ngraph.ops import reshape from ngraph.ops import reshape
from ngraph.ops import slice from ngraph.ops import slice
from ngraph.ops import softmax
from ngraph.ops import sqrt from ngraph.ops import sqrt
from ngraph.ops import subtract from ngraph.ops import subtract
from ngraph.ops import sum from ngraph.ops import sum
......
...@@ -23,7 +23,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, N ...@@ -23,7 +23,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, N
from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, Broadcast, Ceiling, Concat, \ from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, Broadcast, Ceiling, Concat, \
Constant, Convert, Convolution, Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq, Less, \ Constant, Convert, Convolution, Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq, Less, \
LessEq, Log, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, Parameter,\ LessEq, Log, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, Parameter,\
Product, Reshape, Slice, Sqrt, Subtract, Sum, Tanh Product, Reshape, Slice, Softmax, Sqrt, Subtract, Sum, Tanh
from typing import Iterable, List from typing import Iterable, List
...@@ -421,3 +421,11 @@ def concat(nodes, axis): # type: (List[Node], int) -> Node ...@@ -421,3 +421,11 @@ def concat(nodes, axis): # type: (List[Node], int) -> Node
:return: Return new node that is a concatenation of input nodes. :return: Return new node that is a concatenation of input nodes.
""" """
return Concat(NodeVector(nodes), axis) return Concat(NodeVector(nodes), axis)
@nameable_op
def softmax(node, axes): # type: (Node, Iterable[int]) -> Node
"""Softmax operation on input tensor."""
if type(axes) is not set:
axes = set(axes)
return Softmax(node, AxisSet(axes))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment