Commit 0f9119c8 authored by Jayaram Bobba's avatar Jayaram Bobba

Add layout propagation to MKLDNN batchnorm and GetOutputElement

parent 9db548c6
......@@ -166,7 +166,7 @@ mkldnn::memory::format runtime::cpu::mkldnn_utils::get_input_mkldnn_format(const
mkldnn::memory::format runtime::cpu::mkldnn_utils::get_output_mkldnn_format(const Node* node,
int index)
{
auto tvl = node->get_output_tensor_view(0)->get_tensor_view_layout();
auto tvl = node->get_output_tensor_view(index)->get_tensor_view_layout();
return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format();
}
......
......@@ -27,6 +27,7 @@
#include "ngraph/descriptor/output.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/relu.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
......@@ -225,6 +226,16 @@ namespace ngraph
avg_pool->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNorm)
{
auto batchnorm = static_cast<op::ReluBackprop*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
}
}
}
}
......@@ -234,6 +245,7 @@ namespace ngraph
static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Add>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNorm>},
{TI(ngraph::op::Convolution),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionBackpropData),
......
......@@ -28,7 +28,9 @@
#include "ngraph/log.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/get_output_element.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/ops/relu.hpp"
#include "ngraph/ops/result.hpp"
......@@ -640,6 +642,17 @@ namespace ngraph
set_output_layouts(node, prim_output_formats);
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::GetOutputElement)
{
auto goe = static_cast<const ngraph::op::GetOutputElement*>(node.get());
auto input_layout = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(
node.get(), goe->get_n());
vector<memory::format> prim_output_formats;
prim_output_formats.push_back(input_layout);
set_output_layouts(node, prim_output_formats);
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Relu)
{
......@@ -680,6 +693,32 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::BatchNorm)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 2);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(input_layout);
prim_output_formats.push_back(memory::format::x);
prim_output_formats.push_back(memory::format::x);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
throw ngraph_error("Batchnorm only supported in MKLDNN for now");
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Add)
{
......@@ -719,6 +758,9 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNorm>},
{TI(ngraph::op::GetOutputElement),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GetOutputElement>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>},
{TI(ngraph::op::Result), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Result>},
{TI(ngraph::op::ReluBackprop),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment