Commit fb6981b8 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by adstraw

Add layouts and a check for a single user (Relu) (#791)

* add layouts and users check
* add convrelu handler
parent 015e1da8
...@@ -806,6 +806,12 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu() ...@@ -806,6 +806,12 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu()
return false; return false;
} }
if (conv->get_users().size() > 1)
{
NGRAPH_DEBUG << "Convolution has more than one user";
return false;
}
auto conv_relu = std::shared_ptr<Node>(new op::ConvolutionRelu(conv)); auto conv_relu = std::shared_ptr<Node>(new op::ConvolutionRelu(conv));
ngraph::replace_node(m.match_root(), conv_relu); ngraph::replace_node(m.match_root(), conv_relu);
return true; return true;
......
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp" #include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp" #include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp" #include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp" #include "ngraph/runtime/cpu/op/sigmoid.hpp"
...@@ -347,6 +348,25 @@ namespace ngraph ...@@ -347,6 +348,25 @@ namespace ngraph
} }
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionRelu)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
ConvolutionLayout<ngraph::op::ConvolutionRelu, false>(
node, prim_input_formats, prim_output_formats);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBackpropData) void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBackpropData)
{ {
...@@ -1170,6 +1190,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{ ...@@ -1170,6 +1190,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
&runtime::cpu::pass::CPULayout::layout<ngraph::op::MaxPoolBackprop>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::MaxPoolBackprop>},
{TI(ngraph::op::ConvolutionBias), {TI(ngraph::op::ConvolutionBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBias>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBias>},
{TI(ngraph::op::ConvolutionRelu),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionRelu>},
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias), {TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBiasBackpropFiltersBias>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBiasBackpropFiltersBias>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNorm>}, {TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNorm>},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment