Commit 83b3a352 authored by tsocha's avatar tsocha Committed by Scott Cyphers

[Py] Enable Softmax op (#749)

parent c60a0e4b
......@@ -52,6 +52,7 @@ from ngraph.ops import parameter
from ngraph.ops import prod
from ngraph.ops import reshape
from ngraph.ops import slice
from ngraph.ops import softmax
from ngraph.ops import sqrt
from ngraph.ops import subtract
from ngraph.ops import sum
......
......@@ -23,7 +23,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, N
from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, Broadcast, Ceiling, Concat, \
Constant, Convert, Convolution, Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq, Less, \
LessEq, Log, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, Parameter,\
Product, Reshape, Slice, Sqrt, Subtract, Sum, Tanh
Product, Reshape, Slice, Softmax, Sqrt, Subtract, Sum, Tanh
from typing import Iterable, List
......@@ -421,3 +421,11 @@ def concat(nodes, axis): # type: (List[Node], int) -> Node
:return: Return new node that is a concatenation of input nodes.
"""
return Concat(NodeVector(nodes), axis)
@nameable_op
def softmax(node, axes): # type: (Node, Iterable[int]) -> Node
"""Softmax operation on input tensor."""
if type(axes) is not set:
axes = set(axes)
return Softmax(node, AxisSet(axes))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment