Commit bb06c80b authored by Pruthvi's avatar Pruthvi Committed by Robert Kimball

MKLDNN Softmax (#1113)

* 1. Added mkldnn support for Softmax
2. layout assignment for mkldnn softmax

* added assert to check softmax axis for mkldnn
parent b3f0a474
......@@ -4361,186 +4361,217 @@ namespace ngraph
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Softmax)
{
writer.block_begin();
const ngraph::op::Softmax* softmax = static_cast<const ngraph::op::Softmax*>(node);
auto type = out[0].get_type();
auto shape = out[0].get_shape();
auto dims = out[0].get_shape().size();
auto axes = softmax->get_axes();
// create arg/out if 1d
if (dims < 1)
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
writer << type << "* arg = " << args[0].get_name() << "\n";
writer << type << "* out = " << out[0].get_name() << "\n";
auto softmax = static_cast<const ngraph::op::Softmax*>(node);
if (softmax->get_axes().size() != 1)
{
throw ngraph_error("MKLDNN supports softmax only across single axis");
}
int softmax_axis = static_cast<int>(*(softmax->get_axes().begin()));
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
auto result_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
size_t softmax_index = mkldnn_emitter->build_softmax_forward(
input_desc, result_desc, softmax_axis);
auto& deps = mkldnn_emitter->get_primitive_deps(softmax_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(softmax_index) << ");\n";
}
// else cast arg/out to an Nd array
else
{
std::string shape1toN;
for (size_t d = 1; d < dims; ++d)
writer.block_begin();
const ngraph::op::Softmax* softmax =
static_cast<const ngraph::op::Softmax*>(node);
auto type = out[0].get_type();
auto shape = out[0].get_shape();
auto dims = out[0].get_shape().size();
auto axes = softmax->get_axes();
// create arg/out if 1d
if (dims < 1)
{
shape1toN += "[";
shape1toN += std::to_string(shape[d]);
shape1toN += "]";
writer << type << "* arg = " << args[0].get_name() << "\n";
writer << type << "* out = " << out[0].get_name() << "\n";
}
// else cast arg/out to an Nd array
else
{
std::string shape1toN;
for (size_t d = 1; d < dims; ++d)
{
shape1toN += "[";
shape1toN += std::to_string(shape[d]);
shape1toN += "]";
}
writer << type << " (*arg)" << shape1toN << " = (" << type << " (*)"
<< shape1toN << ") " << args[0].get_name() << ";\n";
writer << type << " (*out)" << shape1toN << " = (" << type << " (*)"
<< shape1toN << ") " << out[0].get_name() << ";\n";
}
writer << type << " (*arg)" << shape1toN << " = (" << type << " (*)"
<< shape1toN << ") " << args[0].get_name() << ";\n";
writer << type << " (*out)" << shape1toN << " = (" << type << " (*)"
<< shape1toN << ") " << out[0].get_name() << ";\n";
}
// build arg/out index
std::string index;
for (size_t d = 0; d < dims; ++d)
{
index += "[i";
index += std::to_string(d);
index += "]";
}
// build arg/out index
std::string index;
for (size_t d = 0; d < dims; ++d)
{
index += "[i";
index += std::to_string(d);
index += "]";
}
// calculate e ^ (arg - max)
// outer loop(s) - for axis not in axes
for (size_t d = 0; d < dims; ++d)
{
if (axes.find(d) == axes.end())
// calculate e ^ (arg - max)
// outer loop(s) - for axis not in axes
for (size_t d = 0; d < dims; ++d)
{
writer << "#pragma omp parallel for\n";
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer.block_begin();
if (axes.find(d) == axes.end())
{
writer << "#pragma omp parallel for\n";
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer.block_begin();
}
}
}
// max inner loop(s)
writer << type << " m = 0;\n"; // TODO: needs to be minval for the type
// max inner loop(s)
writer << type << " m = 0;\n"; // TODO: needs to be minval for the type
for (size_t d = 0; d < dims; ++d)
{
if (axes.find(d) != axes.end())
for (size_t d = 0; d < dims; ++d)
{
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer.block_begin();
if (axes.find(d) != axes.end())
{
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer.block_begin();
}
}
}
writer << "if (arg" << index << " > m)\n";
writer.block_begin();
writer << "m = arg" << index << ";\n";
writer.block_end();
writer << "if (arg" << index << " > m)\n";
writer.block_begin();
writer << "m = arg" << index << ";\n";
writer.block_end();
// end max inner loop(s)
for (size_t d = 0; d < dims; ++d)
{
if (axes.find(d) != axes.end())
// end max inner loop(s)
for (size_t d = 0; d < dims; ++d)
{
writer.block_end();
if (axes.find(d) != axes.end())
{
writer.block_end();
}
}
}
// e ^ (arg - max) inner loop
for (size_t d = 0; d < dims; ++d)
{
if (axes.find(d) != axes.end())
// e ^ (arg - max) inner loop
for (size_t d = 0; d < dims; ++d)
{
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer.block_begin();
if (axes.find(d) != axes.end())
{
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer.block_begin();
}
}
}
writer << "out" << index << " = exp(arg" << index << " - m);\n";
writer << "out" << index << " = exp(arg" << index << " - m);\n";
// end e ^ (arg - max) inner loop
for (size_t d = 0; d < dims; ++d)
{
if (axes.find(d) != axes.end())
// end e ^ (arg - max) inner loop
for (size_t d = 0; d < dims; ++d)
{
writer.block_end();
if (axes.find(d) != axes.end())
{
writer.block_end();
}
}
}
// end e ^ (arg - max) outer loop(s)
for (size_t d = 0; d < dims; ++d)
{
if (axes.find(d) == axes.end())
// end e ^ (arg - max) outer loop(s)
for (size_t d = 0; d < dims; ++d)
{
writer.block_end();
if (axes.find(d) == axes.end())
{
writer.block_end();
}
}
}
// calculate softmax = e ^ (arg - max) / sum (e ^ (arg - max))
// outer loop(s) - for axis not in axes
for (size_t d = 0; d < dims; ++d)
{
if (axes.find(d) == axes.end())
// calculate softmax = e ^ (arg - max) / sum (e ^ (arg - max))
// outer loop(s) - for axis not in axes
for (size_t d = 0; d < dims; ++d)
{
writer << "#pragma omp parallel for\n";
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer.block_begin();
if (axes.find(d) == axes.end())
{
writer << "#pragma omp parallel for\n";
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer.block_begin();
}
}
}
// sum (e ^ (arg - max) inner loop(s)
writer << type << " d = 0;\n";
// sum (e ^ (arg - max) inner loop(s)
writer << type << " d = 0;\n";
for (size_t d = 0; d < dims; ++d)
{
if (axes.find(d) != axes.end())
for (size_t d = 0; d < dims; ++d)
{
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer.block_begin();
if (axes.find(d) != axes.end())
{
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer.block_begin();
}
}
}
writer << "d += out" << index << ";\n";
writer << "d += out" << index << ";\n";
// end sum (e ^ (arg - max) inner loop(s)
for (size_t d = 0; d < dims; ++d)
{
if (axes.find(d) != axes.end())
// end sum (e ^ (arg - max) inner loop(s)
for (size_t d = 0; d < dims; ++d)
{
writer.block_end();
if (axes.find(d) != axes.end())
{
writer.block_end();
}
}
}
writer << "d = 1 / d;\n";
writer << "d = 1 / d;\n";
// softmax inner loop(s)
for (size_t d = 0; d < dims; ++d)
{
if (axes.find(d) != axes.end())
// softmax inner loop(s)
for (size_t d = 0; d < dims; ++d)
{
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer.block_begin();
if (axes.find(d) != axes.end())
{
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer.block_begin();
}
}
}
writer << "out" << index << " *= d;\n";
writer << "out" << index << " *= d;\n";
// end softmax inner loop(s)
for (size_t d = 0; d < dims; ++d)
{
if (axes.find(d) != axes.end())
// end softmax inner loop(s)
for (size_t d = 0; d < dims; ++d)
{
writer.block_end();
if (axes.find(d) != axes.end())
{
writer.block_end();
}
}
}
// end softmax outer loop(s)
for (size_t d = 0; d < dims; ++d)
{
if (axes.find(d) == axes.end())
// end softmax outer loop(s)
for (size_t d = 0; d < dims; ++d)
{
writer.block_end();
if (axes.find(d) == axes.end())
{
writer.block_end();
}
}
writer.block_end();
}
writer.block_end();
}
template <>
......
......@@ -911,3 +911,20 @@ size_t MKLDNNEmitter::build_concat(const std::vector<mkldnn::memory::desc>& inpu
m_primitive_deps[concat_index] = in_out_index;
return concat_index;
}
size_t MKLDNNEmitter::build_softmax_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
int softmax_axis)
{
size_t input_index = build_memory_primitive(input_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index = insert_primitive(
new mkldnn::softmax_forward({{mkldnn::prop_kind::forward_scoring, input_desc, softmax_axis},
mkldnn_utils::global_cpu_engine},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
}
......@@ -235,6 +235,10 @@ namespace ngraph
const mkldnn::memory::desc& result_desc,
const size_t concat_dim);
size_t build_softmax_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
int softmax_axis);
private:
std::vector<mkldnn::primitive*> m_mkldnn_primitives;
std::vector<mkldnn::stream> m_mkldnn_streams;
......
......@@ -32,6 +32,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
......@@ -622,6 +623,26 @@ namespace ngraph
rnn_node->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Softmax)
{
auto softmax = static_cast<op::Softmax*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
if ((arg0_rank == 4 || arg0_rank == 2) &&
node->get_input_element_type(0) == element::f32 &&
softmax->get_axes().size() == 1)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
softmax->set_op_annotations(op_annotations);
}
}
}
}
}
......@@ -673,6 +694,7 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::SigmoidBackprop>},
{TI(ngraph::op::Lstm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Lstm>},
{TI(ngraph::op::Rnn), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Rnn>},
{TI(ngraph::op::Softmax), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Softmax>},
};
bool runtime::cpu::pass::CPUAssignment::run_on_call_graph(
......
......@@ -36,6 +36,7 @@
#include "ngraph/op/op.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
......@@ -1462,6 +1463,23 @@ namespace ngraph
throw ngraph_error("RNN fused op is only supported in MKLDNN for now.");
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Softmax)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 0);
vector<memory::format> prim_output_formats;
prim_output_formats.push_back(input_layout);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
}
}
}
......@@ -1515,6 +1533,7 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
&runtime::cpu::pass::CPULayout::layout<ngraph::op::SigmoidBackprop>},
{TI(ngraph::op::Lstm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Lstm>},
{TI(ngraph::op::Rnn), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Rnn>},
{TI(ngraph::op::Softmax), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Softmax>},
};
bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment