Commit bb06c80b authored by Pruthvi's avatar Pruthvi Committed by Robert Kimball

MKLDNN Softmax (#1113)

* 1. Added mkldnn support for Softmax
2. layout assignment for mkldnn softmax

* added assert to check softmax axis for mkldnn
parent b3f0a474
This diff is collapsed.
......@@ -911,3 +911,20 @@ size_t MKLDNNEmitter::build_concat(const std::vector<mkldnn::memory::desc>& inpu
m_primitive_deps[concat_index] = in_out_index;
return concat_index;
}
size_t MKLDNNEmitter::build_softmax_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
int softmax_axis)
{
size_t input_index = build_memory_primitive(input_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index = insert_primitive(
new mkldnn::softmax_forward({{mkldnn::prop_kind::forward_scoring, input_desc, softmax_axis},
mkldnn_utils::global_cpu_engine},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
}
......@@ -235,6 +235,10 @@ namespace ngraph
const mkldnn::memory::desc& result_desc,
const size_t concat_dim);
size_t build_softmax_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
int softmax_axis);
private:
std::vector<mkldnn::primitive*> m_mkldnn_primitives;
std::vector<mkldnn::stream> m_mkldnn_streams;
......
......@@ -32,6 +32,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
......@@ -622,6 +623,26 @@ namespace ngraph
rnn_node->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Softmax)
{
auto softmax = static_cast<op::Softmax*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
if ((arg0_rank == 4 || arg0_rank == 2) &&
node->get_input_element_type(0) == element::f32 &&
softmax->get_axes().size() == 1)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
softmax->set_op_annotations(op_annotations);
}
}
}
}
}
......@@ -673,6 +694,7 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::SigmoidBackprop>},
{TI(ngraph::op::Lstm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Lstm>},
{TI(ngraph::op::Rnn), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Rnn>},
{TI(ngraph::op::Softmax), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Softmax>},
};
bool runtime::cpu::pass::CPUAssignment::run_on_call_graph(
......
......@@ -36,6 +36,7 @@
#include "ngraph/op/op.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
......@@ -1462,6 +1463,23 @@ namespace ngraph
throw ngraph_error("RNN fused op is only supported in MKLDNN for now.");
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Softmax)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 0);
vector<memory::format> prim_output_formats;
prim_output_formats.push_back(input_layout);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
}
}
}
......@@ -1515,6 +1533,7 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
&runtime::cpu::pass::CPULayout::layout<ngraph::op::SigmoidBackprop>},
{TI(ngraph::op::Lstm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Lstm>},
{TI(ngraph::op::Rnn), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Rnn>},
{TI(ngraph::op::Softmax), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Softmax>},
};
bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment