Commit 124d48ba authored by Louis Feng's avatar Louis Feng

added conv+bias to cpu layout pass.

parent 63bcf1a2
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp" #include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp" #include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/ops/conv_bias.hpp"
#include "ngraph/runtime/cpu/ops/convert_layout.hpp" #include "ngraph/runtime/cpu/ops/convert_layout.hpp"
using namespace std; using namespace std;
...@@ -219,67 +220,98 @@ namespace ngraph ...@@ -219,67 +220,98 @@ namespace ngraph
{ {
namespace pass namespace pass
{ {
template <> template <typename T, bool use_bias>
void CPULayout::LAYOUT_DECL(ngraph::op::Convolution) void ConvolutionLayout(std::shared_ptr<ngraph::Node> node,
vector<memory::format>& prim_input_formats,
vector<memory::format>& prim_output_formats)
{ {
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get())) auto convolution = static_cast<const T*>(node.get());
{
auto convolution = static_cast<const ngraph::op::Convolution*>(node.get());
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1); auto arg1_shape = node->get_input_shape(1);
auto result_shape = node->get_output_shape(0); auto result_shape = node->get_output_shape(0);
auto filter_strides = convolution->get_window_movement_strides(); auto filter_strides = convolution->get_window_movement_strides();
auto padding_below = convolution->get_padding_below(); auto padding_below = convolution->get_padding_below();
auto padding_above = convolution->get_padding_above(); auto padding_above = convolution->get_padding_above();
Strides window_dilation_strides_adjusted; Strides window_dilation_strides_adjusted;
for (size_t s : convolution->get_window_dilation_strides()) for (size_t s : convolution->get_window_dilation_strides())
{ {
window_dilation_strides_adjusted.push_back(s - 1); window_dilation_strides_adjusted.push_back(s - 1);
} }
memory::data_type et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type( memory::data_type et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(
node->get_input_element_type(0)); node->get_input_element_type(0));
engine cpu_engine(engine::cpu, 0);
memory::dims mkldnn_arg0_shape(arg0_shape.begin(), arg0_shape.end());
memory::dims mkldnn_arg1_shape(arg1_shape.begin(), arg1_shape.end());
memory::dims mkldnn_result_shape(result_shape.begin(), result_shape.end());
memory::dims mkldnn_filter_strides(filter_strides.begin(),
filter_strides.end());
memory::dims mkldnn_dilated_strides(window_dilation_strides_adjusted.begin(),
window_dilation_strides_adjusted.end());
memory::dims mkldnn_padding_below(padding_below.begin(), padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(), padding_above.end());
const memory::desc input_data_desc(mkldnn_arg0_shape, et, memory::format::any);
const memory::desc weights_desc(mkldnn_arg1_shape, et, memory::format::any);
const memory::desc result_desc(mkldnn_result_shape, et, memory::format::any);
std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr};
if (use_bias)
{
auto arg2_shape = node->get_input_shape(2);
memory::dims mkldnn_arg2_shape(arg2_shape.begin(), arg2_shape.end());
const memory::desc bias_desc(mkldnn_arg2_shape, et, memory::format::any);
fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
algorithm::convolution_direct,
input_data_desc,
weights_desc,
bias_desc, // with bias
result_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero));
}
else
{
fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
algorithm::convolution_direct,
input_data_desc,
weights_desc,
result_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero));
}
convolution_forward::primitive_desc prim_desc(*fwd_desc, cpu_engine);
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.src_primitive_desc().desc().data.format));
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.weights_primitive_desc().desc().data.format));
if (use_bias)
{
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.bias_primitive_desc().desc().data.format));
}
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.dst_primitive_desc().desc().data.format));
}
engine cpu_engine(engine::cpu, 0); template <>
memory::dims mkldnn_arg0_shape(arg0_shape.begin(), arg0_shape.end()); void CPULayout::LAYOUT_DECL(ngraph::op::Convolution)
memory::dims mkldnn_arg1_shape(arg1_shape.begin(), arg1_shape.end()); {
memory::dims mkldnn_result_shape(result_shape.begin(), result_shape.end()); if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
memory::dims mkldnn_filter_strides(filter_strides.begin(), {
filter_strides.end());
memory::dims mkldnn_dilated_strides(
window_dilation_strides_adjusted.begin(),
window_dilation_strides_adjusted.end());
memory::dims mkldnn_padding_below(padding_below.begin(),
padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(),
padding_above.end());
const memory::desc input_data_desc(
mkldnn_arg0_shape, et, memory::format::any);
const memory::desc weights_desc(mkldnn_arg1_shape, et, memory::format::any);
const memory::desc result_desc(
mkldnn_result_shape, et, memory::format::any);
convolution_forward::desc fwd_desc(prop_kind::forward,
algorithm::convolution_direct,
input_data_desc,
weights_desc,
result_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero);
convolution_forward::primitive_desc prim_desc(fwd_desc, cpu_engine);
vector<memory::format> prim_input_formats; vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats; vector<memory::format> prim_output_formats;
prim_input_formats.push_back(static_cast<memory::format>( ConvolutionLayout<ngraph::op::Convolution, false>(
prim_desc.src_primitive_desc().desc().data.format)); node, prim_input_formats, prim_output_formats);
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.weights_primitive_desc().desc().data.format));
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.dst_primitive_desc().desc().data.format));
node = node =
insert_input_conversions(external_function, node, prim_input_formats); insert_input_conversions(external_function, node, prim_input_formats);
...@@ -291,6 +323,25 @@ namespace ngraph ...@@ -291,6 +323,25 @@ namespace ngraph
} }
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBias)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
ConvolutionLayout<ngraph::op::ConvolutionBias, true>(
node, prim_input_formats, prim_output_formats);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBackpropData) void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBackpropData)
{ {
...@@ -379,83 +430,148 @@ namespace ngraph ...@@ -379,83 +430,148 @@ namespace ngraph
} }
} }
template <> template <typename T, bool use_bias>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBackpropFilters) void ConvolutionBackpropFiltersLayout(std::shared_ptr<ngraph::Node> node,
vector<memory::format>& prim_input_formats,
vector<memory::format>& prim_output_formats)
{ {
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get())) auto convolution = static_cast<const T*>(node.get());
{
auto convolution =
static_cast<const ngraph::op::ConvolutionBackpropFilters*>(node.get());
auto arg0_shape = node->get_input_shape(0); auto data_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1); auto delta_shape = node->get_input_shape(1);
auto result_shape = node->get_output_shape(0); auto filters_shape = node->get_output_shape(0);
auto filter_strides = convolution->get_window_movement_strides_forward(); auto filter_strides = convolution->get_window_movement_strides_forward();
auto padding_below = convolution->get_padding_below_forward(); auto padding_below = convolution->get_padding_below_forward();
auto padding_above = convolution->get_padding_above_forward(); auto padding_above = convolution->get_padding_above_forward();
Strides window_dilation_strides_adjusted;
for (size_t s : convolution->get_window_dilation_strides_forward())
{
window_dilation_strides_adjusted.push_back(s - 1);
}
memory::data_type et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type( Strides window_dilation_strides_adjusted;
node->get_input_element_type(0));
engine cpu_engine(engine::cpu, 0); for (size_t s : convolution->get_window_dilation_strides_forward())
memory::dims mkldnn_arg0_shape(arg0_shape.begin(), arg0_shape.end()); {
memory::dims mkldnn_arg1_shape(arg1_shape.begin(), arg1_shape.end()); window_dilation_strides_adjusted.push_back(s - 1);
memory::dims mkldnn_result_shape(result_shape.begin(), result_shape.end()); }
memory::dims mkldnn_filter_strides(filter_strides.begin(),
filter_strides.end());
memory::dims mkldnn_dilated_strides(
window_dilation_strides_adjusted.begin(),
window_dilation_strides_adjusted.end());
memory::dims mkldnn_padding_below(padding_below.begin(),
padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(),
padding_above.end());
const memory::desc data_desc(mkldnn_arg0_shape, et, memory::format::any);
const memory::desc delta_desc(mkldnn_arg1_shape, et, memory::format::any);
const memory::desc result_desc(
mkldnn_result_shape, et, memory::format::any);
convolution_backward_weights::desc bwd_desc(algorithm::convolution_direct, memory::data_type et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(
data_desc, node->get_input_element_type(0));
result_desc,
delta_desc, engine cpu_engine(engine::cpu, 0);
mkldnn_filter_strides, memory::dims mkldnn_data_shape(data_shape.begin(), data_shape.end());
mkldnn_dilated_strides, memory::dims mkldnn_delta_shape(delta_shape.begin(), delta_shape.end());
mkldnn_padding_below, memory::dims mkldnn_filters_shape(filters_shape.begin(), filters_shape.end());
mkldnn_padding_above, memory::dims mkldnn_filter_strides(filter_strides.begin(),
padding_kind::zero); filter_strides.end());
memory::dims mkldnn_dilated_strides(window_dilation_strides_adjusted.begin(),
window_dilation_strides_adjusted.end());
memory::dims mkldnn_padding_below(padding_below.begin(), padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(), padding_above.end());
const memory::desc data_desc(mkldnn_data_shape, et, memory::format::any);
const memory::desc delta_desc(mkldnn_delta_shape, et, memory::format::any);
const memory::desc filters_desc(mkldnn_filters_shape, et, memory::format::any);
std::unique_ptr<convolution_backward_weights::desc> bwd_desc{nullptr};
std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr};
if (use_bias)
{
auto bias_shape = node->get_output_shape(1);
memory::dims mkldnn_bias_shape(bias_shape.begin(), filters_shape.end());
const memory::desc bias_desc(mkldnn_bias_shape, et, memory::format::any);
bwd_desc.reset(
new convolution_backward_weights::desc(algorithm::convolution_direct,
data_desc,
filters_desc,
bias_desc,
delta_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero));
fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
algorithm::convolution_direct,
data_desc,
filters_desc,
bias_desc,
delta_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero));
}
else
{
bwd_desc.reset(
new convolution_backward_weights::desc(algorithm::convolution_direct,
data_desc,
filters_desc,
delta_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero));
fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
algorithm::convolution_direct,
data_desc,
filters_desc,
delta_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero));
}
convolution_forward::primitive_desc fwd_prim_desc(*fwd_desc, cpu_engine);
convolution_backward_weights::primitive_desc prim_desc(
*bwd_desc, cpu_engine, fwd_prim_desc);
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.src_primitive_desc().desc().data.format));
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.diff_dst_primitive_desc().desc().data.format));
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.diff_weights_primitive_desc().desc().data.format));
if (use_bias)
{
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.diff_bias_primitive_desc().desc().data.format));
}
}
convolution_forward::desc fwd_desc(prop_kind::forward, template <>
algorithm::convolution_direct, void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBackpropFilters)
data_desc, {
result_desc, if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
delta_desc, {
mkldnn_filter_strides, vector<memory::format> prim_input_formats;
mkldnn_dilated_strides, vector<memory::format> prim_output_formats;
mkldnn_padding_below, ConvolutionBackpropFiltersLayout<ngraph::op::ConvolutionBackpropFilters,
mkldnn_padding_above, false>(
padding_kind::zero); node, prim_input_formats, prim_output_formats);
convolution_forward::primitive_desc fwd_prim_desc(fwd_desc, cpu_engine);
convolution_backward_weights::primitive_desc prim_desc( node =
bwd_desc, cpu_engine, fwd_prim_desc); insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBiasBackpropFiltersBias)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::format> prim_input_formats; vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats; vector<memory::format> prim_output_formats;
prim_input_formats.push_back(static_cast<memory::format>( ConvolutionBackpropFiltersLayout<
prim_desc.src_primitive_desc().desc().data.format)); ngraph::op::ConvolutionBiasBackpropFiltersBias,
prim_input_formats.push_back(static_cast<memory::format>( false>(node, prim_input_formats, prim_output_formats);
prim_desc.diff_dst_primitive_desc().desc().data.format));
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.diff_weights_primitive_desc().desc().data.format));
node = node =
insert_input_conversions(external_function, node, prim_input_formats); insert_input_conversions(external_function, node, prim_input_formats);
...@@ -715,6 +831,10 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{ ...@@ -715,6 +831,10 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropData>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropData>},
{TI(ngraph::op::ConvolutionBackpropFilters), {TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropFilters>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropFilters>},
{TI(ngraph::op::ConvolutionBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBias>},
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBiasBackpropFiltersBias>},
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>}, {TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop), {TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment