Commit b5e030d9 authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Fix maxpoolbprop layout pass (#4122)

* Fix maxpoolbprop layout pass

* Use default format for max pooling bprop.

* Fix MaxPoolWithIndicesBackprop unit test.

* Fix CODEGEN.

* Modify MaxPoolWithIndicesBackprop unit test.
Co-authored-by: 's avatarAmy Zhuang <amyzhuang97@gmail.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 4ba7a8fa
...@@ -76,8 +76,8 @@ namespace ngraph ...@@ -76,8 +76,8 @@ namespace ngraph
// Lstm needs 9 primitives: src_layer, src_iter, weights_layer, weights_iter, bias, // Lstm needs 9 primitives: src_layer, src_iter, weights_layer, weights_iter, bias,
// dst_layer, dst_iter, workspace, and rnn_forward. // dst_layer, dst_iter, workspace, and rnn_forward.
// It needs a new workspace. // It needs a new workspace.
auto lstm_index = auto lstm_index = mkldnn_emitter->reserve_primitive_space(
mkldnn_emitter->reserve_primitive_space(9, true /* new workspace */); 9, false /* fwd and bwd */, true /* new workspace */);
auto& deps = mkldnn_emitter->get_primitive_deps(lstm_index); auto& deps = mkldnn_emitter->get_primitive_deps(lstm_index);
auto functor = [&, auto functor = [&,
...@@ -139,8 +139,8 @@ namespace ngraph ...@@ -139,8 +139,8 @@ namespace ngraph
// weights_iter, bias, // weights_iter, bias,
// dst_layer, dst_iter, dst_iter_c, workspace, and lstm_forward. // dst_layer, dst_iter, dst_iter_c, workspace, and lstm_forward.
// It needs a new workspace. // It needs a new workspace.
auto lstm_index = auto lstm_index = mkldnn_emitter->reserve_primitive_space(
mkldnn_emitter->reserve_primitive_space(11, true /* new workspace */); 11, false /* fwd and bwd */, true /* new workspace */);
auto& deps = mkldnn_emitter->get_primitive_deps(lstm_index); auto& deps = mkldnn_emitter->get_primitive_deps(lstm_index);
auto functor = [&, auto functor = [&,
......
...@@ -155,8 +155,8 @@ namespace ngraph ...@@ -155,8 +155,8 @@ namespace ngraph
// MaxPoolBackprop forward needs 4 primitives: fprop_src, diff_src, workspace, // MaxPoolBackprop forward needs 4 primitives: fprop_src, diff_src, workspace,
// and pooling_forward. // and pooling_forward.
// It needs a new workspace. // It needs a new workspace.
size_t fwd_pool_index = size_t fwd_pool_index = mkldnn_emitter->reserve_primitive_space(
mkldnn_emitter->reserve_primitive_space(4, true /* new workspace */); 4, false /* fwd and bwd */, true /* new workspace */);
auto& fdeps = mkldnn_emitter->get_primitive_deps(fwd_pool_index); auto& fdeps = mkldnn_emitter->get_primitive_deps(fwd_pool_index);
auto functor_fprop = [&, auto functor_fprop = [&,
...@@ -182,8 +182,8 @@ namespace ngraph ...@@ -182,8 +182,8 @@ namespace ngraph
// MaxPoolBackprop backward needs 4 primitives: diff_dst, workspace, diff_src, // MaxPoolBackprop backward needs 4 primitives: diff_dst, workspace, diff_src,
// and pooling_backward. // and pooling_backward.
// It needs a new workspace. // It needs a new workspace.
size_t bwd_pool_index = size_t bwd_pool_index = mkldnn_emitter->reserve_primitive_space(
mkldnn_emitter->reserve_primitive_space(4, true /* new workspace */); 4, false /* fwd and bwd */, true /* new workspace */);
auto& bdeps = mkldnn_emitter->get_primitive_deps(bwd_pool_index); auto& bdeps = mkldnn_emitter->get_primitive_deps(bwd_pool_index);
auto functor_bprop = [&, bwd_pool_index, delta_buffer_index, out_buffer_index]( auto functor_bprop = [&, bwd_pool_index, delta_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* /* ectx */) { CPURuntimeContext* ctx, CPUExecutionContext* /* ectx */) {
......
...@@ -62,8 +62,8 @@ namespace ngraph ...@@ -62,8 +62,8 @@ namespace ngraph
// Rnn needs 9 primitives: src_layer, src_iter, weights_layer, weights_iter, bias, // Rnn needs 9 primitives: src_layer, src_iter, weights_layer, weights_iter, bias,
// dst_layer, dst_iter, workspace, and rnn_forward. // dst_layer, dst_iter, workspace, and rnn_forward.
// It needs a new workspace. // It needs a new workspace.
auto rnn_index = auto rnn_index = mkldnn_emitter->reserve_primitive_space(
mkldnn_emitter->reserve_primitive_space(9, true /* new workspace */); 9, false /* fwd and bwd */, true /* new workspace */);
auto& deps = mkldnn_emitter->get_primitive_deps(rnn_index); auto& deps = mkldnn_emitter->get_primitive_deps(rnn_index);
auto functor = [&, auto functor = [&,
...@@ -125,8 +125,8 @@ namespace ngraph ...@@ -125,8 +125,8 @@ namespace ngraph
// weights_iter, bias, // weights_iter, bias,
// dst_layer, dst_iter, dst_iter_c, workspace, and lstm_forward. // dst_layer, dst_iter, dst_iter_c, workspace, and lstm_forward.
// It needs a new workspace. // It needs a new workspace.
auto rnn_index = auto rnn_index = mkldnn_emitter->reserve_primitive_space(
mkldnn_emitter->reserve_primitive_space(11, true /* new workspace */); 11, false /* fwd and bwd */, true /* new workspace */);
auto& deps = mkldnn_emitter->get_primitive_deps(rnn_index); auto& deps = mkldnn_emitter->get_primitive_deps(rnn_index);
auto functor = [&, auto functor = [&,
......
...@@ -3374,17 +3374,17 @@ namespace ngraph ...@@ -3374,17 +3374,17 @@ namespace ngraph
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[2]) << ", " writer << "cg_ctx->set_memory_ptr(" << to_string(deps[2]) << ", "
<< out[0].get_name() << ");\n"; << out[0].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[3]) writer << "cg_ctx->set_memory_ptr(" << to_string(deps[3])
<< ", cg_ctx->mkldnn_workspaces[" << deps[5] << "]);\n"; << ", cg_ctx->mkldnn_workspaces[" << deps[4] << "]);\n";
writer << "std::vector<size_t> deps{" << join(deps) << "};\n"; writer << "std::vector<size_t> deps{" << join(deps) << "};\n";
writer << "cg_ctx->mkldnn_invoke_primitive(" << to_string(deps[4]) writer << "cg_ctx->mkldnn_invoke_primitive(" << to_string(max_pool_index - 1)
<< ",deps, OpType::MAXPOOLBACKPROPFORWARD, " << ", deps, OpType::MAXPOOLBACKPROPFORWARD, "
<< to_string(scratchpad_size) << ");\n"; << to_string(scratchpad_size) << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[1]) << ", " writer << "cg_ctx->set_memory_ptr(" << to_string(deps[1]) << ", "
<< args[1].get_name() << ");\n"; << args[1].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[3]) writer << "cg_ctx->set_memory_ptr(" << to_string(deps[3])
<< ", cg_ctx->mkldnn_workspaces[" << deps[5] << "]);\n"; << ", cg_ctx->mkldnn_workspaces[" << deps[4] << "]);\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[2]) << ", " writer << "cg_ctx->set_memory_ptr(" << to_string(deps[2]) << ", "
<< out[0].get_name() << ");\n"; << out[0].get_name() << ");\n";
......
...@@ -478,17 +478,31 @@ size_t MKLDNNEmitter::inner_product_forward_init(bool with_bias) ...@@ -478,17 +478,31 @@ size_t MKLDNNEmitter::inner_product_forward_init(bool with_bias)
return m_mkldnn_primitives.size() - 1; return m_mkldnn_primitives.size() - 1;
} }
size_t MKLDNNEmitter::reserve_primitive_space(size_t count, bool new_workspace) size_t MKLDNNEmitter::reserve_primitive_space(size_t count, bool fwd_bwd, bool new_workspace)
{ {
size_t size = m_mkldnn_primitives.size(); size_t size = m_mkldnn_primitives.size();
#if MKLDNN_VERSION_MAJOR >= 1 #if MKLDNN_VERSION_MAJOR >= 1
size_t mem_size = m_mkldnn_memories.size(); size_t mem_size = m_mkldnn_memories.size();
m_mkldnn_primitives.resize(size + 1, nullptr); if (fwd_bwd)
m_mkldnn_scratchpad_mds.resize(size + 1, nullptr);
m_mkldnn_memories.resize(mem_size + count - 1, nullptr);
for (auto i = 0; i < count - 1; i++)
{ {
m_primitive_deps[m_mkldnn_primitives.size() - 1].push_back(mem_size + i); m_mkldnn_primitives.resize(size + 2, nullptr);
m_mkldnn_memories.resize(mem_size + count - 2, nullptr);
m_mkldnn_scratchpad_mds.resize(size + 2, nullptr);
for (auto i = 0; i < count - 2; i++)
{
m_primitive_deps[m_mkldnn_primitives.size() - 2].push_back(mem_size + i);
m_primitive_deps[m_mkldnn_primitives.size() - 1].push_back(mem_size + i);
}
}
else
{
m_mkldnn_primitives.resize(size + 1, nullptr);
m_mkldnn_memories.resize(mem_size + count - 1, nullptr);
m_mkldnn_scratchpad_mds.resize(size + 1, nullptr);
for (auto i = 0; i < count - 1; i++)
{
m_primitive_deps[m_mkldnn_primitives.size() - 1].push_back(mem_size + i);
}
} }
#else #else
m_mkldnn_primitives.resize(size + count, nullptr); m_mkldnn_primitives.resize(size + count, nullptr);
...@@ -501,6 +515,10 @@ size_t MKLDNNEmitter::reserve_primitive_space(size_t count, bool new_workspace) ...@@ -501,6 +515,10 @@ size_t MKLDNNEmitter::reserve_primitive_space(size_t count, bool new_workspace)
if (new_workspace) if (new_workspace)
{ {
m_primitive_deps[m_mkldnn_primitives.size() - 1].push_back(0); m_primitive_deps[m_mkldnn_primitives.size() - 1].push_back(0);
if (fwd_bwd)
{
m_primitive_deps[m_mkldnn_primitives.size() - 2].push_back(0);
}
} }
return m_mkldnn_primitives.size() - 1; return m_mkldnn_primitives.size() - 1;
} }
......
...@@ -139,7 +139,9 @@ namespace ngraph ...@@ -139,7 +139,9 @@ namespace ngraph
// reserve the space for primitives for each op, different op requires different // reserve the space for primitives for each op, different op requires different
// number of primitives. // number of primitives.
// some ops require a new workspace. // some ops require a new workspace.
size_t reserve_primitive_space(size_t count, bool new_workspace = false); size_t reserve_primitive_space(size_t count,
bool fwd_bwd = false,
bool new_workspace = false);
size_t insert_primitive(mkldnn::primitive* primitive); size_t insert_primitive(mkldnn::primitive* primitive);
size_t insert_memory(mkldnn::memory* memory); size_t insert_memory(mkldnn::memory* memory);
size_t insert_workspace(std::unique_ptr<MKLDNNWorkspace>& workspace); size_t insert_workspace(std::unique_ptr<MKLDNNWorkspace>& workspace);
......
...@@ -1743,22 +1743,14 @@ namespace ngraph ...@@ -1743,22 +1743,14 @@ namespace ngraph
memory::dims mkldnn_padding_below(padding_below.begin(), padding_below.end()); memory::dims mkldnn_padding_below(padding_below.begin(), padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(), padding_above.end()); memory::dims mkldnn_padding_above(padding_above.begin(), padding_above.end());
auto fprop_input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0); if (arg0_shape.size() != 4 && arg0_shape.size() != 5)
#if MKLDNN_VERSION_MAJOR < 1
auto fprop_input_layout =
static_cast<memory::format>(fprop_input_md.data.format);
auto diff_dst_desc = memory::desc(mkldnn_arg1_shape, et, fprop_input_layout);
#else
auto strides = fprop_input_md.data.format_desc.blocking.strides;
memory::dims strides_arg;
for (auto i = 0; i < fprop_input_md.data.ndims; i++)
{ {
strides_arg.push_back(strides[i]); throw ngraph_error("MKLDNN Unsupported pooling layout");
} }
auto diff_dst_desc = memory::desc(mkldnn_arg1_shape, et, strides_arg); auto default_format = arg0_shape.size() == 4 ? mkldnn::memory::FORMAT::nchw
#endif : mkldnn::memory::FORMAT::ncdhw;
auto diff_src_desc = memory::desc(mkldnn_arg0_shape, et, memory::FORMAT::any); auto diff_dst_desc = memory::desc(mkldnn_arg1_shape, et, default_format);
auto diff_src_desc = memory::desc(mkldnn_arg0_shape, et, default_format);
try try
{ {
...@@ -1784,7 +1776,8 @@ namespace ngraph ...@@ -1784,7 +1776,8 @@ namespace ngraph
auto prim_desc = pooling_backward::primitive_desc( auto prim_desc = pooling_backward::primitive_desc(
bwd_desc, executor::global_cpu_engine, fwd_prim_desc); bwd_desc, executor::global_cpu_engine, fwd_prim_desc);
i_mds.push_back(fprop_input_md);
i_mds.push_back(diff_src_desc);
i_mds.push_back(diff_dst_desc); i_mds.push_back(diff_dst_desc);
if (with_indices) if (with_indices)
...@@ -1799,11 +1792,7 @@ namespace ngraph ...@@ -1799,11 +1792,7 @@ namespace ngraph
{ {
i_mds.push_back(diff_dst_desc); i_mds.push_back(diff_dst_desc);
} }
#if MKLDNN_VERSION_MAJOR < 1 o_mds.push_back(diff_src_desc);
o_mds.push_back(prim_desc.diff_src_primitive_desc().desc());
#else
o_mds.push_back(prim_desc.diff_src_desc());
#endif
} }
catch (const mkldnn::error& e) catch (const mkldnn::error& e)
{ {
......
...@@ -256,7 +256,8 @@ namespace ngraph ...@@ -256,7 +256,8 @@ namespace ngraph
// weights_iter, bias, // weights_iter, bias,
// dst_layer, dst_iter, dst_iter_c, workspace, and rnn_forward. // dst_layer, dst_iter, dst_iter_c, workspace, and rnn_forward.
// It needs a new workspace. // It needs a new workspace.
index = mkldnn_emitter.reserve_primitive_space(11, true /* new workspace */); index = mkldnn_emitter.reserve_primitive_space(
11, false /* fwd and bwd */, true /* new workspace */);
deps = mkldnn_emitter.get_primitive_deps(index); deps = mkldnn_emitter.get_primitive_deps(index);
CodeWriter writer; CodeWriter writer;
...@@ -1773,7 +1774,8 @@ namespace ngraph ...@@ -1773,7 +1774,8 @@ namespace ngraph
// MaxPoolBackprop needs 6 primitives: fprop_src, diff_dst, diff_src, workspace // MaxPoolBackprop needs 6 primitives: fprop_src, diff_dst, diff_src, workspace
// pooling forward, and pooling_backward. // pooling forward, and pooling_backward.
// It needs a new workspace. // It needs a new workspace.
index = mkldnn_emitter.reserve_primitive_space(6, true /* new workspace */); index = mkldnn_emitter.reserve_primitive_space(
6, true /* fwd and bwd */, true /* new workspace */);
deps = mkldnn_emitter.get_primitive_deps(index); deps = mkldnn_emitter.get_primitive_deps(index);
CodeWriter writer; CodeWriter writer;
...@@ -1832,13 +1834,13 @@ namespace ngraph ...@@ -1832,13 +1834,13 @@ namespace ngraph
writer.block_end(); writer.block_end();
writer << "cg_ctx->mkldnn_workspaces.push_back(workspace);\n"; writer << "cg_ctx->mkldnn_workspaces.push_back(workspace);\n";
deps[5] = mkldnn_emitter.reserve_workspace(); deps[4] = mkldnn_emitter.reserve_workspace();
writer << "\n// build primitive\n"; writer << "\n// build primitive\n";
writer << "cg_ctx->mkldnn_primitives[" << std::to_string(deps[4]) writer << "cg_ctx->mkldnn_primitives[" << std::to_string(index - 1)
<< "] = new mkldnn::pooling_forward(fwd_pd);\n"; << "] = new mkldnn::pooling_forward(fwd_pd);\n";
writer << "cg_ctx->mkldnn_scratchpad_mds[" << std::to_string(deps[4]) writer << "cg_ctx->mkldnn_scratchpad_mds[" << std::to_string(index - 1)
<< "] = new mkldnn::memory::desc(fwd_pd.scratchpad_desc());\n"; << "] = new mkldnn::memory::desc(fwd_pd.scratchpad_desc());\n";
writer << "cg_ctx->mkldnn_primitives[" << std::to_string(index) writer << "cg_ctx->mkldnn_primitives[" << std::to_string(index)
...@@ -2816,7 +2818,6 @@ bool MKLDNNPrimitiveBuildPass::run_on_call_graph(const std::list<std::shared_ptr ...@@ -2816,7 +2818,6 @@ bool MKLDNNPrimitiveBuildPass::run_on_call_graph(const std::list<std::shared_ptr
construct_string, deps, index, scratchpad_size); construct_string, deps, index, scratchpad_size);
} }
} }
return false; return false;
} }
......
...@@ -339,10 +339,7 @@ private: ...@@ -339,10 +339,7 @@ private:
{ {
free(w); free(w);
} }
}
inline void cleanup_mkldnn_descriptors()
{
for (auto d : mkldnn_descriptors) for (auto d : mkldnn_descriptors)
{ {
free(d); free(d);
...@@ -362,14 +359,14 @@ extern "C" void destroy_cg_ctx(CPURuntimeContextCG* cg_ctx) ...@@ -362,14 +359,14 @@ extern "C" void destroy_cg_ctx(CPURuntimeContextCG* cg_ctx)
static void static void
deserialize_memory_descs_and_build_memory(std::ifstream& desc_file, deserialize_memory_descs_and_build_memory(std::ifstream& desc_file,
CPURuntimeContextCG* cg_ctx, CPURuntimeContextCG* cg_ctx,
size_t descs_count) size_t descs_count)
{ {
cg_ctx->mkldnn_descriptors = std::vector<mkldnn::memory::desc*>(descs_count); cg_ctx->mkldnn_descriptors = std::vector<mkldnn::memory::desc*>(descs_count);
for (auto i = 0; i < descs_count; i++) for (auto i = 0; i < descs_count; i++)
{ {
size_t index; size_t index;
desc_file >> index; desc_file >> index;
auto desc = (mkldnn::memory::desc*)malloc(sizeof(mkldnn::memory::desc)); auto desc = (mkldnn::memory::desc*)malloc(sizeof(mkldnn::memory::desc));
if (!desc) if (!desc)
{ {
...@@ -377,8 +374,8 @@ static void ...@@ -377,8 +374,8 @@ static void
} }
desc_file.read(reinterpret_cast<char*>(desc), sizeof(mkldnn::memory::desc)); desc_file.read(reinterpret_cast<char*>(desc), sizeof(mkldnn::memory::desc));
cg_ctx->mkldnn_descriptors[i] = desc; cg_ctx->mkldnn_descriptors[i] = desc;
cg_ctx->mkldnn_memories[index] = new mkldnn::memory(*cg_ctx->mkldnn_descriptors[i], cg_ctx->global_cpu_engine, nullptr); cg_ctx->mkldnn_memories[index] = new mkldnn::memory(*cg_ctx->mkldnn_descriptors[i], cg_ctx->global_cpu_engine, nullptr);
} }
}; };
)" )"
...@@ -1551,14 +1551,16 @@ TEST(cpu_test, max_pool_with_indices_bprop_2d_2channel_2image) ...@@ -1551,14 +1551,16 @@ TEST(cpu_test, max_pool_with_indices_bprop_2d_2channel_2image)
Shape padding_below{0, 0}; Shape padding_below{0, 0};
Shape padding_above{0, 0}; Shape padding_above{0, 0};
auto A = make_shared<op::Parameter>(element::f32, shape_a); auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto max_pool = make_shared<op::MaxPoolWithIndices>(
A, window_shape, window_movement_strides, padding_below, padding_above);
auto indices = make_shared<op::GetOutputElement>(max_pool, 1);
Shape shape_i{2, 2, 4, 3}; Shape shape_i{2, 2, 4, 3};
auto indices = make_shared<op::Parameter>(element::i32, shape_i);
auto delta = make_shared<op::Parameter>(element::f32, shape_i); auto delta = make_shared<op::Parameter>(element::f32, shape_i);
auto max_pool_bprop = make_shared<op::MaxPoolWithIndicesBackprop>( auto max_pool_bprop = make_shared<op::MaxPoolWithIndicesBackprop>(
A, delta, indices, window_shape, window_movement_strides, padding_below, padding_above); A, delta, indices, window_shape, window_movement_strides, padding_below, padding_above);
auto f = make_shared<Function>(max_pool_bprop, ParameterVector{A, delta, indices}); auto f = make_shared<Function>(max_pool_bprop, ParameterVector{A, delta});
auto backend = runtime::Backend::create("CPU"); auto backend = runtime::Backend::create("CPU");
...@@ -1590,29 +1592,6 @@ TEST(cpu_test, max_pool_with_indices_bprop_2d_2channel_2image) ...@@ -1590,29 +1592,6 @@ TEST(cpu_test, max_pool_with_indices_bprop_2d_2channel_2image)
{1, 0, 0, 0, 2}}}}) {1, 0, 0, 0, 2}}}})
.get_vector()); .get_vector());
auto i = backend->create_tensor(element::i32, shape_i);
copy_data(i,
test::NDArray<int, 4>({{{{4, 3, 1}, // img 0 chan 0
{1, 0, 0},
{0, 4, 5},
{0, 3, 2}},
{{5, 4, 3}, // img 0 chan 1
{2, 1, 0},
{3, 1, 2},
{0, 0, 0}}},
{{{1, 0, 3}, // img 1 chan 0
{2, 1, 5},
{3, 5, 2},
{0, 2, 1}},
{{0, 3, 2}, // img 1 chan 1
{1, 0, 3},
{2, 1, 0},
{0, 0, 5}}}})
.get_vector());
auto d = backend->create_tensor(element::f32, shape_i); auto d = backend->create_tensor(element::f32, shape_i);
copy_data(d, copy_data(d,
test::NDArray<float, 4>({{{{0.3f, 0.3f, 0.2f}, // img 0 chan 0 test::NDArray<float, 4>({{{{0.3f, 0.3f, 0.2f}, // img 0 chan 0
...@@ -1639,7 +1618,7 @@ TEST(cpu_test, max_pool_with_indices_bprop_2d_2channel_2image) ...@@ -1639,7 +1618,7 @@ TEST(cpu_test, max_pool_with_indices_bprop_2d_2channel_2image)
auto result = backend->create_tensor(element::f32, shape_a); auto result = backend->create_tensor(element::f32, shape_a);
auto handle = backend->compile(f); auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, d, i}); handle->call_with_validate({result}, {a, d});
EXPECT_TRUE(test::all_close_f((test::NDArray<float, 4>({{{{0, 0, 0, 0.2, 0}, // img 0 chan 0 EXPECT_TRUE(test::all_close_f((test::NDArray<float, 4>({{{{0, 0, 0, 0.2, 0}, // img 0 chan 0
{0, 1.2, 0.2, 0, 0}, {0, 1.2, 0.2, 0, 0},
{0.2, 0, 0, 0, 0}, {0.2, 0, 0, 0, 0},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment