Commit 9a4125ef authored by Jaikrishnan Menon's avatar Jaikrishnan Menon Committed by Scott Cyphers

Add Softmax variant for rank-3 with 2 reduction axes (#1360)

parent 59a2d4dd
...@@ -98,6 +98,19 @@ namespace ngraph ...@@ -98,6 +98,19 @@ namespace ngraph
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
else if (arg_shape.size() == 3 && axes.size() == 2)
{
std::function<decltype(runtime::cpu::kernel::softmax_3d_2rd<float>)> kernel;
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::softmax_3d_2rd);
auto functor = [&, kernel, arg_shape, axes](CPURuntimeContext* ctx) {
kernel(arg_tensor, out_tensor, arg_shape, axes);
};
functors.emplace_back(functor);
}
else else
{ {
throw ngraph_error("Unsupported Softmax"); throw ngraph_error("Unsupported Softmax");
......
...@@ -106,6 +106,15 @@ namespace ngraph ...@@ -106,6 +106,15 @@ namespace ngraph
{ {
softmax<ElementType, Rank, 1>(input, output, input_shape, softmax_axes); softmax<ElementType, Rank, 1>(input, output, input_shape, softmax_axes);
} }
template <typename ElementType>
void softmax_3d_2rd(void* input,
void* output,
const Shape& input_shape,
const AxisSet& softmax_axes)
{
softmax<ElementType, 3, 2>(input, output, input_shape, softmax_axes);
}
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment