Commit 124d48ba authored by Louis Feng's avatar Louis Feng

added conv+bias to cpu layout pass.

parent 63bcf1a2
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp" #include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp" #include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/ops/conv_bias.hpp"
#include "ngraph/runtime/cpu/ops/convert_layout.hpp" #include "ngraph/runtime/cpu/ops/convert_layout.hpp"
using namespace std; using namespace std;
...@@ -219,12 +220,12 @@ namespace ngraph ...@@ -219,12 +220,12 @@ namespace ngraph
{ {
namespace pass namespace pass
{ {
template <> template <typename T, bool use_bias>
void CPULayout::LAYOUT_DECL(ngraph::op::Convolution) void ConvolutionLayout(std::shared_ptr<ngraph::Node> node,
vector<memory::format>& prim_input_formats,
vector<memory::format>& prim_output_formats)
{ {
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get())) auto convolution = static_cast<const T*>(node.get());
{
auto convolution = static_cast<const ngraph::op::Convolution*>(node.get());
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1); auto arg1_shape = node->get_input_shape(1);
...@@ -249,38 +250,88 @@ namespace ngraph ...@@ -249,38 +250,88 @@ namespace ngraph
memory::dims mkldnn_result_shape(result_shape.begin(), result_shape.end()); memory::dims mkldnn_result_shape(result_shape.begin(), result_shape.end());
memory::dims mkldnn_filter_strides(filter_strides.begin(), memory::dims mkldnn_filter_strides(filter_strides.begin(),
filter_strides.end()); filter_strides.end());
memory::dims mkldnn_dilated_strides( memory::dims mkldnn_dilated_strides(window_dilation_strides_adjusted.begin(),
window_dilation_strides_adjusted.begin(),
window_dilation_strides_adjusted.end()); window_dilation_strides_adjusted.end());
memory::dims mkldnn_padding_below(padding_below.begin(), memory::dims mkldnn_padding_below(padding_below.begin(), padding_below.end());
padding_below.end()); memory::dims mkldnn_padding_above(padding_above.begin(), padding_above.end());
memory::dims mkldnn_padding_above(padding_above.begin(), const memory::desc input_data_desc(mkldnn_arg0_shape, et, memory::format::any);
padding_above.end());
const memory::desc input_data_desc(
mkldnn_arg0_shape, et, memory::format::any);
const memory::desc weights_desc(mkldnn_arg1_shape, et, memory::format::any); const memory::desc weights_desc(mkldnn_arg1_shape, et, memory::format::any);
const memory::desc result_desc( const memory::desc result_desc(mkldnn_result_shape, et, memory::format::any);
mkldnn_result_shape, et, memory::format::any); std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr};
convolution_forward::desc fwd_desc(prop_kind::forward, if (use_bias)
{
auto arg2_shape = node->get_input_shape(2);
memory::dims mkldnn_arg2_shape(arg2_shape.begin(), arg2_shape.end());
const memory::desc bias_desc(mkldnn_arg2_shape, et, memory::format::any);
fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
algorithm::convolution_direct, algorithm::convolution_direct,
input_data_desc, input_data_desc,
weights_desc, weights_desc,
bias_desc, // with bias
result_desc, result_desc,
mkldnn_filter_strides, mkldnn_filter_strides,
mkldnn_dilated_strides, mkldnn_dilated_strides,
mkldnn_padding_below, mkldnn_padding_below,
mkldnn_padding_above, mkldnn_padding_above,
padding_kind::zero); padding_kind::zero));
convolution_forward::primitive_desc prim_desc(fwd_desc, cpu_engine); }
vector<memory::format> prim_input_formats; else
vector<memory::format> prim_output_formats; {
fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
algorithm::convolution_direct,
input_data_desc,
weights_desc,
result_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero));
}
convolution_forward::primitive_desc prim_desc(*fwd_desc, cpu_engine);
prim_input_formats.push_back(static_cast<memory::format>( prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.src_primitive_desc().desc().data.format)); prim_desc.src_primitive_desc().desc().data.format));
prim_input_formats.push_back(static_cast<memory::format>( prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.weights_primitive_desc().desc().data.format)); prim_desc.weights_primitive_desc().desc().data.format));
if (use_bias)
{
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.bias_primitive_desc().desc().data.format));
}
prim_output_formats.push_back(static_cast<memory::format>( prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.dst_primitive_desc().desc().data.format)); prim_desc.dst_primitive_desc().desc().data.format));
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Convolution)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
ConvolutionLayout<ngraph::op::Convolution, false>(
node, prim_input_formats, prim_output_formats);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBias)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
ConvolutionLayout<ngraph::op::ConvolutionBias, true>(
node, prim_input_formats, prim_output_formats);
node = node =
insert_input_conversions(external_function, node, prim_input_formats); insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats); set_output_layouts(node, prim_output_formats);
...@@ -379,17 +430,16 @@ namespace ngraph ...@@ -379,17 +430,16 @@ namespace ngraph
} }
} }
template <> template <typename T, bool use_bias>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBackpropFilters) void ConvolutionBackpropFiltersLayout(std::shared_ptr<ngraph::Node> node,
vector<memory::format>& prim_input_formats,
vector<memory::format>& prim_output_formats)
{ {
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get())) auto convolution = static_cast<const T*>(node.get());
{
auto convolution =
static_cast<const ngraph::op::ConvolutionBackpropFilters*>(node.get());
auto arg0_shape = node->get_input_shape(0); auto data_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1); auto delta_shape = node->get_input_shape(1);
auto result_shape = node->get_output_shape(0); auto filters_shape = node->get_output_shape(0);
auto filter_strides = convolution->get_window_movement_strides_forward(); auto filter_strides = convolution->get_window_movement_strides_forward();
auto padding_below = convolution->get_padding_below_forward(); auto padding_below = convolution->get_padding_below_forward();
auto padding_above = convolution->get_padding_above_forward(); auto padding_above = convolution->get_padding_above_forward();
...@@ -405,57 +455,123 @@ namespace ngraph ...@@ -405,57 +455,123 @@ namespace ngraph
node->get_input_element_type(0)); node->get_input_element_type(0));
engine cpu_engine(engine::cpu, 0); engine cpu_engine(engine::cpu, 0);
memory::dims mkldnn_arg0_shape(arg0_shape.begin(), arg0_shape.end()); memory::dims mkldnn_data_shape(data_shape.begin(), data_shape.end());
memory::dims mkldnn_arg1_shape(arg1_shape.begin(), arg1_shape.end()); memory::dims mkldnn_delta_shape(delta_shape.begin(), delta_shape.end());
memory::dims mkldnn_result_shape(result_shape.begin(), result_shape.end()); memory::dims mkldnn_filters_shape(filters_shape.begin(), filters_shape.end());
memory::dims mkldnn_filter_strides(filter_strides.begin(), memory::dims mkldnn_filter_strides(filter_strides.begin(),
filter_strides.end()); filter_strides.end());
memory::dims mkldnn_dilated_strides( memory::dims mkldnn_dilated_strides(window_dilation_strides_adjusted.begin(),
window_dilation_strides_adjusted.begin(),
window_dilation_strides_adjusted.end()); window_dilation_strides_adjusted.end());
memory::dims mkldnn_padding_below(padding_below.begin(), memory::dims mkldnn_padding_below(padding_below.begin(), padding_below.end());
padding_below.end()); memory::dims mkldnn_padding_above(padding_above.begin(), padding_above.end());
memory::dims mkldnn_padding_above(padding_above.begin(),
padding_above.end()); const memory::desc data_desc(mkldnn_data_shape, et, memory::format::any);
const memory::desc delta_desc(mkldnn_delta_shape, et, memory::format::any);
const memory::desc data_desc(mkldnn_arg0_shape, et, memory::format::any); const memory::desc filters_desc(mkldnn_filters_shape, et, memory::format::any);
const memory::desc delta_desc(mkldnn_arg1_shape, et, memory::format::any);
const memory::desc result_desc( std::unique_ptr<convolution_backward_weights::desc> bwd_desc{nullptr};
mkldnn_result_shape, et, memory::format::any); std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr};
if (use_bias)
convolution_backward_weights::desc bwd_desc(algorithm::convolution_direct, {
auto bias_shape = node->get_output_shape(1);
memory::dims mkldnn_bias_shape(bias_shape.begin(), filters_shape.end());
const memory::desc bias_desc(mkldnn_bias_shape, et, memory::format::any);
bwd_desc.reset(
new convolution_backward_weights::desc(algorithm::convolution_direct,
data_desc, data_desc,
result_desc, filters_desc,
bias_desc,
delta_desc, delta_desc,
mkldnn_filter_strides, mkldnn_filter_strides,
mkldnn_dilated_strides, mkldnn_dilated_strides,
mkldnn_padding_below, mkldnn_padding_below,
mkldnn_padding_above, mkldnn_padding_above,
padding_kind::zero); padding_kind::zero));
convolution_forward::desc fwd_desc(prop_kind::forward, fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
algorithm::convolution_direct, algorithm::convolution_direct,
data_desc, data_desc,
result_desc, filters_desc,
bias_desc,
delta_desc, delta_desc,
mkldnn_filter_strides, mkldnn_filter_strides,
mkldnn_dilated_strides, mkldnn_dilated_strides,
mkldnn_padding_below, mkldnn_padding_below,
mkldnn_padding_above, mkldnn_padding_above,
padding_kind::zero); padding_kind::zero));
convolution_forward::primitive_desc fwd_prim_desc(fwd_desc, cpu_engine); }
else
{
bwd_desc.reset(
new convolution_backward_weights::desc(algorithm::convolution_direct,
data_desc,
filters_desc,
delta_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero));
fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
algorithm::convolution_direct,
data_desc,
filters_desc,
delta_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero));
}
convolution_forward::primitive_desc fwd_prim_desc(*fwd_desc, cpu_engine);
convolution_backward_weights::primitive_desc prim_desc( convolution_backward_weights::primitive_desc prim_desc(
bwd_desc, cpu_engine, fwd_prim_desc); *bwd_desc, cpu_engine, fwd_prim_desc);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(static_cast<memory::format>( prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.src_primitive_desc().desc().data.format)); prim_desc.src_primitive_desc().desc().data.format));
prim_input_formats.push_back(static_cast<memory::format>( prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.diff_dst_primitive_desc().desc().data.format)); prim_desc.diff_dst_primitive_desc().desc().data.format));
prim_output_formats.push_back(static_cast<memory::format>( prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.diff_weights_primitive_desc().desc().data.format)); prim_desc.diff_weights_primitive_desc().desc().data.format));
if (use_bias)
{
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.diff_bias_primitive_desc().desc().data.format));
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBackpropFilters)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
ConvolutionBackpropFiltersLayout<ngraph::op::ConvolutionBackpropFilters,
false>(
node, prim_input_formats, prim_output_formats);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBiasBackpropFiltersBias)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
ConvolutionBackpropFiltersLayout<
ngraph::op::ConvolutionBiasBackpropFiltersBias,
false>(node, prim_input_formats, prim_output_formats);
node = node =
insert_input_conversions(external_function, node, prim_input_formats); insert_input_conversions(external_function, node, prim_input_formats);
...@@ -715,6 +831,10 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{ ...@@ -715,6 +831,10 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropData>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropData>},
{TI(ngraph::op::ConvolutionBackpropFilters), {TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropFilters>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropFilters>},
{TI(ngraph::op::ConvolutionBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBias>},
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBiasBackpropFiltersBias>},
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>}, {TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop), {TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment