Unverified Commit 8520e846 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by GitHub

Merge pull request #606 from NervanaSystems/jbobba/maxpool-layouts

Add mkldnn layouts to Maxpool and Maxpoolbackprop
parents 89da71d3 f521db20
...@@ -146,6 +146,7 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context() ...@@ -146,6 +146,7 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context()
} }
const auto& mkldnn_emitter = m_external_function->get_mkldnn_emitter(); const auto& mkldnn_emitter = m_external_function->get_mkldnn_emitter();
ctx->mkldnn_primitives = mkldnn_emitter->get_mkldnn_primitives().data(); ctx->mkldnn_primitives = mkldnn_emitter->get_mkldnn_primitives().data();
ctx->mkldnn_workspaces = mkldnn_emitter->get_mkldnn_workspaces().data();
} }
void runtime::cpu::CPU_CallFrame::cleanup_runtime_context() void runtime::cpu::CPU_CallFrame::cleanup_runtime_context()
......
...@@ -2883,59 +2883,31 @@ namespace ngraph ...@@ -2883,59 +2883,31 @@ namespace ngraph
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string( auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
args[0].get_element_type()); auto diff_dst_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
auto diff_src_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
auto input_format = size_t avg_pool_index = mkldnn_emitter->build_pooling_backward(
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0); (apb->get_include_padding_in_avg_computation()
auto result_format = ? mkldnn::algorithm::pooling_avg_include_padding
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0); : mkldnn::algorithm::pooling_avg_exclude_padding),
diff_dst_desc,
diff_src_desc,
apb->get_window_movement_strides(),
apb->get_window_shape(),
apb->get_padding_below(),
apb->get_padding_above());
writer << "{\n"; auto& deps = mkldnn_emitter->get_primitive_deps(avg_pool_index);
writer.indent++; writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
writer << "engine cpu_engine = engine(engine::cpu, 0);\n"; writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
writer << "memory::desc input_data_desc = memory::desc({" << join(delta_shape) << to_string(avg_pool_index) << ");\n";
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(input_format)
<< ");\n";
writer << "memory::desc result_desc = memory::desc({" << join(out_shape)
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(result_format)
<< ");\n";
writer << "memory input_data = memory({input_data_desc, cpu_engine}, "
<< args[0].get_name() << ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, "
<< out[0].get_name() << ");\n";
// Dummy forward primitive descriptor to keep MKLDNN happy
const char* algorithm_enumerator =
apb->get_include_padding_in_avg_computation()
? "algorithm::pooling_avg_include_padding"
: "algorithm::pooling_avg_exclude_padding";
writer << "pooling_forward::primitive_desc fwd_pd = "
"pooling_forward::primitive_desc("
<< "{prop_kind::forward, " << algorithm_enumerator << ", "
<< "result_desc, input_data_desc, {"
<< join(apb->get_window_movement_strides()) << "}, {"
<< join(apb->get_window_shape()) << "}, "
<< "{" << join(apb->get_padding_below()) << "}, "
<< "{" << join(apb->get_padding_above()) << "}, "
<< "padding_kind::zero}, cpu_engine);\n";
writer
<< "auto avg_pooling = pooling_backward(pooling_backward::primitive_desc("
<< "pooling_backward::desc(" << algorithm_enumerator << ", "
<< "result_desc, input_data_desc, {"
<< join(apb->get_window_movement_strides()) << "}, {"
<< join(apb->get_window_shape()) << "}, "
<< "{" << join(apb->get_padding_below()) << "}, "
<< "{" << join(apb->get_padding_above()) << "}, "
<< "padding_kind::zero), cpu_engine, fwd_pd), "
<< "input_data, result);\n";
writer << "auto s = stream(stream::kind::eager);\n"
<< "s.submit({avg_pooling}).wait();\n";
writer.indent--;
writer << "}\n";
} }
else else
{ {
...@@ -2963,79 +2935,48 @@ namespace ngraph ...@@ -2963,79 +2935,48 @@ namespace ngraph
auto mpb = static_cast<const ngraph::op::MaxPoolBackprop*>(node); auto mpb = static_cast<const ngraph::op::MaxPoolBackprop*>(node);
auto delta_shape = args[1].get_shape(); auto delta_shape = args[1].get_shape();
auto delta_rank = delta_shape.size();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
if (delta_rank == 4 && mpb->get_window_shape().size() == 2 && if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
args[0].get_element_type() == element::f32)
{ {
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string( auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
args[1].get_element_type()); auto fprop_src_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
auto diff_dst_desc = mkldnn_emitter->build_memory_descriptor(
args[1], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1));
auto diff_src_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
writer << "{\n"; size_t max_pool_index = mkldnn_emitter->build_max_pooling_backward(
writer.indent++; mkldnn::algorithm::pooling_max,
writer << "engine cpu_engine = engine(engine::cpu, 0);\n"; fprop_src_desc,
writer << "memory::desc input_data_desc = memory::desc({" << join(delta_shape) diff_dst_desc,
<< "}, " << et << ", memory::format::nchw);\n"; diff_src_desc,
writer << "memory::desc result_desc = memory::desc({" << join(out_shape) mpb->get_window_movement_strides(),
<< "}, " << et << ", memory::format::nchw);\n"; mpb->get_window_shape(),
writer << "memory input_data = memory({input_data_desc, cpu_engine}, " mpb->get_padding_below(),
<< args[1].get_name() << ");\n"; mpb->get_padding_above());
writer << "memory result = memory({result_desc, cpu_engine}, "
<< out[0].get_name() << ");\n"; auto& fdeps = mkldnn_emitter->get_primitive_deps(max_pool_index - 1);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(fdeps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(fdeps[1])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(fdeps[2])
<< ", ctx->mkldnn_workspaces[" << fdeps[3] << "]);\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(max_pool_index - 1) << ");\n";
//---------------------------------------------------------------------------------------------- auto& bdeps = mkldnn_emitter->get_primitive_deps(max_pool_index);
// create a forward primitive_desc, use this to query the workspace writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(bdeps[0])
// TODO: (pruthvi) this is a workaround, till we maintain a global context to refer to the corrosponding << ", " << args[1].get_name() << ");\n";
// MKLDNN fprop kernel. this impacts performance writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(bdeps[1])
writer << "memory::desc max_pool_input_desc = memory::desc({" << ", ctx->mkldnn_workspaces[" << bdeps[3] << "]);\n";
<< join(args[0].get_shape()) << "}, " << et writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(bdeps[2])
<< ", memory::format::nchw);\n"; << ", " << out[0].get_name() << ");\n";
writer << "memory::desc max_pool_result_desc = memory::desc({"
<< join(args[1].get_shape()) << "}, " << et
<< ", memory::format::nchw);\n";
writer
<< "memory maxpool_input_data = memory({max_pool_input_desc, cpu_engine}, "
<< args[0].get_name() << ");\n";
writer << "memory maxpool_result = memory({max_pool_result_desc, cpu_engine}, "
<< out[0].get_name() << ");\n";
writer << "pooling_forward::primitive_desc pool_fwd_pd = "
"pooling_forward::primitive_desc("
<< "{prop_kind::forward, algorithm::pooling_max, "
<< "max_pool_input_desc, max_pool_result_desc, {"
<< join(mpb->get_window_movement_strides()) << "}, {"
<< join(mpb->get_window_shape()) << "}, "
<< "{" << join(mpb->get_padding_below()) << "}, "
<< "{" << join(mpb->get_padding_above()) << "}, "
<< "padding_kind::zero}, cpu_engine);\n";
// query the workspace from the forward primitive desc and allocates memory
writer << "auto max_pool_workspace_memory = "
"memory(pool_fwd_pd.workspace_primitive_desc());\n";
//run fprop with this workspace attached
writer << "pooling_forward max_pooling_fwd = pooling_forward("
<< "pool_fwd_pd, maxpool_input_data, maxpool_result, "
"max_pool_workspace_memory);\n";
writer << "stream s_fprop = stream(stream::kind::eager);\n"
<< "s_fprop.submit({max_pooling_fwd}).wait();\n";
//---------------------------------------------------------------------------------------------
writer << "auto max_pooling_bwd = "
"pooling_backward(pooling_backward::primitive_desc("
<< "pooling_backward::desc(algorithm::pooling_max, "
<< "result_desc, input_data_desc, {"
<< join(mpb->get_window_movement_strides()) << "}, {"
<< join(mpb->get_window_shape()) << "}, "
<< "{" << join(mpb->get_padding_below()) << "}, "
<< "{" << join(mpb->get_padding_above()) << "}, "
<< "padding_kind::zero), cpu_engine, pool_fwd_pd), "
<< "input_data, max_pool_workspace_memory, result);\n";
writer << "auto s_bwd = stream(stream::kind::eager);\n"
<< "s_bwd.submit({max_pooling_bwd}).wait();\n";
writer.indent--; writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
writer << "}\n"; << to_string(max_pool_index) << ");\n";
} }
else else
{ {
......
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
{ {
int64_t* op_durations; int64_t* op_durations;
mkldnn::primitive* const* mkldnn_primitives; mkldnn::primitive* const* mkldnn_primitives;
char* const* mkldnn_workspaces;
}; };
} }
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp" #include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace ngraph::runtime::cpu; using namespace ngraph::runtime::cpu;
...@@ -36,12 +37,24 @@ const std::vector<mkldnn::primitive*>& MKLDNNEmitter::get_mkldnn_primitives() co ...@@ -36,12 +37,24 @@ const std::vector<mkldnn::primitive*>& MKLDNNEmitter::get_mkldnn_primitives() co
return m_mkldnn_primitives; return m_mkldnn_primitives;
} }
const std::vector<char*>& MKLDNNEmitter::get_mkldnn_workspaces()
{
return m_workspace_bufs;
}
size_t MKLDNNEmitter::insert_primitive(mkldnn::primitive* primitive) size_t MKLDNNEmitter::insert_primitive(mkldnn::primitive* primitive)
{ {
m_mkldnn_primitives.emplace_back(primitive); m_mkldnn_primitives.emplace_back(primitive);
return (m_mkldnn_primitives.size() - 1); return (m_mkldnn_primitives.size() - 1);
} }
size_t MKLDNNEmitter::insert_workspace(std::unique_ptr<MKLDNNWorkspace>& workspace)
{
m_workspace_bufs.push_back(workspace.get()->buf);
m_workspaces.push_back(std::move(workspace));
return (m_workspaces.size() - 1);
}
const std::vector<size_t>& MKLDNNEmitter::get_primitive_deps(size_t index) const const std::vector<size_t>& MKLDNNEmitter::get_primitive_deps(size_t index) const
{ {
return m_primitive_deps.at(index); return m_primitive_deps.at(index);
...@@ -321,6 +334,105 @@ size_t MKLDNNEmitter::build_pooling_forward(mkldnn::algorithm pooling_algorithm, ...@@ -321,6 +334,105 @@ size_t MKLDNNEmitter::build_pooling_forward(mkldnn::algorithm pooling_algorithm,
return primitive_index; return primitive_index;
} }
size_t MKLDNNEmitter::build_pooling_backward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& diff_dst_desc,
const mkldnn::memory::desc& diff_src_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above)
{
size_t input_index = build_memory_primitive(diff_dst_desc);
size_t result_index = build_memory_primitive(diff_src_desc);
size_t primitive_index = insert_primitive(new mkldnn::pooling_backward(
{{pooling_algorithm,
diff_src_desc,
diff_dst_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
mkldnn::memory::dims(window_shape.begin(), window_shape.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine,
{{mkldnn::prop_kind::forward_training,
pooling_algorithm,
diff_src_desc,
diff_dst_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
mkldnn::memory::dims(window_shape.begin(), window_shape.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine}},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::build_max_pooling_backward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& fprop_src_desc,
const mkldnn::memory::desc& diff_dst_desc,
const mkldnn::memory::desc& diff_src_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above)
{
size_t fprop_src_index = build_memory_primitive(fprop_src_desc);
size_t diff_dst_index = build_memory_primitive(diff_dst_desc);
size_t diff_src_index = build_memory_primitive(diff_src_desc);
mkldnn::pooling_forward::primitive_desc fwd_pd{
{mkldnn::prop_kind::forward_training,
pooling_algorithm,
diff_src_desc,
diff_dst_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
mkldnn::memory::dims(window_shape.begin(), window_shape.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine};
auto ws_index = build_memory_primitive(fwd_pd.workspace_primitive_desc().desc());
// Allocate workspace
// TODO (jbobba): Might need to align memory
auto ws = std::unique_ptr<MKLDNNWorkspace>(
new MKLDNNWorkspace(fwd_pd.workspace_primitive_desc().get_size()));
auto ws_buf_index = insert_workspace(ws);
size_t fwd_primitive_index = insert_primitive(new mkldnn::pooling_forward(
fwd_pd,
*m_mkldnn_primitives[fprop_src_index],
*m_mkldnn_primitives
[diff_src_index], // HACK - Uses diff_src buffer. Safe since diff_src > fprop_dst
*m_mkldnn_primitives[ws_index]));
size_t bwd_primitive_index = insert_primitive(new mkldnn::pooling_backward(
{{pooling_algorithm,
diff_src_desc,
diff_dst_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
mkldnn::memory::dims(window_shape.begin(), window_shape.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine,
fwd_pd},
*m_mkldnn_primitives[diff_dst_index],
*m_mkldnn_primitives[ws_index],
*m_mkldnn_primitives[diff_src_index]));
m_primitive_deps[fwd_primitive_index] = {
fprop_src_index, diff_src_index, ws_index, ws_buf_index};
m_primitive_deps[bwd_primitive_index] = {
diff_dst_index, ws_index, diff_src_index, ws_buf_index};
return bwd_primitive_index;
}
size_t MKLDNNEmitter::build_reorder(const mkldnn::memory::desc& input_desc, size_t MKLDNNEmitter::build_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc) const mkldnn::memory::desc& result_desc)
{ {
......
...@@ -35,6 +35,14 @@ namespace ngraph ...@@ -35,6 +35,14 @@ namespace ngraph
class CPU_ExternalFunction; class CPU_ExternalFunction;
class TensorViewWrapper; class TensorViewWrapper;
class MKLDNNWorkspace
{
public:
MKLDNNWorkspace(size_t size) { buf = reinterpret_cast<char*>(malloc(size)); }
~MKLDNNWorkspace() { free(buf); }
char* buf;
};
class MKLDNNEmitter class MKLDNNEmitter
{ {
public: public:
...@@ -42,8 +50,10 @@ namespace ngraph ...@@ -42,8 +50,10 @@ namespace ngraph
~MKLDNNEmitter(); ~MKLDNNEmitter();
const std::vector<mkldnn::primitive*>& get_mkldnn_primitives() const; const std::vector<mkldnn::primitive*>& get_mkldnn_primitives() const;
const std::vector<char*>& get_mkldnn_workspaces();
size_t insert_primitive(mkldnn::primitive* primitive); size_t insert_primitive(mkldnn::primitive* primitive);
size_t insert_workspace(std::unique_ptr<MKLDNNWorkspace>& workspace);
const std::vector<size_t>& get_primitive_deps(size_t index) const; const std::vector<size_t>& get_primitive_deps(size_t index) const;
// TODO(jmenon): Get rid of TensorViewWrappers at some point // TODO(jmenon): Get rid of TensorViewWrappers at some point
...@@ -109,6 +119,23 @@ namespace ngraph ...@@ -109,6 +119,23 @@ namespace ngraph
const ngraph::Shape& padding_below, const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above); const ngraph::Shape& padding_above);
size_t build_pooling_backward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& diff_dst_desc,
const mkldnn::memory::desc& diff_src_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above);
size_t build_max_pooling_backward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& fprop_src_desc,
const mkldnn::memory::desc& diff_dst_desc,
const mkldnn::memory::desc& diff_src_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above);
size_t build_reorder(const mkldnn::memory::desc& input_desc, size_t build_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc); const mkldnn::memory::desc& result_desc);
...@@ -129,6 +156,8 @@ namespace ngraph ...@@ -129,6 +156,8 @@ namespace ngraph
std::vector<mkldnn::primitive*> m_mkldnn_primitives; std::vector<mkldnn::primitive*> m_mkldnn_primitives;
std::vector<mkldnn::stream> m_mkldnn_streams; std::vector<mkldnn::stream> m_mkldnn_streams;
std::unordered_map<size_t, std::vector<size_t>> m_primitive_deps; std::unordered_map<size_t, std::vector<size_t>> m_primitive_deps;
std::vector<std::unique_ptr<MKLDNNWorkspace>> m_workspaces;
std::vector<char*> m_workspace_bufs;
}; };
} }
} }
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "ngraph/ops/avg_pool.hpp" #include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp" #include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp" #include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/max_pool.hpp"
#include "ngraph/ops/relu.hpp" #include "ngraph/ops/relu.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp" #include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
...@@ -245,10 +246,48 @@ namespace ngraph ...@@ -245,10 +246,48 @@ namespace ngraph
} }
} }
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPool)
{
auto max_pool = static_cast<op::MaxPool*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
if (arg0_rank == 4 && max_pool->get_window_shape().size() == 2 &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
max_pool->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolBackprop)
{
auto max_pool = static_cast<op::MaxPoolBackprop*>(node);
auto arg1_shape = node->get_input_shape(1);
auto arg1_rank = arg1_shape.size();
auto result_shape = node->get_output_shape(0);
if (arg1_rank == 4 && max_pool->get_window_shape().size() == 2 &&
node->get_input_element_type(1) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
max_pool->set_op_annotations(op_annotations);
}
}
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Relu) void CPUAssignment::ASSIGN_DECL(ngraph::op::Relu)
{ {
auto avg_pool = static_cast<op::Relu*>(node); auto relu = static_cast<op::Relu*>(node);
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
...@@ -260,7 +299,7 @@ namespace ngraph ...@@ -260,7 +299,7 @@ namespace ngraph
auto op_annotations = auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>(); std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true); op_annotations->set_mkldnn_op(true);
avg_pool->set_op_annotations(op_annotations); relu->set_op_annotations(op_annotations);
} }
} }
...@@ -280,18 +319,19 @@ namespace ngraph ...@@ -280,18 +319,19 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ReluBackprop) void CPUAssignment::ASSIGN_DECL(ngraph::op::ReluBackprop)
{ {
auto avg_pool = static_cast<op::ReluBackprop*>(node); auto relu_bprop = static_cast<op::ReluBackprop*>(node);
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0); auto result_shape = node->get_output_shape(0);
if (arg0_rank == 4 && node->get_input_element_type(0) == element::f32) if ((arg0_rank == 4 || arg0_rank == 2) &&
node->get_input_element_type(0) == element::f32)
{ {
auto op_annotations = auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>(); std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true); op_annotations->set_mkldnn_op(true);
avg_pool->set_op_annotations(op_annotations); relu_bprop->set_op_annotations(op_annotations);
} }
} }
...@@ -313,6 +353,9 @@ namespace ngraph ...@@ -313,6 +353,9 @@ namespace ngraph
static const runtime::cpu::pass::AssignOpMap s_dispatcher{ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Add>}, {TI(ngraph::op::Add), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Add>},
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNorm>}, {TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNorm>},
{TI(ngraph::op::Convolution), {TI(ngraph::op::Convolution),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>},
...@@ -320,13 +363,13 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{ ...@@ -320,13 +363,13 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropData>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropData>},
{TI(ngraph::op::ConvolutionBackpropFilters), {TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropFilters>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropFilters>},
{TI(ngraph::op::MaxPool), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPool>},
{TI(ngraph::op::MaxPoolBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPoolBackprop>},
{TI(ngraph::op::ConvolutionBias), {TI(ngraph::op::ConvolutionBias),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBias>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBias>},
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias), {TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBiasBackpropFiltersBias>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBiasBackpropFiltersBias>},
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Relu>}, {TI(ngraph::op::Relu), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Relu>},
{TI(ngraph::op::ReluBackprop), {TI(ngraph::op::ReluBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ReluBackprop>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ReluBackprop>},
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "ngraph/ops/batch_norm.hpp" #include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp" #include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/get_output_element.hpp" #include "ngraph/ops/get_output_element.hpp"
#include "ngraph/ops/max_pool.hpp"
#include "ngraph/ops/op.hpp" #include "ngraph/ops/op.hpp"
#include "ngraph/ops/relu.hpp" #include "ngraph/ops/relu.hpp"
#include "ngraph/ops/result.hpp" #include "ngraph/ops/result.hpp"
...@@ -645,11 +646,8 @@ namespace ngraph ...@@ -645,11 +646,8 @@ namespace ngraph
} }
catch (const mkldnn::error& e) catch (const mkldnn::error& e)
{ {
// TODO (jbobba): Check with MKLDNN folks if this is necessary
throw ngraph_error("MKLDNN Unsupported pooling layout" + throw ngraph_error("MKLDNN Unsupported pooling layout" +
to_string(input_layout) + e.message); to_string(input_layout) + e.message);
// prim_input_formats.push_back(memory::format::nchw);
// prim_output_formats.push_back(memory::format::nchw);
} }
node = node =
...@@ -732,11 +730,169 @@ namespace ngraph ...@@ -732,11 +730,169 @@ namespace ngraph
} }
catch (const mkldnn::error& e) catch (const mkldnn::error& e)
{ {
// TODO (jbobba): Check with MKLDNN folks if this is necessary
throw ngraph_error("MKLDNN Unsupported pooling layout" + throw ngraph_error("MKLDNN Unsupported pooling layout" +
to_string(input_layout) + e.message); to_string(input_layout) + e.message);
// prim_input_formats.push_back(memory::format::nchw); }
// prim_output_formats.push_back(memory::format::nchw);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::MaxPool)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto max_pool = static_cast<const ngraph::op::MaxPool*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto result_shape = node->get_output_shape(0);
auto filter_shape = max_pool->get_window_shape();
auto filter_strides = max_pool->get_window_movement_strides();
auto padding_below = max_pool->get_padding_below();
auto padding_above = max_pool->get_padding_above();
memory::data_type et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(
node->get_input_element_type(0));
algorithm algorithm_enumerator = algorithm::pooling_max;
memory::dims mkldnn_arg0_shape(arg0_shape.begin(), arg0_shape.end());
memory::dims mkldnn_result_shape(result_shape.begin(), result_shape.end());
memory::dims mkldnn_filter_shape(filter_shape.begin(), filter_shape.end());
memory::dims mkldnn_filter_strides(filter_strides.begin(),
filter_strides.end());
memory::dims mkldnn_padding_below(padding_below.begin(),
padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(),
padding_above.end());
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 0);
auto input_desc = memory::desc(mkldnn_arg0_shape, et, input_layout);
auto result_desc =
memory::desc(mkldnn_result_shape, et, memory::format::any);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
try
{
auto prim_desc = pooling_forward::primitive_desc(
{prop_kind::forward_inference,
algorithm_enumerator,
input_desc,
result_desc,
mkldnn_filter_strides,
mkldnn_filter_shape,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero},
runtime::cpu::mkldnn_utils::global_cpu_engine);
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.dst_primitive_desc().desc().data.format));
// TODO (jbobba): Add workspace layouts here
}
catch (const mkldnn::error& e)
{
throw ngraph_error("MKLDNN Unsupported pooling fwd layout" +
to_string(input_layout) + e.message);
}
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::MaxPoolBackprop)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto max_pool = static_cast<const ngraph::op::MaxPoolBackprop*>(node.get());
// arg 0 - fprop input
// arg 1 - delta
// Propagate fprop's input layout
auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1);
auto result_shape = node->get_output_shape(0);
auto filter_shape = max_pool->get_window_shape();
auto filter_strides = max_pool->get_window_movement_strides();
auto padding_below = max_pool->get_padding_below();
auto padding_above = max_pool->get_padding_above();
memory::data_type et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(
node->get_input_element_type(1));
algorithm algorithm_enumerator = algorithm::pooling_max;
memory::dims mkldnn_arg0_shape(arg0_shape.begin(), arg0_shape.end());
memory::dims mkldnn_arg1_shape(arg1_shape.begin(), arg1_shape.end());
memory::dims mkldnn_result_shape(result_shape.begin(), result_shape.end());
memory::dims mkldnn_filter_shape(filter_shape.begin(), filter_shape.end());
memory::dims mkldnn_filter_strides(filter_strides.begin(),
filter_strides.end());
memory::dims mkldnn_padding_below(padding_below.begin(),
padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(),
padding_above.end());
auto fprop_input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 0);
auto diff_dst_desc =
memory::desc(mkldnn_arg1_shape, et, fprop_input_layout);
auto diff_src_desc =
memory::desc(mkldnn_arg0_shape, et, memory::format::any);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
try
{
auto fwd_prim_desc = pooling_forward::primitive_desc(
{prop_kind::forward_training,
algorithm_enumerator,
diff_src_desc,
diff_dst_desc,
mkldnn_filter_strides,
mkldnn_filter_shape,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero},
runtime::cpu::mkldnn_utils::global_cpu_engine);
auto prim_desc = pooling_backward::primitive_desc(
{algorithm_enumerator,
diff_src_desc,
diff_dst_desc,
mkldnn_filter_strides,
mkldnn_filter_shape,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero},
runtime::cpu::mkldnn_utils::global_cpu_engine,
fwd_prim_desc);
prim_input_formats.push_back(fprop_input_layout);
prim_input_formats.push_back(fprop_input_layout);
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.diff_src_primitive_desc().desc().data.format));
}
catch (const mkldnn::error& e)
{
throw ngraph_error("MKLDNN Unsupported pooling layout" +
to_string(fprop_input_layout) + e.message);
} }
node = node =
...@@ -884,18 +1040,21 @@ namespace ngraph ...@@ -884,18 +1040,21 @@ namespace ngraph
static const runtime::cpu::pass::LayoutOpMap s_dispatcher{ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Add>}, {TI(ngraph::op::Add), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Add>},
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::Convolution), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Convolution>}, {TI(ngraph::op::Convolution), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionBackpropData), {TI(ngraph::op::ConvolutionBackpropData),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropData>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropData>},
{TI(ngraph::op::ConvolutionBackpropFilters), {TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropFilters>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropFilters>},
{TI(ngraph::op::MaxPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::MaxPool>},
{TI(ngraph::op::MaxPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::MaxPoolBackprop>},
{TI(ngraph::op::ConvolutionBias), {TI(ngraph::op::ConvolutionBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBias>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBias>},
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias), {TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBiasBackpropFiltersBias>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBiasBackpropFiltersBias>},
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNorm>}, {TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNorm>},
{TI(ngraph::op::GetOutputElement), {TI(ngraph::op::GetOutputElement),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GetOutputElement>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::GetOutputElement>},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment