Commit 124d48ba authored by Louis Feng's avatar Louis Feng

added conv+bias to cpu layout pass.

parent 63bcf1a2
......@@ -35,6 +35,7 @@
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/ops/conv_bias.hpp"
#include "ngraph/runtime/cpu/ops/convert_layout.hpp"
using namespace std;
......@@ -219,12 +220,12 @@ namespace ngraph
{
namespace pass
{
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Convolution)
template <typename T, bool use_bias>
void ConvolutionLayout(std::shared_ptr<ngraph::Node> node,
vector<memory::format>& prim_input_formats,
vector<memory::format>& prim_output_formats)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto convolution = static_cast<const ngraph::op::Convolution*>(node.get());
auto convolution = static_cast<const T*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1);
......@@ -249,38 +250,88 @@ namespace ngraph
memory::dims mkldnn_result_shape(result_shape.begin(), result_shape.end());
memory::dims mkldnn_filter_strides(filter_strides.begin(),
filter_strides.end());
memory::dims mkldnn_dilated_strides(
window_dilation_strides_adjusted.begin(),
memory::dims mkldnn_dilated_strides(window_dilation_strides_adjusted.begin(),
window_dilation_strides_adjusted.end());
memory::dims mkldnn_padding_below(padding_below.begin(),
padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(),
padding_above.end());
const memory::desc input_data_desc(
mkldnn_arg0_shape, et, memory::format::any);
memory::dims mkldnn_padding_below(padding_below.begin(), padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(), padding_above.end());
const memory::desc input_data_desc(mkldnn_arg0_shape, et, memory::format::any);
const memory::desc weights_desc(mkldnn_arg1_shape, et, memory::format::any);
const memory::desc result_desc(
mkldnn_result_shape, et, memory::format::any);
convolution_forward::desc fwd_desc(prop_kind::forward,
const memory::desc result_desc(mkldnn_result_shape, et, memory::format::any);
std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr};
if (use_bias)
{
auto arg2_shape = node->get_input_shape(2);
memory::dims mkldnn_arg2_shape(arg2_shape.begin(), arg2_shape.end());
const memory::desc bias_desc(mkldnn_arg2_shape, et, memory::format::any);
fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
algorithm::convolution_direct,
input_data_desc,
weights_desc,
bias_desc, // with bias
result_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero);
convolution_forward::primitive_desc prim_desc(fwd_desc, cpu_engine);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
padding_kind::zero));
}
else
{
fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
algorithm::convolution_direct,
input_data_desc,
weights_desc,
result_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero));
}
convolution_forward::primitive_desc prim_desc(*fwd_desc, cpu_engine);
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.src_primitive_desc().desc().data.format));
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.weights_primitive_desc().desc().data.format));
if (use_bias)
{
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.bias_primitive_desc().desc().data.format));
}
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.dst_primitive_desc().desc().data.format));
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Convolution)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
ConvolutionLayout<ngraph::op::Convolution, false>(
node, prim_input_formats, prim_output_formats);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBias)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
ConvolutionLayout<ngraph::op::ConvolutionBias, true>(
node, prim_input_formats, prim_output_formats);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
......@@ -379,17 +430,16 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBackpropFilters)
template <typename T, bool use_bias>
void ConvolutionBackpropFiltersLayout(std::shared_ptr<ngraph::Node> node,
vector<memory::format>& prim_input_formats,
vector<memory::format>& prim_output_formats)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto convolution =
static_cast<const ngraph::op::ConvolutionBackpropFilters*>(node.get());
auto convolution = static_cast<const T*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1);
auto result_shape = node->get_output_shape(0);
auto data_shape = node->get_input_shape(0);
auto delta_shape = node->get_input_shape(1);
auto filters_shape = node->get_output_shape(0);
auto filter_strides = convolution->get_window_movement_strides_forward();
auto padding_below = convolution->get_padding_below_forward();
auto padding_above = convolution->get_padding_above_forward();
......@@ -405,57 +455,123 @@ namespace ngraph
node->get_input_element_type(0));
engine cpu_engine(engine::cpu, 0);
memory::dims mkldnn_arg0_shape(arg0_shape.begin(), arg0_shape.end());
memory::dims mkldnn_arg1_shape(arg1_shape.begin(), arg1_shape.end());
memory::dims mkldnn_result_shape(result_shape.begin(), result_shape.end());
memory::dims mkldnn_data_shape(data_shape.begin(), data_shape.end());
memory::dims mkldnn_delta_shape(delta_shape.begin(), delta_shape.end());
memory::dims mkldnn_filters_shape(filters_shape.begin(), filters_shape.end());
memory::dims mkldnn_filter_strides(filter_strides.begin(),
filter_strides.end());
memory::dims mkldnn_dilated_strides(
window_dilation_strides_adjusted.begin(),
memory::dims mkldnn_dilated_strides(window_dilation_strides_adjusted.begin(),
window_dilation_strides_adjusted.end());
memory::dims mkldnn_padding_below(padding_below.begin(),
padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(),
padding_above.end());
const memory::desc data_desc(mkldnn_arg0_shape, et, memory::format::any);
const memory::desc delta_desc(mkldnn_arg1_shape, et, memory::format::any);
const memory::desc result_desc(
mkldnn_result_shape, et, memory::format::any);
convolution_backward_weights::desc bwd_desc(algorithm::convolution_direct,
memory::dims mkldnn_padding_below(padding_below.begin(), padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(), padding_above.end());
const memory::desc data_desc(mkldnn_data_shape, et, memory::format::any);
const memory::desc delta_desc(mkldnn_delta_shape, et, memory::format::any);
const memory::desc filters_desc(mkldnn_filters_shape, et, memory::format::any);
std::unique_ptr<convolution_backward_weights::desc> bwd_desc{nullptr};
std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr};
if (use_bias)
{
auto bias_shape = node->get_output_shape(1);
memory::dims mkldnn_bias_shape(bias_shape.begin(), filters_shape.end());
const memory::desc bias_desc(mkldnn_bias_shape, et, memory::format::any);
bwd_desc.reset(
new convolution_backward_weights::desc(algorithm::convolution_direct,
data_desc,
result_desc,
filters_desc,
bias_desc,
delta_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero);
padding_kind::zero));
convolution_forward::desc fwd_desc(prop_kind::forward,
fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
algorithm::convolution_direct,
data_desc,
result_desc,
filters_desc,
bias_desc,
delta_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero);
convolution_forward::primitive_desc fwd_prim_desc(fwd_desc, cpu_engine);
padding_kind::zero));
}
else
{
bwd_desc.reset(
new convolution_backward_weights::desc(algorithm::convolution_direct,
data_desc,
filters_desc,
delta_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero));
fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
algorithm::convolution_direct,
data_desc,
filters_desc,
delta_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero));
}
convolution_forward::primitive_desc fwd_prim_desc(*fwd_desc, cpu_engine);
convolution_backward_weights::primitive_desc prim_desc(
bwd_desc, cpu_engine, fwd_prim_desc);
*bwd_desc, cpu_engine, fwd_prim_desc);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.src_primitive_desc().desc().data.format));
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.diff_dst_primitive_desc().desc().data.format));
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.diff_weights_primitive_desc().desc().data.format));
if (use_bias)
{
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.diff_bias_primitive_desc().desc().data.format));
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBackpropFilters)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
ConvolutionBackpropFiltersLayout<ngraph::op::ConvolutionBackpropFilters,
false>(
node, prim_input_formats, prim_output_formats);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBiasBackpropFiltersBias)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
ConvolutionBackpropFiltersLayout<
ngraph::op::ConvolutionBiasBackpropFiltersBias,
false>(node, prim_input_formats, prim_output_formats);
node =
insert_input_conversions(external_function, node, prim_input_formats);
......@@ -715,6 +831,10 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropData>},
{TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropFilters>},
{TI(ngraph::op::ConvolutionBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBias>},
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBiasBackpropFiltersBias>},
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment