Commit 1404a20b authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Dilated convolution and refactor

parent 42e52e42
......@@ -2030,11 +2030,7 @@ namespace ngraph
data_dilated = data_dilated || (s != 1);
}
// TODO(jmenon): MKLDNN streams should be static so we need to either implement
// codegen for statics or move primitive and stream construction out
// of the generated function and only generate code to run/rerun the stream
if (!filter_dilated && !data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
if (!data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
args[0].get_element_type() == element::f32)
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
......@@ -2044,14 +2040,38 @@ namespace ngraph
args[1], mkldnn::memory::format::oihw);
auto result_desc = mkldnn_emitter->build_memory_descriptor(
out[0], mkldnn::memory::format::nchw);
size_t conv_index = 0;
if (!filter_dilated)
{
conv_index = mkldnn_emitter->build_convolution_forward(
input_data_desc,
weights_desc,
result_desc,
convolution->get_window_movement_strides(),
convolution->get_padding_below(),
convolution->get_padding_above());
}
else
{
// For dilation, MKLDNN wants to know how many elements to insert between, not how far
// apart to space the elements like nGraph. So we have to subtract 1 from each pos.
Strides window_dilation_strides_adjusted;
for (size_t s : convolution->get_window_dilation_strides())
{
window_dilation_strides_adjusted.push_back(s - 1);
}
size_t conv_index = mkldnn_emitter->build_convolution_forward(
input_data_desc,
weights_desc,
result_desc,
convolution->get_window_movement_strides(),
convolution->get_padding_below(),
convolution->get_padding_above());
conv_index = mkldnn_emitter->build_convolution_forward(
input_data_desc,
weights_desc,
result_desc,
convolution->get_window_movement_strides(),
window_dilation_strides_adjusted,
convolution->get_padding_below(),
convolution->get_padding_above());
}
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
......@@ -2064,54 +2084,6 @@ namespace ngraph
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(conv_index) << ");\n";
}
else if (filter_dilated && !data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
args[0].get_element_type() == element::f32)
{
// For dilation, MKLDNN wants to know how many elements to insert between, not how far
// apart to space the elements like nGraph. So we have to subtract 1 from each pos.
Strides window_dilation_strides_adjusted;
for (size_t s : convolution->get_window_dilation_strides())
{
window_dilation_strides_adjusted.push_back(s - 1);
}
const string& et =
get_mkldnn_data_type(args[0].get_element_type().c_type_string());
writer << "{\n";
writer.indent++;
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
writer << "memory::desc input_data_desc = memory::desc({" << join(arg0_shape)
<< "}, " << et << ", memory::format::nchw);\n";
writer << "memory::desc weights_desc = memory::desc({" << join(arg1_shape)
<< "}, " << et << ", memory::format::oihw);\n";
writer << "memory::desc result_desc = memory::desc({" << join(result_shape)
<< "}, " << et << ", memory::format::nchw);\n";
writer << "memory input_data = memory({input_data_desc, cpu_engine}, "
<< args[0].get_name() << ");\n";
writer << "memory weights = memory({weights_desc, cpu_engine}, "
<< args[1].get_name() << ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, "
<< out[0].get_name() << ");\n";
writer
<< "convolution_forward conv = convolution_forward({"
<< "{prop_kind::forward, algorithm::convolution_direct, input_data_desc, "
"weights_desc, result_desc, {"
<< join(convolution->get_window_movement_strides()) << "}, {"
<< join(window_dilation_strides_adjusted) << "}, {"
<< join(convolution->get_padding_below()) << "}, {"
<< join(convolution->get_padding_above())
<< "}, padding_kind::zero}, cpu_engine}, "
<< "input_data, weights, result);\n";
writer << "stream s = stream(stream::kind::eager);\n"
<< "s.submit({conv}).wait();\n";
writer.indent--;
writer << "}\n";
}
else
{
writer << "kernel::convolution<" << out[0].get_type() << ">("
......
......@@ -100,3 +100,36 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
primitive_deps[conv_index] = {input_data_index, weights_index, result_index};
return conv_index;
}
size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides,
const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above)
{
size_t input_data_index = build_memory_primitive(input_data_desc);
size_t weights_index = build_memory_primitive(weights_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t conv_index = insert_primitive(new mkldnn::convolution_forward(
{{mkldnn::prop_kind::forward,
mkldnn::algorithm::convolution_direct,
input_data_desc,
weights_desc,
result_desc,
mkldnn::memory::dims(strides.begin(), strides.end()),
mkldnn::memory::dims(dilation_strides.begin(), dilation_strides.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine},
*mkldnn_primitives[input_data_index],
*mkldnn_primitives[weights_index],
*mkldnn_primitives[result_index]));
primitive_deps[conv_index] = {input_data_index, weights_index, result_index};
return conv_index;
}
......@@ -58,6 +58,14 @@ namespace ngraph
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above);
size_t build_convolution_forward(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides,
const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above);
private:
std::shared_ptr<CPU_ExternalFunction> external_function;
std::vector<mkldnn::primitive*> mkldnn_primitives;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment