Commit 0f9119c8 authored by Jayaram Bobba's avatar Jayaram Bobba

Add layout propagation to MKLDNN batchnorm and GetOutputElement

parent 9db548c6
...@@ -166,7 +166,7 @@ mkldnn::memory::format runtime::cpu::mkldnn_utils::get_input_mkldnn_format(const ...@@ -166,7 +166,7 @@ mkldnn::memory::format runtime::cpu::mkldnn_utils::get_input_mkldnn_format(const
mkldnn::memory::format runtime::cpu::mkldnn_utils::get_output_mkldnn_format(const Node* node, mkldnn::memory::format runtime::cpu::mkldnn_utils::get_output_mkldnn_format(const Node* node,
int index) int index)
{ {
auto tvl = node->get_output_tensor_view(0)->get_tensor_view_layout(); auto tvl = node->get_output_tensor_view(index)->get_tensor_view_layout();
return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format(); return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format();
} }
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/avg_pool.hpp" #include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp" #include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/relu.hpp" #include "ngraph/ops/relu.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp" #include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
...@@ -225,6 +226,16 @@ namespace ngraph ...@@ -225,6 +226,16 @@ namespace ngraph
avg_pool->set_op_annotations(op_annotations); avg_pool->set_op_annotations(op_annotations);
} }
} }
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNorm)
{
auto batchnorm = static_cast<op::ReluBackprop*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
}
} }
} }
} }
...@@ -234,6 +245,7 @@ namespace ngraph ...@@ -234,6 +245,7 @@ namespace ngraph
static const runtime::cpu::pass::AssignOpMap s_dispatcher{ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Add>}, {TI(ngraph::op::Add), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Add>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNorm>},
{TI(ngraph::op::Convolution), {TI(ngraph::op::Convolution),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionBackpropData), {TI(ngraph::op::ConvolutionBackpropData),
......
...@@ -28,7 +28,9 @@ ...@@ -28,7 +28,9 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/avg_pool.hpp" #include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp" #include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/get_output_element.hpp"
#include "ngraph/ops/op.hpp" #include "ngraph/ops/op.hpp"
#include "ngraph/ops/relu.hpp" #include "ngraph/ops/relu.hpp"
#include "ngraph/ops/result.hpp" #include "ngraph/ops/result.hpp"
...@@ -640,6 +642,17 @@ namespace ngraph ...@@ -640,6 +642,17 @@ namespace ngraph
set_output_layouts(node, prim_output_formats); set_output_layouts(node, prim_output_formats);
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::GetOutputElement)
{
auto goe = static_cast<const ngraph::op::GetOutputElement*>(node.get());
auto input_layout = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(
node.get(), goe->get_n());
vector<memory::format> prim_output_formats;
prim_output_formats.push_back(input_layout);
set_output_layouts(node, prim_output_formats);
}
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Relu) void CPULayout::LAYOUT_DECL(ngraph::op::Relu)
{ {
...@@ -680,6 +693,32 @@ namespace ngraph ...@@ -680,6 +693,32 @@ namespace ngraph
} }
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::BatchNorm)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 2);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(input_layout);
prim_output_formats.push_back(memory::format::x);
prim_output_formats.push_back(memory::format::x);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
throw ngraph_error("Batchnorm only supported in MKLDNN for now");
}
}
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Add) void CPULayout::LAYOUT_DECL(ngraph::op::Add)
{ {
...@@ -719,6 +758,9 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{ ...@@ -719,6 +758,9 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>}, {TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop), {TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNorm>},
{TI(ngraph::op::GetOutputElement),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GetOutputElement>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>}, {TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>},
{TI(ngraph::op::Result), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Result>}, {TI(ngraph::op::Result), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Result>},
{TI(ngraph::op::ReluBackprop), {TI(ngraph::op::ReluBackprop),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment