Commit 9a1cbd9d authored by Louis Feng's avatar Louis Feng

added conv+bias backprop filter bias to cpu emitter.

parent 7d0e91be
......@@ -2368,29 +2368,27 @@ namespace ngraph
const vector<size_t>& weights_shape = weights.get_shape();
const vector<size_t>& bias_shape = bias.get_shape();
const vector<size_t>& result_shape = result.get_shape();
const size_t data_rank = data_shape.size();
const size_t weights_rank = weights_shape.size();
const element::Type& elem_type = data.get_element_type();
bool data_dilated = false;
for (size_t s : convolution->get_data_dilation_strides())
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
data_dilated = data_dilated || (s != 1);
}
auto data_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto weights_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1);
auto bias_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
if (!data_dilated && data_rank == 4 && weights_rank == 4 &&
elem_type == element::f32)
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto data_desc = mkldnn_emitter->build_memory_descriptor(
data, mkldnn::memory::format::nchw);
data, data_format);
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
weights, mkldnn::memory::format::oihw);
weights, weights_format);
auto bias_desc = mkldnn_emitter->build_memory_descriptor(
bias, mkldnn::memory::format::x);
bias, bias_format);
auto result_desc = mkldnn_emitter->build_memory_descriptor(
result, mkldnn::memory::format::nchw);
size_t conv_index = 0;
result, result_format);
// For dilation, MKLDNN wants to know how many elements to insert between, not how far
// apart to space the elements like nGraph. So we have to subtract 1 from each pos.
......@@ -2401,7 +2399,7 @@ namespace ngraph
window_dilation_strides_adjusted.push_back(s - 1);
}
conv_index = mkldnn_emitter->build_convolution_forward(
size_t conv_index = mkldnn_emitter->build_convolution_forward(
data_desc,
weights_desc,
bias_desc,
......@@ -2426,7 +2424,70 @@ namespace ngraph
}
else
{
throw ngraph_error("ConvolutionBias does not yet support this layout rank: "+std::to_string(data_rank));
throw ngraph_error("ConvolutionBias is only supported with MKLDNN kernel.");
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ConvolutionBiasBackpropFiltersBias)
{
auto convolution = static_cast<const ngraph::op::ConvolutionBiasBackpropFiltersBias*>(node);
const TensorViewWrapper& data = args[0];
const TensorViewWrapper& delta = args[1];
const TensorViewWrapper& weights_delta = out[0];
const TensorViewWrapper& bias_delta = out[1];
const vector<size_t>& data_shape = data.get_shape();
const vector<size_t>& delta_shape = delta.get_shape();
const vector<size_t>& weights_delta_shape = weights_delta.get_shape();
const vector<size_t>& bias_delta_shape = bias_delta.get_shape();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
const string& elem_type =
runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
data.get_element_type());
Strides window_dilation_strides_adjusted;
for (size_t s : convolution->get_window_dilation_strides_forward())
{
window_dilation_strides_adjusted.push_back(s - 1);
}
auto data_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto delta_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1);
auto weights_delta_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto bias_delta_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 1);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto data_desc = mkldnn_emitter->build_memory_descriptor(args[0], data_format);
auto delta_desc = mkldnn_emitter->build_memory_descriptor(args[1], delta_format);
auto weights_delta_desc = mkldnn_emitter->build_memory_descriptor(args[1], weights_delta_format);
auto bias_delta_desc = mkldnn_emitter->build_memory_descriptor(out[0], bias_delta_format);
size_t conv_index = mkldnn_emitter->build_convolution_backward_filters_bias(
data_desc,
delta_desc,
weights_delta_desc,
bias_delta_desc,
convolution->get_window_movement_strides_forward(),
window_dilation_strides_adjusted,
convolution->get_padding_below_forward(),
convolution->get_padding_above_forward());
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << data.get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << delta.get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ", " << weights_delta.get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ", " << bias_delta.get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(conv_index) << ");\n";
}
else
{
throw ngraph_error("ConvolutionBiasBackpropFiltersBias is only supported with MKLDNN kernel.");
}
}
......
......@@ -83,14 +83,14 @@ namespace ngraph
/**
* Convolution + bias backprop for filters and bias
*/
size_t build_convolution_backward_filters_bias(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& out_weights_delta_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides,
const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above);
size_t build_convolution_backward_filters_bias(const mkldnn::memory::desc &in_data_desc,
const mkldnn::memory::desc &in_delta_desc,
const mkldnn::memory::desc &out_weights_delta_desc,
const mkldnn::memory::desc &out_bias_delta_desc,
const ngraph::Strides &ng_strides,
const ngraph::Strides &ng_dilation_strides,
const ngraph::CoordinateDiff &ng_padding_below,
const ngraph::CoordinateDiff &ng_padding_above);
size_t build_elementwise_add(
const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment