Commit d861ba32 authored by Pruthvi's avatar Pruthvi Committed by Robert Kimball

- Added workspace for rnn fprop kernel (#1153)

- fixes segfault issue for GNMT model execution through ngraph-mxnet
parent aa36865c
......@@ -564,6 +564,8 @@ namespace ngraph
<< out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[6]) << ", "
<< out[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[7])
<< ", ctx->mkldnn_workspaces[" << deps[8] << "]);\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(lstm_index) << ");\n";
......@@ -665,7 +667,8 @@ namespace ngraph
<< out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[6]) << ", "
<< out[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[7])
<< ", ctx->mkldnn_workspaces[" << deps[8] << "]);\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, " << to_string(rnn_index)
<< ");\n";
}
......
......@@ -837,11 +837,8 @@ size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_de
size_t dst_layer_index = build_memory_primitive(dst_layer_desc);
size_t dst_iter_index = build_memory_primitive(dst_iter_desc);
//TODO: figure our the role of workspace
auto null_memory_ = mkldnn::null_memory(mkldnn_utils::global_cpu_engine);
mkldnn::rnn_cell::desc rnn_cell(mkldnn::algorithm::vanilla_lstm);
mkldnn::rnn_forward::desc rnn_layer_desc(mkldnn::prop_kind::forward_inference,
mkldnn::rnn_forward::desc rnn_layer_desc(mkldnn::prop_kind::forward_training,
rnn_cell,
mkldnn::rnn_direction::unidirectional_left2right,
src_layer_desc,
......@@ -853,8 +850,13 @@ size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_de
dst_iter_desc);
auto rnn_layer_prim_desc =
mkldnn::rnn_forward::primitive_desc(rnn_layer_desc, mkldnn_utils::global_cpu_engine);
size_t rnn_index = insert_primitive(
new mkldnn::rnn_forward(rnn_layer_prim_desc,
auto workspace_index =
build_memory_primitive(rnn_layer_prim_desc.workspace_primitive_desc().desc());
auto workspace = std::unique_ptr<MKLDNNWorkspace>(
new MKLDNNWorkspace(rnn_layer_prim_desc.workspace_primitive_desc().get_size()));
auto workspace_buf_index = insert_workspace(workspace);
size_t rnn_index = insert_primitive(new mkldnn::rnn_forward(
rnn_layer_prim_desc,
mkldnn::primitive::at(*m_mkldnn_primitives[src_layer_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[src_iter_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[weights_layer_index]),
......@@ -862,14 +864,16 @@ size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_de
mkldnn::primitive::at(*m_mkldnn_primitives[bias_index]),
static_cast<mkldnn::memory>(*m_mkldnn_primitives[dst_layer_index]),
static_cast<mkldnn::memory>(*m_mkldnn_primitives[dst_iter_index]),
static_cast<mkldnn::memory>(null_memory_)));
static_cast<mkldnn::memory>(*m_mkldnn_primitives[workspace_index])));
m_primitive_deps[rnn_index] = {src_layer_index,
src_iter_index,
weights_layer_index,
weights_iter_index,
bias_index,
dst_layer_index,
dst_iter_index};
dst_iter_index,
workspace_index,
workspace_buf_index};
return rnn_index;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment