Unverified Commit e46184a1 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by GitHub

Merge pull request #613 from NervanaSystems/jbobba/batchnorm-layouts

Jbobba/batchnorm layouts
parents 9cca4073 a94d46d4
...@@ -3169,6 +3169,19 @@ namespace ngraph ...@@ -3169,6 +3169,19 @@ namespace ngraph
auto output_format = auto output_format =
dynamic_cast<runtime::cpu::LayoutDescriptor&>(*output_tvl).get_mkldnn_format(); dynamic_cast<runtime::cpu::LayoutDescriptor&>(*output_tvl).get_mkldnn_format();
// MKLDNN relies on format names for selecting optimized kernel implementations
// Hacky way to deal with this until they move to using canonicalized layouts
if (input_format == mkldnn::memory::format::nchw &&
runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(output_format))
{
input_format = mkldnn::memory::format::oihw;
}
if (output_format == mkldnn::memory::format::nchw &&
runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(input_format))
{
output_format = mkldnn::memory::format::oihw;
}
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor(args[0], input_format); auto input_desc = mkldnn_emitter->build_memory_descriptor(args[0], input_format);
auto result_desc = mkldnn_emitter->build_memory_descriptor(out[0], output_format); auto result_desc = mkldnn_emitter->build_memory_descriptor(out[0], output_format);
......
...@@ -110,6 +110,23 @@ static const std::map<memory::format, const std::string> s_mkldnn_format_string_ ...@@ -110,6 +110,23 @@ static const std::map<memory::format, const std::string> s_mkldnn_format_string_
{memory::format::OhIw16o4i, "memory::format::OhIw16o4i"}, {memory::format::OhIw16o4i, "memory::format::OhIw16o4i"},
}; };
static const std::set<memory::format> s_filter_formats{
memory::format::oihw,
memory::format::ihwo,
memory::format::hwio,
// memory::format::oIhw8i, // These currently map to nChw8c and nChw16c
// memory::format::oIhw16i,
memory::format::OIhw8i8o,
memory::format::OIhw16i16o,
memory::format::IOhw16o16i,
memory::format::OIhw8o8i,
memory::format::OIhw16o16i,
memory::format::Oihw8o,
memory::format::Oihw16o,
memory::format::Ohwi8o,
memory::format::Ohwi16o,
memory::format::OhIw16o4i};
bool runtime::cpu::mkldnn_utils::IsMKLDNNOp(ngraph::Node& op) bool runtime::cpu::mkldnn_utils::IsMKLDNNOp(ngraph::Node& op)
{ {
return (s_op_registry.find(TI(op)) != s_op_registry.end()); return (s_op_registry.find(TI(op)) != s_op_registry.end());
...@@ -157,16 +174,16 @@ const std::string& runtime::cpu::mkldnn_utils::get_mkldnn_format_string(memory:: ...@@ -157,16 +174,16 @@ const std::string& runtime::cpu::mkldnn_utils::get_mkldnn_format_string(memory::
} }
mkldnn::memory::format runtime::cpu::mkldnn_utils::get_input_mkldnn_format(const Node* node, mkldnn::memory::format runtime::cpu::mkldnn_utils::get_input_mkldnn_format(const Node* node,
int index) size_t index)
{ {
auto tvl = node->get_inputs()[index].get_output().get_tensor_view()->get_tensor_view_layout(); auto tvl = node->get_inputs()[index].get_output().get_tensor_view()->get_tensor_view_layout();
return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format(); return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format();
} }
mkldnn::memory::format runtime::cpu::mkldnn_utils::get_output_mkldnn_format(const Node* node, mkldnn::memory::format runtime::cpu::mkldnn_utils::get_output_mkldnn_format(const Node* node,
int index) size_t index)
{ {
auto tvl = node->get_output_tensor_view(0)->get_tensor_view_layout(); auto tvl = node->get_output_tensor_view(index)->get_tensor_view_layout();
return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format(); return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format();
} }
...@@ -181,7 +198,7 @@ bool runtime::cpu::mkldnn_utils::use_mkldnn_kernel(const ngraph::Node* node) ...@@ -181,7 +198,7 @@ bool runtime::cpu::mkldnn_utils::use_mkldnn_kernel(const ngraph::Node* node)
bool runtime::cpu::mkldnn_utils::compare_mkldnn_formats(mkldnn::memory::format fmt1, bool runtime::cpu::mkldnn_utils::compare_mkldnn_formats(mkldnn::memory::format fmt1,
mkldnn::memory::format fmt2) mkldnn::memory::format fmt2)
{ {
set<mkldnn::memory::format> similar_4d_formats{mkldnn::memory::format::nchw, std::set<mkldnn::memory::format> similar_4d_formats{mkldnn::memory::format::nchw,
mkldnn::memory::format::oihw}; mkldnn::memory::format::oihw};
if ((fmt1 == fmt2) || (similar_4d_formats.find(fmt1) != similar_4d_formats.end() && if ((fmt1 == fmt2) || (similar_4d_formats.find(fmt1) != similar_4d_formats.end() &&
similar_4d_formats.find(fmt2) != similar_4d_formats.end())) similar_4d_formats.find(fmt2) != similar_4d_formats.end()))
...@@ -190,3 +207,12 @@ bool runtime::cpu::mkldnn_utils::compare_mkldnn_formats(mkldnn::memory::format f ...@@ -190,3 +207,12 @@ bool runtime::cpu::mkldnn_utils::compare_mkldnn_formats(mkldnn::memory::format f
} }
return false; return false;
} }
bool runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(mkldnn::memory::format fmt)
{
if (s_filter_formats.find(fmt) != s_filter_formats.end())
{
return true;
}
return false;
}
...@@ -39,11 +39,12 @@ namespace ngraph ...@@ -39,11 +39,12 @@ namespace ngraph
mkldnn::memory::data_type get_mkldnn_data_type(const ngraph::element::Type& type); mkldnn::memory::data_type get_mkldnn_data_type(const ngraph::element::Type& type);
const std::string& get_mkldnn_format_string(mkldnn::memory::format fmt); const std::string& get_mkldnn_format_string(mkldnn::memory::format fmt);
mkldnn::memory::format get_input_mkldnn_format(const Node* node, int index); mkldnn::memory::format get_input_mkldnn_format(const Node* node, size_t index);
mkldnn::memory::format get_output_mkldnn_format(const Node* node, int index); mkldnn::memory::format get_output_mkldnn_format(const Node* node, size_t index);
bool use_mkldnn_kernel(const ngraph::Node* node); bool use_mkldnn_kernel(const ngraph::Node* node);
bool compare_mkldnn_formats(mkldnn::memory::format fmt1, bool compare_mkldnn_formats(mkldnn::memory::format fmt1,
mkldnn::memory::format fmt2); mkldnn::memory::format fmt2);
bool is_mkldnn_filter_format(mkldnn::memory::format fmt);
} }
} }
} }
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/avg_pool.hpp" #include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp" #include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/relu.hpp" #include "ngraph/ops/relu.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp" #include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
...@@ -225,6 +226,16 @@ namespace ngraph ...@@ -225,6 +226,16 @@ namespace ngraph
avg_pool->set_op_annotations(op_annotations); avg_pool->set_op_annotations(op_annotations);
} }
} }
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNorm)
{
auto batchnorm = static_cast<op::BatchNorm*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
}
} }
} }
} }
...@@ -234,6 +245,7 @@ namespace ngraph ...@@ -234,6 +245,7 @@ namespace ngraph
static const runtime::cpu::pass::AssignOpMap s_dispatcher{ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Add>}, {TI(ngraph::op::Add), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Add>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNorm>},
{TI(ngraph::op::Convolution), {TI(ngraph::op::Convolution),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionBackpropData), {TI(ngraph::op::ConvolutionBackpropData),
......
...@@ -28,7 +28,9 @@ ...@@ -28,7 +28,9 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/avg_pool.hpp" #include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp" #include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/get_output_element.hpp"
#include "ngraph/ops/op.hpp" #include "ngraph/ops/op.hpp"
#include "ngraph/ops/relu.hpp" #include "ngraph/ops/relu.hpp"
#include "ngraph/ops/result.hpp" #include "ngraph/ops/result.hpp"
...@@ -640,6 +642,17 @@ namespace ngraph ...@@ -640,6 +642,17 @@ namespace ngraph
set_output_layouts(node, prim_output_formats); set_output_layouts(node, prim_output_formats);
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::GetOutputElement)
{
auto goe = static_cast<const ngraph::op::GetOutputElement*>(node.get());
auto input_layout = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(
node.get(), goe->get_n());
vector<memory::format> prim_output_formats;
prim_output_formats.push_back(input_layout);
set_output_layouts(node, prim_output_formats);
}
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Relu) void CPULayout::LAYOUT_DECL(ngraph::op::Relu)
{ {
...@@ -680,6 +693,32 @@ namespace ngraph ...@@ -680,6 +693,32 @@ namespace ngraph
} }
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::BatchNorm)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 2);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(input_layout);
prim_output_formats.push_back(memory::format::x);
prim_output_formats.push_back(memory::format::x);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
throw ngraph_error("Batchnorm only supported in MKLDNN for now");
}
}
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Add) void CPULayout::LAYOUT_DECL(ngraph::op::Add)
{ {
...@@ -719,6 +758,9 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{ ...@@ -719,6 +758,9 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>}, {TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop), {TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNorm>},
{TI(ngraph::op::GetOutputElement),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GetOutputElement>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>}, {TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>},
{TI(ngraph::op::Result), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Result>}, {TI(ngraph::op::Result), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Result>},
{TI(ngraph::op::ReluBackprop), {TI(ngraph::op::ReluBackprop),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment