Commit da3184ec authored by Jayaram Bobba's avatar Jayaram Bobba

Added batchnorm bprop layouts and moved batchnorm ops to mkldnn emitter

parent 5885c09a
This diff is collapsed.
......@@ -22,6 +22,7 @@
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/types/element_type.hpp"
using namespace ngraph::runtime::cpu;
......@@ -58,6 +59,15 @@ mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrap
return build_memory_descriptor(tvw, layout->get_mkldnn_format());
}
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const Shape& shape,
const ngraph::element::Type& et,
mkldnn::memory::format fmt) const
{
return mkldnn::memory::desc(mkldnn::memory::dims(shape.begin(), shape.end()),
mkldnn_utils::get_mkldnn_data_type(et),
fmt);
}
mkldnn::memory MKLDNNEmitter::build_memory_primitive(const TensorViewWrapper& tvw) const
{
return mkldnn::memory({build_memory_descriptor(tvw), mkldnn_utils::global_cpu_engine}, nullptr);
......@@ -327,3 +337,80 @@ size_t MKLDNNEmitter::build_elementwise_add(
m_primitive_deps[add_index] = {input0_data_index, input1_data_index, result_index};
return add_index;
}
size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const double eps)
{
size_t input_index = build_memory_primitive(input_desc);
size_t weights_index = build_memory_primitive(weights_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t mean_index = build_memory_primitive(mean_desc);
size_t variance_index = build_memory_primitive(variance_desc);
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward(
{{mkldnn::prop_kind::forward_training,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift},
mkldnn_utils::global_cpu_engine},
mkldnn::primitive::at(*m_mkldnn_primitives[input_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[weights_index]),
static_cast<mkldnn::memory>(*m_mkldnn_primitives[result_index]),
*m_mkldnn_primitives[mean_index],
*m_mkldnn_primitives[variance_index]));
m_primitive_deps[batchnorm_index] = {
input_index, weights_index, result_index, mean_index, variance_index};
return batchnorm_index;
}
size_t MKLDNNEmitter::build_batchnorm_backward(const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& dinput_desc,
const mkldnn::memory::desc& dweights_desc,
const double eps)
{
size_t weights_index = build_memory_primitive(weights_desc);
size_t input_index = build_memory_primitive(input_desc);
size_t mean_index = build_memory_primitive(mean_desc);
size_t variance_index = build_memory_primitive(variance_desc);
size_t delta_index = build_memory_primitive(delta_desc);
size_t dinput_index = build_memory_primitive(dinput_desc);
size_t dweights_index = build_memory_primitive(dweights_desc);
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_backward(
{{mkldnn::prop_kind::backward,
delta_desc,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift},
mkldnn_utils::global_cpu_engine,
{{mkldnn::prop_kind::forward_training,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift},
mkldnn_utils::global_cpu_engine}},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[mean_index],
*m_mkldnn_primitives[variance_index],
*m_mkldnn_primitives[delta_index],
*m_mkldnn_primitives[weights_index],
*m_mkldnn_primitives[dinput_index],
*m_mkldnn_primitives[dweights_index]));
m_primitive_deps[batchnorm_index] = {weights_index,
input_index,
mean_index,
variance_index,
delta_index,
dinput_index,
dweights_index};
return batchnorm_index;
}
\ No newline at end of file
......@@ -25,6 +25,7 @@
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp"
#include "ngraph/types/element_type.hpp"
namespace ngraph
{
......@@ -48,6 +49,9 @@ namespace ngraph
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw,
mkldnn::memory::format fmt) const;
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw) const;
mkldnn::memory::desc build_memory_descriptor(const Shape& shape,
const ngraph::element::Type& et,
mkldnn::memory::format fmt) const;
mkldnn::memory build_memory_primitive(const TensorViewWrapper& tvw) const;
size_t build_memory_primitive(const mkldnn::memory::desc& desc);
......@@ -107,6 +111,22 @@ namespace ngraph
const std::vector<float>& scale_vector,
const std::vector<mkldnn::memory::primitive_desc>& input_pd);
size_t build_batchnorm_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const double eps);
size_t build_batchnorm_backward(const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& dinput_desc,
const mkldnn::memory::desc& dweights_desc,
const double eps);
private:
std::vector<mkldnn::primitive*> m_mkldnn_primitives;
std::vector<mkldnn::stream> m_mkldnn_streams;
......
......@@ -250,6 +250,16 @@ namespace ngraph
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNormBackprop)
{
auto batchnorm = static_cast<op::BatchNormBackprop*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
}
}
}
}
......@@ -260,6 +270,8 @@ namespace ngraph
static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Add>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNorm>},
{TI(ngraph::op::BatchNormBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNormBackprop>},
{TI(ngraph::op::Convolution),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionBackpropData),
......
......@@ -737,6 +737,36 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::BatchNormBackprop)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto delta_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 5);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(memory::format::x); // gamma
prim_input_formats.push_back(memory::format::x); // beta
prim_input_formats.push_back(delta_layout); // input
prim_input_formats.push_back(memory::format::x); // mean
prim_input_formats.push_back(memory::format::x); // variance
prim_input_formats.push_back(delta_layout); // delta
prim_output_formats.push_back(delta_layout); // dinput
prim_output_formats.push_back(memory::format::x); // dgamma
prim_output_formats.push_back(memory::format::x); // dbeta
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
throw ngraph_error("Batchnorm Backprop only supported in MKLDNN for now");
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Add)
{
......@@ -777,6 +807,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNorm>},
{TI(ngraph::op::BatchNormBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNormBackprop>},
{TI(ngraph::op::GetOutputElement),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GetOutputElement>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment