Commit b5e030d9 authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Fix maxpoolbprop layout pass (#4122)

* Fix maxpoolbprop layout pass

* Use default format for max pooling bprop.

* Fix MaxPoolWithIndicesBackprop unit test.

* Fix CODEGEN.

* Modify MaxPoolWithIndicesBackprop unit test.
Co-authored-by: 's avatarAmy Zhuang <amyzhuang97@gmail.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 4ba7a8fa
......@@ -76,8 +76,8 @@ namespace ngraph
// Lstm needs 9 primitives: src_layer, src_iter, weights_layer, weights_iter, bias,
// dst_layer, dst_iter, workspace, and rnn_forward.
// It needs a new workspace.
auto lstm_index =
mkldnn_emitter->reserve_primitive_space(9, true /* new workspace */);
auto lstm_index = mkldnn_emitter->reserve_primitive_space(
9, false /* fwd and bwd */, true /* new workspace */);
auto& deps = mkldnn_emitter->get_primitive_deps(lstm_index);
auto functor = [&,
......@@ -139,8 +139,8 @@ namespace ngraph
// weights_iter, bias,
// dst_layer, dst_iter, dst_iter_c, workspace, and lstm_forward.
// It needs a new workspace.
auto lstm_index =
mkldnn_emitter->reserve_primitive_space(11, true /* new workspace */);
auto lstm_index = mkldnn_emitter->reserve_primitive_space(
11, false /* fwd and bwd */, true /* new workspace */);
auto& deps = mkldnn_emitter->get_primitive_deps(lstm_index);
auto functor = [&,
......
......@@ -155,8 +155,8 @@ namespace ngraph
// MaxPoolBackprop forward needs 4 primitives: fprop_src, diff_src, workspace,
// and pooling_forward.
// It needs a new workspace.
size_t fwd_pool_index =
mkldnn_emitter->reserve_primitive_space(4, true /* new workspace */);
size_t fwd_pool_index = mkldnn_emitter->reserve_primitive_space(
4, false /* fwd and bwd */, true /* new workspace */);
auto& fdeps = mkldnn_emitter->get_primitive_deps(fwd_pool_index);
auto functor_fprop = [&,
......@@ -182,8 +182,8 @@ namespace ngraph
// MaxPoolBackprop backward needs 4 primitives: diff_dst, workspace, diff_src,
// and pooling_backward.
// It needs a new workspace.
size_t bwd_pool_index =
mkldnn_emitter->reserve_primitive_space(4, true /* new workspace */);
size_t bwd_pool_index = mkldnn_emitter->reserve_primitive_space(
4, false /* fwd and bwd */, true /* new workspace */);
auto& bdeps = mkldnn_emitter->get_primitive_deps(bwd_pool_index);
auto functor_bprop = [&, bwd_pool_index, delta_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* /* ectx */) {
......
......@@ -62,8 +62,8 @@ namespace ngraph
// Rnn needs 9 primitives: src_layer, src_iter, weights_layer, weights_iter, bias,
// dst_layer, dst_iter, workspace, and rnn_forward.
// It needs a new workspace.
auto rnn_index =
mkldnn_emitter->reserve_primitive_space(9, true /* new workspace */);
auto rnn_index = mkldnn_emitter->reserve_primitive_space(
9, false /* fwd and bwd */, true /* new workspace */);
auto& deps = mkldnn_emitter->get_primitive_deps(rnn_index);
auto functor = [&,
......@@ -125,8 +125,8 @@ namespace ngraph
// weights_iter, bias,
// dst_layer, dst_iter, dst_iter_c, workspace, and lstm_forward.
// It needs a new workspace.
auto rnn_index =
mkldnn_emitter->reserve_primitive_space(11, true /* new workspace */);
auto rnn_index = mkldnn_emitter->reserve_primitive_space(
11, false /* fwd and bwd */, true /* new workspace */);
auto& deps = mkldnn_emitter->get_primitive_deps(rnn_index);
auto functor = [&,
......
......@@ -3374,17 +3374,17 @@ namespace ngraph
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[2]) << ", "
<< out[0].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[3])
<< ", cg_ctx->mkldnn_workspaces[" << deps[5] << "]);\n";
<< ", cg_ctx->mkldnn_workspaces[" << deps[4] << "]);\n";
writer << "std::vector<size_t> deps{" << join(deps) << "};\n";
writer << "cg_ctx->mkldnn_invoke_primitive(" << to_string(deps[4])
<< ",deps, OpType::MAXPOOLBACKPROPFORWARD, "
writer << "cg_ctx->mkldnn_invoke_primitive(" << to_string(max_pool_index - 1)
<< ", deps, OpType::MAXPOOLBACKPROPFORWARD, "
<< to_string(scratchpad_size) << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[1]) << ", "
<< args[1].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[3])
<< ", cg_ctx->mkldnn_workspaces[" << deps[5] << "]);\n";
<< ", cg_ctx->mkldnn_workspaces[" << deps[4] << "]);\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[2]) << ", "
<< out[0].get_name() << ");\n";
......
......@@ -478,17 +478,31 @@ size_t MKLDNNEmitter::inner_product_forward_init(bool with_bias)
return m_mkldnn_primitives.size() - 1;
}
size_t MKLDNNEmitter::reserve_primitive_space(size_t count, bool new_workspace)
size_t MKLDNNEmitter::reserve_primitive_space(size_t count, bool fwd_bwd, bool new_workspace)
{
size_t size = m_mkldnn_primitives.size();
#if MKLDNN_VERSION_MAJOR >= 1
size_t mem_size = m_mkldnn_memories.size();
m_mkldnn_primitives.resize(size + 1, nullptr);
m_mkldnn_scratchpad_mds.resize(size + 1, nullptr);
m_mkldnn_memories.resize(mem_size + count - 1, nullptr);
for (auto i = 0; i < count - 1; i++)
if (fwd_bwd)
{
m_primitive_deps[m_mkldnn_primitives.size() - 1].push_back(mem_size + i);
m_mkldnn_primitives.resize(size + 2, nullptr);
m_mkldnn_memories.resize(mem_size + count - 2, nullptr);
m_mkldnn_scratchpad_mds.resize(size + 2, nullptr);
for (auto i = 0; i < count - 2; i++)
{
m_primitive_deps[m_mkldnn_primitives.size() - 2].push_back(mem_size + i);
m_primitive_deps[m_mkldnn_primitives.size() - 1].push_back(mem_size + i);
}
}
else
{
m_mkldnn_primitives.resize(size + 1, nullptr);
m_mkldnn_memories.resize(mem_size + count - 1, nullptr);
m_mkldnn_scratchpad_mds.resize(size + 1, nullptr);
for (auto i = 0; i < count - 1; i++)
{
m_primitive_deps[m_mkldnn_primitives.size() - 1].push_back(mem_size + i);
}
}
#else
m_mkldnn_primitives.resize(size + count, nullptr);
......@@ -501,6 +515,10 @@ size_t MKLDNNEmitter::reserve_primitive_space(size_t count, bool new_workspace)
if (new_workspace)
{
m_primitive_deps[m_mkldnn_primitives.size() - 1].push_back(0);
if (fwd_bwd)
{
m_primitive_deps[m_mkldnn_primitives.size() - 2].push_back(0);
}
}
return m_mkldnn_primitives.size() - 1;
}
......
......@@ -139,7 +139,9 @@ namespace ngraph
// reserve the space for primitives for each op, different op requires different
// number of primitives.
// some ops require a new workspace.
size_t reserve_primitive_space(size_t count, bool new_workspace = false);
size_t reserve_primitive_space(size_t count,
bool fwd_bwd = false,
bool new_workspace = false);
size_t insert_primitive(mkldnn::primitive* primitive);
size_t insert_memory(mkldnn::memory* memory);
size_t insert_workspace(std::unique_ptr<MKLDNNWorkspace>& workspace);
......
......@@ -1743,22 +1743,14 @@ namespace ngraph
memory::dims mkldnn_padding_below(padding_below.begin(), padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(), padding_above.end());
auto fprop_input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
#if MKLDNN_VERSION_MAJOR < 1
auto fprop_input_layout =
static_cast<memory::format>(fprop_input_md.data.format);
auto diff_dst_desc = memory::desc(mkldnn_arg1_shape, et, fprop_input_layout);
#else
auto strides = fprop_input_md.data.format_desc.blocking.strides;
memory::dims strides_arg;
for (auto i = 0; i < fprop_input_md.data.ndims; i++)
if (arg0_shape.size() != 4 && arg0_shape.size() != 5)
{
strides_arg.push_back(strides[i]);
throw ngraph_error("MKLDNN Unsupported pooling layout");
}
auto diff_dst_desc = memory::desc(mkldnn_arg1_shape, et, strides_arg);
#endif
auto diff_src_desc = memory::desc(mkldnn_arg0_shape, et, memory::FORMAT::any);
auto default_format = arg0_shape.size() == 4 ? mkldnn::memory::FORMAT::nchw
: mkldnn::memory::FORMAT::ncdhw;
auto diff_dst_desc = memory::desc(mkldnn_arg1_shape, et, default_format);
auto diff_src_desc = memory::desc(mkldnn_arg0_shape, et, default_format);
try
{
......@@ -1784,7 +1776,8 @@ namespace ngraph
auto prim_desc = pooling_backward::primitive_desc(
bwd_desc, executor::global_cpu_engine, fwd_prim_desc);
i_mds.push_back(fprop_input_md);
i_mds.push_back(diff_src_desc);
i_mds.push_back(diff_dst_desc);
if (with_indices)
......@@ -1799,11 +1792,7 @@ namespace ngraph
{
i_mds.push_back(diff_dst_desc);
}
#if MKLDNN_VERSION_MAJOR < 1
o_mds.push_back(prim_desc.diff_src_primitive_desc().desc());
#else
o_mds.push_back(prim_desc.diff_src_desc());
#endif
o_mds.push_back(diff_src_desc);
}
catch (const mkldnn::error& e)
{
......
......@@ -256,7 +256,8 @@ namespace ngraph
// weights_iter, bias,
// dst_layer, dst_iter, dst_iter_c, workspace, and rnn_forward.
// It needs a new workspace.
index = mkldnn_emitter.reserve_primitive_space(11, true /* new workspace */);
index = mkldnn_emitter.reserve_primitive_space(
11, false /* fwd and bwd */, true /* new workspace */);
deps = mkldnn_emitter.get_primitive_deps(index);
CodeWriter writer;
......@@ -1773,7 +1774,8 @@ namespace ngraph
// MaxPoolBackprop needs 6 primitives: fprop_src, diff_dst, diff_src, workspace
// pooling forward, and pooling_backward.
// It needs a new workspace.
index = mkldnn_emitter.reserve_primitive_space(6, true /* new workspace */);
index = mkldnn_emitter.reserve_primitive_space(
6, true /* fwd and bwd */, true /* new workspace */);
deps = mkldnn_emitter.get_primitive_deps(index);
CodeWriter writer;
......@@ -1832,13 +1834,13 @@ namespace ngraph
writer.block_end();
writer << "cg_ctx->mkldnn_workspaces.push_back(workspace);\n";
deps[5] = mkldnn_emitter.reserve_workspace();
deps[4] = mkldnn_emitter.reserve_workspace();
writer << "\n// build primitive\n";
writer << "cg_ctx->mkldnn_primitives[" << std::to_string(deps[4])
writer << "cg_ctx->mkldnn_primitives[" << std::to_string(index - 1)
<< "] = new mkldnn::pooling_forward(fwd_pd);\n";
writer << "cg_ctx->mkldnn_scratchpad_mds[" << std::to_string(deps[4])
writer << "cg_ctx->mkldnn_scratchpad_mds[" << std::to_string(index - 1)
<< "] = new mkldnn::memory::desc(fwd_pd.scratchpad_desc());\n";
writer << "cg_ctx->mkldnn_primitives[" << std::to_string(index)
......@@ -2816,7 +2818,6 @@ bool MKLDNNPrimitiveBuildPass::run_on_call_graph(const std::list<std::shared_ptr
construct_string, deps, index, scratchpad_size);
}
}
return false;
}
......
......@@ -339,10 +339,7 @@ private:
{
free(w);
}
}
inline void cleanup_mkldnn_descriptors()
{
for (auto d : mkldnn_descriptors)
{
free(d);
......@@ -362,14 +359,14 @@ extern "C" void destroy_cg_ctx(CPURuntimeContextCG* cg_ctx)
static void
deserialize_memory_descs_and_build_memory(std::ifstream& desc_file,
CPURuntimeContextCG* cg_ctx,
size_t descs_count)
CPURuntimeContextCG* cg_ctx,
size_t descs_count)
{
cg_ctx->mkldnn_descriptors = std::vector<mkldnn::memory::desc*>(descs_count);
for (auto i = 0; i < descs_count; i++)
{
size_t index;
desc_file >> index;
size_t index;
desc_file >> index;
auto desc = (mkldnn::memory::desc*)malloc(sizeof(mkldnn::memory::desc));
if (!desc)
{
......@@ -377,8 +374,8 @@ static void
}
desc_file.read(reinterpret_cast<char*>(desc), sizeof(mkldnn::memory::desc));
cg_ctx->mkldnn_descriptors[i] = desc;
cg_ctx->mkldnn_memories[index] = new mkldnn::memory(*cg_ctx->mkldnn_descriptors[i], cg_ctx->global_cpu_engine, nullptr);
cg_ctx->mkldnn_descriptors[i] = desc;
cg_ctx->mkldnn_memories[index] = new mkldnn::memory(*cg_ctx->mkldnn_descriptors[i], cg_ctx->global_cpu_engine, nullptr);
}
};
)"
......@@ -1551,14 +1551,16 @@ TEST(cpu_test, max_pool_with_indices_bprop_2d_2channel_2image)
Shape padding_below{0, 0};
Shape padding_above{0, 0};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto max_pool = make_shared<op::MaxPoolWithIndices>(
A, window_shape, window_movement_strides, padding_below, padding_above);
auto indices = make_shared<op::GetOutputElement>(max_pool, 1);
Shape shape_i{2, 2, 4, 3};
auto indices = make_shared<op::Parameter>(element::i32, shape_i);
auto delta = make_shared<op::Parameter>(element::f32, shape_i);
auto max_pool_bprop = make_shared<op::MaxPoolWithIndicesBackprop>(
A, delta, indices, window_shape, window_movement_strides, padding_below, padding_above);
auto f = make_shared<Function>(max_pool_bprop, ParameterVector{A, delta, indices});
auto f = make_shared<Function>(max_pool_bprop, ParameterVector{A, delta});
auto backend = runtime::Backend::create("CPU");
......@@ -1590,29 +1592,6 @@ TEST(cpu_test, max_pool_with_indices_bprop_2d_2channel_2image)
{1, 0, 0, 0, 2}}}})
.get_vector());
auto i = backend->create_tensor(element::i32, shape_i);
copy_data(i,
test::NDArray<int, 4>({{{{4, 3, 1}, // img 0 chan 0
{1, 0, 0},
{0, 4, 5},
{0, 3, 2}},
{{5, 4, 3}, // img 0 chan 1
{2, 1, 0},
{3, 1, 2},
{0, 0, 0}}},
{{{1, 0, 3}, // img 1 chan 0
{2, 1, 5},
{3, 5, 2},
{0, 2, 1}},
{{0, 3, 2}, // img 1 chan 1
{1, 0, 3},
{2, 1, 0},
{0, 0, 5}}}})
.get_vector());
auto d = backend->create_tensor(element::f32, shape_i);
copy_data(d,
test::NDArray<float, 4>({{{{0.3f, 0.3f, 0.2f}, // img 0 chan 0
......@@ -1639,7 +1618,7 @@ TEST(cpu_test, max_pool_with_indices_bprop_2d_2channel_2image)
auto result = backend->create_tensor(element::f32, shape_a);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, d, i});
handle->call_with_validate({result}, {a, d});
EXPECT_TRUE(test::all_close_f((test::NDArray<float, 4>({{{{0, 0, 0, 0.2, 0}, // img 0 chan 0
{0, 1.2, 0.2, 0, 0},
{0.2, 0, 0, 0, 0},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment