Commit ab810bb5 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Port ConvolutionBackpropData

parent 48e4157a
...@@ -2192,9 +2192,6 @@ namespace ngraph ...@@ -2192,9 +2192,6 @@ namespace ngraph
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
const string& elem_type =
runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
args[0].get_element_type());
Strides window_dilation_strides_adjusted; Strides window_dilation_strides_adjusted;
for (size_t s : convolution->get_window_dilation_strides_forward()) for (size_t s : convolution->get_window_dilation_strides_forward())
...@@ -2202,82 +2199,33 @@ namespace ngraph ...@@ -2202,82 +2199,33 @@ namespace ngraph
window_dilation_strides_adjusted.push_back(s - 1); window_dilation_strides_adjusted.push_back(s - 1);
} }
auto weight_format = auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0); auto weights_desc = mkldnn_emitter->build_memory_descriptor(
auto delta_format = args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1); auto delta_desc = mkldnn_emitter->build_memory_descriptor(
auto result_format = args[1], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1));
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0); auto result_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
auto emit_memory_desc = [&writer](const std::string& var, size_t conv_bwd_data_index = mkldnn_emitter->build_convolution_backward_data(
const std::string& shape, weights_desc,
const std::string& type, delta_desc,
const std::string& layout) { result_desc,
writer << "memory::desc " << var << " = memory::desc({" << shape << "}, " convolution->get_window_movement_strides_forward(),
<< type << ", " << layout << ");\n"; window_dilation_strides_adjusted,
}; convolution->get_padding_below_forward(),
convolution->get_padding_above_forward());
auto emit_memory = [&writer](
const std::string& var, const std::string& desc, const std::string& data) {
writer << "memory " << var << " = memory({" << desc << ", cpu_engine}, "
<< data << ");\n";
};
auto emit_memory_dims = [&writer](const std::string& var,
const std::string& dims) {
writer << "memory::dims " << var << "{" << dims << "};\n";
};
writer.block_begin();
writer << "try\n";
writer.block_begin();
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
emit_memory_desc(
"weight_desc",
join(arg0_shape),
elem_type,
runtime::cpu::mkldnn_utils::get_mkldnn_format_string(weight_format));
emit_memory_desc(
"delta_desc",
join(arg1_shape),
elem_type,
runtime::cpu::mkldnn_utils::get_mkldnn_format_string(delta_format));
emit_memory_desc(
"result_desc",
join(result_shape),
elem_type,
runtime::cpu::mkldnn_utils::get_mkldnn_format_string(result_format));
emit_memory("weight", "weight_desc", args[0].get_name());
emit_memory("delta", "delta_desc", args[1].get_name());
emit_memory("result", "result_desc", out[0].get_name());
emit_memory_dims("dilates", join(window_dilation_strides_adjusted));
emit_memory_dims("strides",
join(convolution->get_window_movement_strides_forward()));
emit_memory_dims("padding_l", join(convolution->get_padding_below_forward()));
emit_memory_dims("padding_r", join(convolution->get_padding_above_forward()));
writer auto& deps = mkldnn_emitter->get_primitive_deps(conv_bwd_data_index);
<< "convolution_backward_data::desc " writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
"bwd_data_desc(algorithm::convolution_direct, " << ", " << args[0].get_name() << ");\n";
"result_desc, weight_desc, delta_desc, strides, dilates, " writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
"padding_l, padding_r, padding_kind::zero);\n" << ", " << args[1].get_name() << ");\n";
"convolution_forward::primitive_desc fwd_pd({prop_kind::forward, " writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
"algorithm::convolution_direct, result_desc, weight_desc, delta_desc, " << ", " << out[0].get_name() << ");\n";
"strides, dilates, padding_l, padding_r, padding_kind::zero}, "
"cpu_engine);\n" writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
"convolution_backward_data::primitive_desc bwd_data_pd(bwd_data_desc, " << to_string(conv_bwd_data_index) << ");\n";
"cpu_engine, fwd_pd);\n"
"convolution_backward_data bwd_data(bwd_data_pd, delta, weight, "
"result);\n"
"stream s = stream(stream::kind::eager);\n"
"s.submit({bwd_data}).wait();\n";
writer.block_end();
writer << "catch (const mkldnn::error& e)\n";
writer.block_begin();
writer << "throw ngraph::ngraph_error(\"MKLDNN ERROR (\" + std::to_string("
"e.status) + \"): \" + e.message);\n";
writer.block_end();
writer.block_end();
} }
else else
{ {
......
...@@ -79,7 +79,6 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu ...@@ -79,7 +79,6 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
const ngraph::Strides& strides, const ngraph::Strides& strides,
const ngraph::CoordinateDiff& padding_below, const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above) const ngraph::CoordinateDiff& padding_above)
{ {
size_t input_data_index = build_memory_primitive(input_data_desc); size_t input_data_index = build_memory_primitive(input_data_desc);
size_t weights_index = build_memory_primitive(weights_desc); size_t weights_index = build_memory_primitive(weights_desc);
...@@ -111,7 +110,6 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu ...@@ -111,7 +110,6 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
const ngraph::Strides& dilation_strides, const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below, const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above) const ngraph::CoordinateDiff& padding_above)
{ {
size_t input_data_index = build_memory_primitive(input_data_desc); size_t input_data_index = build_memory_primitive(input_data_desc);
size_t weights_index = build_memory_primitive(weights_desc); size_t weights_index = build_memory_primitive(weights_desc);
...@@ -145,7 +143,6 @@ size_t ...@@ -145,7 +143,6 @@ size_t
const ngraph::Strides& dilation_strides, const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below, const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above) const ngraph::CoordinateDiff& padding_above)
{ {
size_t input_index = build_memory_primitive(input_desc); size_t input_index = build_memory_primitive(input_desc);
size_t delta_index = build_memory_primitive(delta_desc); size_t delta_index = build_memory_primitive(delta_desc);
...@@ -182,6 +179,49 @@ size_t ...@@ -182,6 +179,49 @@ size_t
return primitive_index; return primitive_index;
} }
size_t MKLDNNEmitter::build_convolution_backward_data(const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides,
const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above)
{
size_t weights_index = build_memory_primitive(weights_desc);
size_t delta_index = build_memory_primitive(delta_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index = insert_primitive(new mkldnn::convolution_backward_data(
{{mkldnn::algorithm::convolution_direct,
result_desc,
weights_desc,
delta_desc,
mkldnn::memory::dims(strides.begin(), strides.end()),
mkldnn::memory::dims(dilation_strides.begin(), dilation_strides.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine,
// Forward primitive descriptor corresponding to this backward data descriptor
{{mkldnn::prop_kind::forward,
mkldnn::algorithm::convolution_direct,
result_desc,
weights_desc,
delta_desc,
mkldnn::memory::dims(strides.begin(), strides.end()),
mkldnn::memory::dims(dilation_strides.begin(), dilation_strides.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine}},
*mkldnn_primitives[delta_index],
*mkldnn_primitives[weights_index],
*mkldnn_primitives[result_index]));
primitive_deps[primitive_index] = {weights_index, delta_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::build_elementwise_add( size_t MKLDNNEmitter::build_elementwise_add(
const mkldnn::memory::desc& input0_data_desc, const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc, const mkldnn::memory::desc& input1_data_desc,
......
...@@ -77,6 +77,14 @@ namespace ngraph ...@@ -77,6 +77,14 @@ namespace ngraph
const ngraph::CoordinateDiff& padding_below, const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above); const ngraph::CoordinateDiff& padding_above);
size_t build_convolution_backward_data(const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides,
const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above);
size_t build_elementwise_add( size_t build_elementwise_add(
const mkldnn::memory::desc& input0_data_desc, const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc, const mkldnn::memory::desc& input1_data_desc,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment