Unverified Commit e46184a1 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by GitHub

Merge pull request #613 from NervanaSystems/jbobba/batchnorm-layouts

Jbobba/batchnorm layouts
parents 9cca4073 a94d46d4
......@@ -3169,6 +3169,19 @@ namespace ngraph
auto output_format =
dynamic_cast<runtime::cpu::LayoutDescriptor&>(*output_tvl).get_mkldnn_format();
// MKLDNN relies on format names for selecting optimized kernel implementations
// Hacky way to deal with this until they move to using canonicalized layouts
if (input_format == mkldnn::memory::format::nchw &&
runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(output_format))
{
input_format = mkldnn::memory::format::oihw;
}
if (output_format == mkldnn::memory::format::nchw &&
runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(input_format))
{
output_format = mkldnn::memory::format::oihw;
}
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor(args[0], input_format);
auto result_desc = mkldnn_emitter->build_memory_descriptor(out[0], output_format);
......
......@@ -110,6 +110,23 @@ static const std::map<memory::format, const std::string> s_mkldnn_format_string_
{memory::format::OhIw16o4i, "memory::format::OhIw16o4i"},
};
static const std::set<memory::format> s_filter_formats{
memory::format::oihw,
memory::format::ihwo,
memory::format::hwio,
// memory::format::oIhw8i, // These currently map to nChw8c and nChw16c
// memory::format::oIhw16i,
memory::format::OIhw8i8o,
memory::format::OIhw16i16o,
memory::format::IOhw16o16i,
memory::format::OIhw8o8i,
memory::format::OIhw16o16i,
memory::format::Oihw8o,
memory::format::Oihw16o,
memory::format::Ohwi8o,
memory::format::Ohwi16o,
memory::format::OhIw16o4i};
bool runtime::cpu::mkldnn_utils::IsMKLDNNOp(ngraph::Node& op)
{
return (s_op_registry.find(TI(op)) != s_op_registry.end());
......@@ -157,16 +174,16 @@ const std::string& runtime::cpu::mkldnn_utils::get_mkldnn_format_string(memory::
}
mkldnn::memory::format runtime::cpu::mkldnn_utils::get_input_mkldnn_format(const Node* node,
int index)
size_t index)
{
auto tvl = node->get_inputs()[index].get_output().get_tensor_view()->get_tensor_view_layout();
return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format();
}
mkldnn::memory::format runtime::cpu::mkldnn_utils::get_output_mkldnn_format(const Node* node,
int index)
size_t index)
{
auto tvl = node->get_output_tensor_view(0)->get_tensor_view_layout();
auto tvl = node->get_output_tensor_view(index)->get_tensor_view_layout();
return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format();
}
......@@ -181,8 +198,8 @@ bool runtime::cpu::mkldnn_utils::use_mkldnn_kernel(const ngraph::Node* node)
bool runtime::cpu::mkldnn_utils::compare_mkldnn_formats(mkldnn::memory::format fmt1,
mkldnn::memory::format fmt2)
{
set<mkldnn::memory::format> similar_4d_formats{mkldnn::memory::format::nchw,
mkldnn::memory::format::oihw};
std::set<mkldnn::memory::format> similar_4d_formats{mkldnn::memory::format::nchw,
mkldnn::memory::format::oihw};
if ((fmt1 == fmt2) || (similar_4d_formats.find(fmt1) != similar_4d_formats.end() &&
similar_4d_formats.find(fmt2) != similar_4d_formats.end()))
{
......@@ -190,3 +207,12 @@ bool runtime::cpu::mkldnn_utils::compare_mkldnn_formats(mkldnn::memory::format f
}
return false;
}
bool runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(mkldnn::memory::format fmt)
{
if (s_filter_formats.find(fmt) != s_filter_formats.end())
{
return true;
}
return false;
}
......@@ -39,11 +39,12 @@ namespace ngraph
mkldnn::memory::data_type get_mkldnn_data_type(const ngraph::element::Type& type);
const std::string& get_mkldnn_format_string(mkldnn::memory::format fmt);
mkldnn::memory::format get_input_mkldnn_format(const Node* node, int index);
mkldnn::memory::format get_output_mkldnn_format(const Node* node, int index);
mkldnn::memory::format get_input_mkldnn_format(const Node* node, size_t index);
mkldnn::memory::format get_output_mkldnn_format(const Node* node, size_t index);
bool use_mkldnn_kernel(const ngraph::Node* node);
bool compare_mkldnn_formats(mkldnn::memory::format fmt1,
mkldnn::memory::format fmt2);
bool is_mkldnn_filter_format(mkldnn::memory::format fmt);
}
}
}
......
......@@ -27,6 +27,7 @@
#include "ngraph/descriptor/output.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/relu.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
......@@ -225,6 +226,16 @@ namespace ngraph
avg_pool->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNorm)
{
auto batchnorm = static_cast<op::BatchNorm*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
}
}
}
}
......@@ -234,6 +245,7 @@ namespace ngraph
static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Add>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNorm>},
{TI(ngraph::op::Convolution),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionBackpropData),
......
......@@ -28,7 +28,9 @@
#include "ngraph/log.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/get_output_element.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/ops/relu.hpp"
#include "ngraph/ops/result.hpp"
......@@ -640,6 +642,17 @@ namespace ngraph
set_output_layouts(node, prim_output_formats);
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::GetOutputElement)
{
auto goe = static_cast<const ngraph::op::GetOutputElement*>(node.get());
auto input_layout = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(
node.get(), goe->get_n());
vector<memory::format> prim_output_formats;
prim_output_formats.push_back(input_layout);
set_output_layouts(node, prim_output_formats);
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Relu)
{
......@@ -680,6 +693,32 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::BatchNorm)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 2);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(input_layout);
prim_output_formats.push_back(memory::format::x);
prim_output_formats.push_back(memory::format::x);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
throw ngraph_error("Batchnorm only supported in MKLDNN for now");
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Add)
{
......@@ -719,6 +758,9 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNorm>},
{TI(ngraph::op::GetOutputElement),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GetOutputElement>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>},
{TI(ngraph::op::Result), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Result>},
{TI(ngraph::op::ReluBackprop),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment