Commit 04e56c64 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon Committed by Scott Cyphers

Packetized Softmax variant with innermost reducers (#1438)

parent 09c1d3b1
...@@ -85,17 +85,38 @@ namespace ngraph ...@@ -85,17 +85,38 @@ namespace ngraph
} }
else if (axes.size() == 1) else if (axes.size() == 1)
{ {
std::function<decltype(runtime::cpu::kernel::softmax_1rd<float, 1>)> kernel; if (*axes.begin() == (arg_shape.size() - 1))
{
PARTIAL_SELECT_KERNEL_BY_RANK(kernel, std::function<decltype(
args[0].get_element_type(), runtime::cpu::kernel::softmax_innermost_1rd<float, 1>)>
args[0].get_shape().size(), kernel;
runtime::cpu::kernel::softmax_1rd);
PARTIAL_SELECT_KERNEL_BY_RANK(
auto functor = [&, kernel, arg_shape, axes](CPURuntimeContext* ctx) { kernel,
kernel(arg_tensor, out_tensor, arg_shape, axes); args[0].get_element_type(),
}; args[0].get_shape().size(),
functors.emplace_back(functor); runtime::cpu::kernel::softmax_innermost_1rd);
auto functor = [&, kernel, arg_shape](CPURuntimeContext* ctx) {
kernel(arg_tensor, out_tensor, arg_shape);
};
functors.emplace_back(functor);
}
else
{
std::function<decltype(runtime::cpu::kernel::softmax_1rd<float, 1>)>
kernel;
PARTIAL_SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
args[0].get_shape().size(),
runtime::cpu::kernel::softmax_1rd);
auto functor = [&, kernel, arg_shape, axes](CPURuntimeContext* ctx) {
kernel(arg_tensor, out_tensor, arg_shape, axes);
};
functors.emplace_back(functor);
}
} }
else if (arg_shape.size() == 3 && axes.size() == 2) else if (arg_shape.size() == 3 && axes.size() == 2)
{ {
......
...@@ -98,6 +98,38 @@ namespace ngraph ...@@ -98,6 +98,38 @@ namespace ngraph
out * out.sum(axes).inverse().eval().reshape(rdims).broadcast(bcast); out * out.sum(axes).inverse().eval().reshape(rdims).broadcast(bcast);
} }
template <typename ElementType, unsigned int Rank>
void softmax_innermost_1rd(void* input, void* output, const Shape& input_shape)
{
Eigen::array<Eigen::Index, Rank> in_dims, rdims, bcast;
Eigen::IndexList<Eigen::type2index<Rank - 1>> axis;
rdims.fill(1);
for (int i = 0; i < Rank; i++)
{
in_dims[i] = input_shape[i];
}
for (int i = 0; i < Rank - 1; i++)
{
rdims[i] = in_dims[i];
}
for (int i = 0; i < Rank; i++)
{
bcast[i] = in_dims[i] / rdims[i];
}
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> out(
static_cast<ElementType *>(output), in_dims),
in(static_cast<ElementType *>(input), in_dims);
out.device(eigen::global_thread_pool_device) =
(in - in.maximum(axis).eval().reshape(rdims).broadcast(bcast)).exp();
out.device(eigen::global_thread_pool_device) =
out * out.sum(axis).inverse().eval().reshape(rdims).broadcast(bcast);
}
template <typename ElementType, unsigned int Rank> template <typename ElementType, unsigned int Rank>
void softmax_1rd(void* input, void softmax_1rd(void* input,
void* output, void* output,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment