Commit 5dcd835f authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

cuDNN Softmax implementation for all axis activation (#1045)

* cuDNN softmax impl. for all axis activation.

* Added catch for per-axis activations.
parent bff65fe3
......@@ -432,3 +432,69 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const runtime::gpu::GPURuntim
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
size_t runtime::gpu::CUDNNEmitter::build_softmax(const runtime::gpu::GPURuntimeContext* ctx,
const cudnnSoftmaxAlgorithm_t& algorithm,
const cudnnSoftmaxMode_t& mode,
const Prop& direction,
const Shape& tensor_shape)
{
// construct hash to determine if kernel needs to be emitted
// or if it already exists in the primitive list
std::stringstream ss;
ss << "softmax_op" << mode << "_alg" << algorithm << "_dir" << static_cast<int>(direction)
<< "_s" << join(tensor_shape, "_");
std::string hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
auto& tensor_desc = tensor_descriptor_from_shape(tensor_shape);
float alpha = 1.0, beta = 0.0;
std::unique_ptr<runtime::gpu::primitive> softmax;
switch (direction)
{
case Prop::Forward:
case Prop::Inference:
{
softmax.reset(new gpu::primitive{[=, &tensor_desc](void** inputs, void** outputs) {
CUDNN_SAFE_CALL(cudnnSoftmaxForward(*ctx->cudnn_handle,
algorithm,
mode,
&alpha,
tensor_desc,
inputs[0],
&beta,
tensor_desc,
outputs[0]));
}});
break;
}
case Prop::Backward:
{
softmax.reset(new gpu::primitive{[=, &tensor_desc](void** inputs, void** outputs) {
CUDNN_SAFE_CALL(cudnnSoftmaxBackward(*ctx->cudnn_handle,
algorithm,
mode,
&alpha,
tensor_desc,
inputs[0],
tensor_desc,
inputs[1],
&beta,
tensor_desc,
outputs[0]));
}});
break;
}
}
primitive_index = this->m_primitive_emitter->insert(std::move(softmax));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
......@@ -77,6 +77,12 @@ namespace ngraph
const Shape& param_shape,
double epsilon);
size_t build_softmax(const runtime::gpu::GPURuntimeContext* ctx,
const cudnnSoftmaxAlgorithm_t& algorithm,
const cudnnSoftmaxMode_t& mode,
const Prop& direction,
const Shape& tensor_shape);
cudnnTensorDescriptor_t& tensor_descriptor_from_shape(const Shape& shape);
private:
......
......@@ -1859,6 +1859,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
}
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::AvgPoolBackprop)
{
......@@ -1901,6 +1902,38 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
}
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::Softmax)
{
writer.block_begin(" // " + node->get_name());
{
auto softmax = static_cast<const ngraph::op::Softmax*>(node);
auto tensor_shape = args[0].get_shape();
auto axes = softmax->get_axes();
if (axes.size() != tensor_shape.size())
{
throw std::runtime_error(
"Softmax implementation currently only supports all axis activation.");
}
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
size_t softmax_index =
cudnn_emitter->build_softmax(external_function->ctx().get(),
CUDNN_SOFTMAX_FAST,
CUDNN_SOFTMAX_MODE_INSTANCE,
CUDNNEmitter::Prop::Forward,
tensor_shape);
writer << "gpu::invoke_primitive(ctx, " << softmax_index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
writer.block_end();
}
}
}
}
......@@ -4,9 +4,7 @@ backwards_reverse_sequence_n4d2c3h2w2
backwards_reverse_sequence_n3_c2_h3
backwards_slice
backwards_softmax_3d
backwards_softmax_all
backwards_softmax_axis
backwards_softmax_underflow
batch_norm_one_output
batch_norm_three_outputs
broadcast_vector_rowwise_int64
......@@ -68,7 +66,6 @@ scalar_constant_int64
select_and_scatter_3d_without_overlap
select_and_scatter_with_overlap
select_and_scatter_without_overlap
softmax_all
softmax_axis
softmax_underflow
tensor_constant
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment