Commit b1239af4 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

dex group convolution (#1297)

parent 1011f6c7
......@@ -21,6 +21,7 @@
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp"
using namespace std;
using namespace ngraph;
......@@ -404,6 +405,125 @@ namespace ngraph
}
}
template <>
void Builder::BUILDER_DECL(ngraph::op::GroupConvolution)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& arg1_tensor = tensor_data[args[1].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
auto convolution = static_cast<const ngraph::op::GroupConvolution*>(node);
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
Strides window_dilation_strides_adjusted;
for (size_t s : convolution->get_window_dilation_strides())
{
window_dilation_strides_adjusted.push_back(s - 1);
}
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto output_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_data_desc =
mkldnn_emitter->build_memory_descriptor(args[0], input_format);
Shape weights_shape_groups = convolution->get_weights_dimensions();
auto weights_desc_any = mkldnn::memory::desc(
mkldnn::memory::dims(weights_shape_groups.begin(),
weights_shape_groups.end()),
mkldnn_utils::get_mkldnn_data_type(args[1].get_element_type()),
mkldnn::memory::format::any);
auto padding_below = convolution->get_padding_below();
auto padding_above = convolution->get_padding_above();
auto filter_strides = convolution->get_window_movement_strides();
auto result_desc =
mkldnn_emitter->build_memory_descriptor(out[0], output_format);
auto weights_optimized_format =
mkldnn_emitter->query_convolution_forward_weight_format(
input_data_desc,
weights_desc_any,
result_desc,
filter_strides,
window_dilation_strides_adjusted,
padding_below,
padding_above);
//create workspace for holding the result of converting weights layouts
auto ws = std::unique_ptr<MKLDNNWorkspace>(new MKLDNNWorkspace(
shape_size(args[1].get_shape()) * args[1].get_element_type().size()));
auto ws_buf_index = mkldnn_emitter->insert_workspace(ws);
//descriptors for reorder operation
auto input_reorder_desc =
mkldnn_emitter->build_memory_descriptor(weights_shape_groups,
args[1].get_element_type(),
mkldnn::memory::format::goihw);
auto result_reorder_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape_groups, args[1].get_element_type(), weights_optimized_format);
auto weights_desc = mkldnn::memory::desc(
mkldnn::memory::dims(weights_shape_groups.begin(),
weights_shape_groups.end()),
mkldnn_utils::get_mkldnn_data_type(args[1].get_element_type()),
weights_optimized_format);
auto prim_indices = mkldnn_emitter->build_group_convolution_forward(
input_reorder_desc, //weights
input_data_desc,
weights_desc,
result_reorder_desc,
result_desc,
convolution->get_window_movement_strides(),
window_dilation_strides_adjusted,
padding_below,
padding_above);
size_t reorder_index = prim_indices.first;
auto& reorder_deps = mkldnn_emitter->get_primitive_deps(reorder_index);
size_t conv_index = prim_indices.second;
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
auto functor =
[&, conv_index, reorder_index, ws_buf_index](CPURuntimeContext* ctx) {
//reorder
cpu::mkldnn_utils::set_memory_ptr(ctx, reorder_deps[0], arg1_tensor);
cpu::mkldnn_utils::set_memory_ptr(
ctx, reorder_deps[1], ctx->mkldnn_workspaces[ws_buf_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, reorder_index);
//group convolution
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg0_tensor);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->mkldnn_workspaces[ws_buf_index]);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[2], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, conv_index);
};
functors.emplace_back(functor);
}
else
{
throw ngraph_error("unsupported parameters for GroupConvolution");
}
}
REGISTER_OP_BUILDER(Convolution);
REGISTER_OP_BUILDER(ConvolutionRelu);
REGISTER_OP_BUILDER(ConvolutionBias);
......@@ -411,6 +531,7 @@ namespace ngraph
REGISTER_OP_BUILDER(ConvolutionBackpropData);
REGISTER_OP_BUILDER(ConvolutionBackpropFilters);
REGISTER_OP_BUILDER(ConvolutionBiasBackpropFiltersBias);
REGISTER_OP_BUILDER(GroupConvolution);
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment