Commit da3184ec authored by Jayaram Bobba's avatar Jayaram Bobba

Added batchnorm bprop layouts and moved batchnorm ops to mkldnn emitter

parent 5885c09a
...@@ -386,99 +386,55 @@ namespace ngraph ...@@ -386,99 +386,55 @@ namespace ngraph
const ngraph::op::BatchNorm* batchnorm = const ngraph::op::BatchNorm* batchnorm =
static_cast<const ngraph::op::BatchNorm*>(node); static_cast<const ngraph::op::BatchNorm*>(node);
// get the shape of all the inputs and output to batchnorm
auto gamma_shape = args[0].get_shape();
auto beta_shape = args[1].get_shape();
auto input_shape = args[2].get_shape();
auto result_shape = out[0].get_shape();
auto mean_shape = out[1].get_shape();
auto variance_shape = out[2].get_shape();
// get input element type
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
args[2].get_element_type());
const string& gamma_format = runtime::cpu::mkldnn_utils::get_mkldnn_format_string(
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
const string& beta_format = runtime::cpu::mkldnn_utils::get_mkldnn_format_string(
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1));
if (gamma_format.compare("memory::format::x") != 0 &&
beta_format.compare("memory::format::x") != 0)
{
throw std::runtime_error(
"gamma layout->" + gamma_format + ", beta layout->" + beta_format +
" should match and both should have memory::format::x format");
}
writer << "{\n";
writer.indent++; writer.indent++;
writer << "{\n";
// define weights // define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string() writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(2*" << input_shape[1] << ");\n"; << ">bn_weights(2*" << args[0].get_size() << ");\n";
auto weights_shape = Shape{2, input_shape[1]}; writer << "memcpy(&bn_weights[0], " << args[0].get_name() << ", "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
// push gamma and beta writer << "memcpy(&bn_weights[0]+" << args[0].get_size() << ", "
writer << "auto gamma = " << args[0].get_name() << ";\n"; << args[1].get_name() << ", "
writer << "auto beta = " << args[1].get_name() << ";\n";
writer << "memcpy(&bn_weights[0], gamma,"
<< args[1].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(&bn_weights[0]+" << args[1].get_size() << ", beta, "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n"; << args[1].get_size() * args[1].get_element_type().size() << ");\n";
// get the eps value from the bn node auto input_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
writer << "auto epsilon = " << batchnorm->get_eps_value() << ";\n"; auto result_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto mean_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 1);
const string& input_format = runtime::cpu::mkldnn_utils::get_mkldnn_format_string( auto variance_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2)); runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 2);
const string& result_format = runtime::cpu::mkldnn_utils::get_mkldnn_format_string(
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0)); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
// Bind to CPU engine auto weights_shape = Shape{2, args[0].get_size()};
writer << "engine cpu_engine = engine(engine::cpu, 0);\n"; auto input_desc = mkldnn_emitter->build_memory_descriptor(args[2], input_format);
// create memory descriptors auto weights_desc = mkldnn_emitter->build_memory_descriptor(
writer << "memory::desc input_data_desc = memory::desc({" << join(input_shape) weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
<< "}, " << et << ", " << input_format << ");\n"; auto results_desc = mkldnn_emitter->build_memory_descriptor(out[0], result_format);
// TODO define weights by stacking gamma and beta values auto mean_desc = mkldnn_emitter->build_memory_descriptor(out[1], mean_format);
writer << "memory::desc weights_desc = memory::desc({" << join(weights_shape) auto variance_desc =
<< "}, " << et << ", memory::format::nc);\n"; mkldnn_emitter->build_memory_descriptor(out[2], variance_format);
writer << "memory::desc result_desc = memory::desc({" << join(result_shape) << "}, "
<< et << ", " << result_format << ");\n"; auto batchnorm_index =
writer << "memory::desc mean_desc = memory::desc({" << join(mean_shape) << "}, " mkldnn_emitter->build_batchnorm_forward(input_desc,
<< et << ", memory::format::x);\n"; weights_desc,
writer << "memory::desc variance_desc = memory::desc({" << join(variance_shape) results_desc,
<< "}, " << et << ", memory::format::x);\n"; mean_desc,
variance_desc,
// Define memory for the user data batchnorm->get_eps_value());
writer << "memory input_data = memory({input_data_desc, cpu_engine}, "
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", "
<< args[2].get_name() << ");\n"; << args[2].get_name() << ");\n";
writer << "memory weights = memory({weights_desc, cpu_engine}, bn_weights.data()" writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ");\n"; << ", bn_weights.data());\n";
writer << "memory result = memory({result_desc, cpu_engine}, " << out[0].get_name() writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", "
<< ");\n"; << out[0].get_name() << ");\n";
writer << "memory mean = memory({mean_desc, cpu_engine}, " << out[1].get_name() writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3]) << ", "
<< ");\n"; << out[1].get_name() << ");\n";
writer << "memory variance = memory({variance_desc, cpu_engine}, " writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[4]) << ", "
<< out[2].get_name() << ");\n"; << out[2].get_name() << ");\n";
// create batchnorm descriptor writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
writer << "batch_normalization_forward::desc bn_fprop_desc = " << to_string(batchnorm_index) << ");\n";
"batch_normalization_forward::desc(forward_training,"
<< "input_data_desc, epsilon, use_scale_shift);\n";
// bn fprop primitive descriptor
writer
<< "batch_normalization_forward::primitive_desc bn_fprop_prim_desc = "
"batch_normalization_forward::primitive_desc(bn_fprop_desc, cpu_engine);\n";
// create a batchnorm fprop primitive
writer << "batch_normalization_forward bn_fprop = "
"batch_normalization_forward(bn_fprop_prim_desc, "
"primitive::at(input_data),"
<< "primitive::at(weights), result, mean, variance); \n";
// create stream and execute
writer << "stream s = stream(stream::kind::eager);\n"
<< "s.submit({bn_fprop}).wait();\n";
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
} }
...@@ -488,108 +444,74 @@ namespace ngraph ...@@ -488,108 +444,74 @@ namespace ngraph
{ {
const ngraph::op::BatchNormBackprop* batchnorm = const ngraph::op::BatchNormBackprop* batchnorm =
static_cast<const ngraph::op::BatchNormBackprop*>(node); static_cast<const ngraph::op::BatchNormBackprop*>(node);
auto gamma_shape = args[0].get_shape();
auto beta_shape = args[1].get_shape();
auto input_shape = args[2].get_shape();
auto mean_shape = args[3].get_shape();
auto variance_shape = args[4].get_shape();
auto delta_shape = args[5].get_shape();
auto result_shape = out[0].get_shape();
// get input element type
const string& et =
mkldnn_utils::get_mkldnn_data_type_string(args[2].get_element_type());
writer << "{\n";
writer.indent++; writer.indent++;
writer << "{\n";
// define weights // define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string() writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(" << input_shape[1] * 2 << ");\n"; << ">bn_weights(2*" << args[0].get_size() << ");\n";
writer << "std::vector<" << args[0].get_element_type().c_type_string() writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">vdiff_weights(" << input_shape[1] * 2 << ");\n"; << ">bn_dweights(2*" << args[0].get_size() << ");\n";
auto weights_shape = Shape{2, input_shape[1]};
// push gamma and beta
writer << "auto gamma = " << args[0].get_name() << ";\n";
writer << "auto beta = " << args[1].get_name() << ";\n";
writer << "memcpy(&bn_weights[0], gamma," writer << "memcpy(&bn_weights[0], " << args[0].get_name() << ", "
<< args[1].get_size() * args[0].get_element_type().size() << ");\n"; << args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(&bn_weights[0]+" << args[1].get_size() << ", beta, " writer << "memcpy(&bn_weights[0]+" << args[0].get_size() << ", "
<< args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n"; << args[1].get_size() * args[1].get_element_type().size() << ");\n";
// get the eps value from the bn node auto input_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
writer << "auto epsilon = " << batchnorm->get_eps_value() << ";\n"; auto mean_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 3);
// Bind to CPU engine auto variance_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 4);
writer << "using namespace mkldnn; \n"; auto delta_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 5);
writer << "engine cpu_engine = engine(engine::cpu, 0);\n"; auto dinput_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
// create memory descriptors
writer << "memory::desc input_data_desc = memory::desc({" << join(input_shape) auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
<< "}, " << et << ", memory::format::nchw);\n"; auto weights_shape = Shape{2, args[0].get_size()};
// TODO define weights by stacking gamma and beta values auto weights_desc = mkldnn_emitter->build_memory_descriptor(
writer << "memory::desc weights_desc = memory::desc({" << join(weights_shape) weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
<< "}, " << et << ", memory::format::nc);\n"; auto input_desc = mkldnn_emitter->build_memory_descriptor(args[2], input_format);
writer << "memory::desc diff_weights_desc = memory::desc({" << join(weights_shape) auto mean_desc = mkldnn_emitter->build_memory_descriptor(args[3], mean_format);
<< "}, " << et << ", memory::format::nc);\n"; auto variance_desc =
writer << "memory::desc result_desc = memory::desc({" << join(result_shape) << "}, " mkldnn_emitter->build_memory_descriptor(args[4], variance_format);
<< et << ", memory::format::nchw);\n"; auto delta_desc = mkldnn_emitter->build_memory_descriptor(args[5], delta_format);
writer << "memory::desc mean_desc = memory::desc({" << join(mean_shape) << "}, " auto dinput_desc = mkldnn_emitter->build_memory_descriptor(out[0], dinput_format);
<< et << ", memory::format::x);\n"; auto dweights_desc = mkldnn_emitter->build_memory_descriptor(
writer << "memory::desc variance_desc = memory::desc({" << join(variance_shape) weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
<< "}, " << et << ", memory::format::x);\n";
writer << "memory::desc delta_desc = memory::desc({" << join(input_shape) << "}, " auto batchnorm_index =
<< et << ", memory::format::nchw);\n"; mkldnn_emitter->build_batchnorm_backward(weights_desc,
input_desc,
// Define memory for the user data mean_desc,
writer << "memory input_data = memory({input_data_desc, cpu_engine}, " variance_desc,
delta_desc,
dinput_desc,
dweights_desc,
batchnorm->get_eps_value());
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", bn_weights.data());\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1]) << ", "
<< args[2].get_name() << ");\n"; << args[2].get_name() << ");\n";
writer << "memory weights = memory({weights_desc, cpu_engine}, bn_weights.data()" writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", "
<< ");\n"; << args[3].get_name() << ");\n";
writer << "memory diff_weights = memory({diff_weights_desc, cpu_engine}, " writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3]) << ", "
"vdiff_weights.data()"
<< ");\n";
writer << "memory mean = memory({mean_desc, cpu_engine}, " << args[3].get_name()
<< ");\n";
writer << "memory variance = memory({variance_desc, cpu_engine}, "
<< args[4].get_name() << ");\n"; << args[4].get_name() << ");\n";
writer << "memory delta = memory({delta_desc, cpu_engine}, " << args[5].get_name() writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[4]) << ", "
<< ");\n"; << args[5].get_name() << ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, " << out[0].get_name() writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[5]) << ", "
<< ");\n"; << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[6])
<< ", bn_dweights.data());\n";
//create fprop batchnorm descriptor writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
writer << "batch_normalization_forward::desc bn_fprop_desc = " << to_string(batchnorm_index) << ");\n";
"batch_normalization_forward::desc(forward_training,"
<< "input_data_desc, epsilon, use_scale_shift);\n";
//bn fprop primitive descriptor
writer
<< "batch_normalization_forward::primitive_desc bn_fprop_prim_desc = "
"batch_normalization_forward::primitive_desc(bn_fprop_desc, cpu_engine);\n";
//create bprop batchnorm descriptor
writer << "batch_normalization_backward::desc bn_bprop_desc = "
"batch_normalization_backward::desc(backward, delta_desc, "
"input_data_desc, epsilon, use_scale_shift);\n";
//bn bprop primitive descriptor
writer << "batch_normalization_backward::primitive_desc bn_bprop_prim_desc = "
"batch_normalization_backward::primitive_desc(bn_bprop_desc, cpu_engine, "
"bn_fprop_prim_desc);\n";
//create a batchnorm fprop primitive
writer << " batch_normalization_backward bn_bprop = "
"batch_normalization_backward(bn_bprop_prim_desc, input_data, mean, "
"variance, delta, weights, result, diff_weights);\n ";
//create stream and execute
writer << "stream s = stream(stream::kind::eager);\n"
<< "s.submit({bn_bprop}).wait();\n";
writer << "memcpy(" << out[1].get_name() << ",&vdiff_weights[0],"
<< args[1].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(" << out[2].get_name() << ",&vdiff_weights[0] + "
<< args[1].get_size() << ","
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
writer << "memcpy(" << out[1].get_name() << ", &bn_dweights[0], "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(" << out[2].get_name() << ", &bn_dweights[0]+"
<< args[0].get_size() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
} }
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp" #include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/types/element_type.hpp"
using namespace ngraph::runtime::cpu; using namespace ngraph::runtime::cpu;
...@@ -58,6 +59,15 @@ mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrap ...@@ -58,6 +59,15 @@ mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrap
return build_memory_descriptor(tvw, layout->get_mkldnn_format()); return build_memory_descriptor(tvw, layout->get_mkldnn_format());
} }
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const Shape& shape,
const ngraph::element::Type& et,
mkldnn::memory::format fmt) const
{
return mkldnn::memory::desc(mkldnn::memory::dims(shape.begin(), shape.end()),
mkldnn_utils::get_mkldnn_data_type(et),
fmt);
}
mkldnn::memory MKLDNNEmitter::build_memory_primitive(const TensorViewWrapper& tvw) const mkldnn::memory MKLDNNEmitter::build_memory_primitive(const TensorViewWrapper& tvw) const
{ {
return mkldnn::memory({build_memory_descriptor(tvw), mkldnn_utils::global_cpu_engine}, nullptr); return mkldnn::memory({build_memory_descriptor(tvw), mkldnn_utils::global_cpu_engine}, nullptr);
...@@ -327,3 +337,80 @@ size_t MKLDNNEmitter::build_elementwise_add( ...@@ -327,3 +337,80 @@ size_t MKLDNNEmitter::build_elementwise_add(
m_primitive_deps[add_index] = {input0_data_index, input1_data_index, result_index}; m_primitive_deps[add_index] = {input0_data_index, input1_data_index, result_index};
return add_index; return add_index;
} }
size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const double eps)
{
size_t input_index = build_memory_primitive(input_desc);
size_t weights_index = build_memory_primitive(weights_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t mean_index = build_memory_primitive(mean_desc);
size_t variance_index = build_memory_primitive(variance_desc);
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward(
{{mkldnn::prop_kind::forward_training,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift},
mkldnn_utils::global_cpu_engine},
mkldnn::primitive::at(*m_mkldnn_primitives[input_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[weights_index]),
static_cast<mkldnn::memory>(*m_mkldnn_primitives[result_index]),
*m_mkldnn_primitives[mean_index],
*m_mkldnn_primitives[variance_index]));
m_primitive_deps[batchnorm_index] = {
input_index, weights_index, result_index, mean_index, variance_index};
return batchnorm_index;
}
size_t MKLDNNEmitter::build_batchnorm_backward(const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& dinput_desc,
const mkldnn::memory::desc& dweights_desc,
const double eps)
{
size_t weights_index = build_memory_primitive(weights_desc);
size_t input_index = build_memory_primitive(input_desc);
size_t mean_index = build_memory_primitive(mean_desc);
size_t variance_index = build_memory_primitive(variance_desc);
size_t delta_index = build_memory_primitive(delta_desc);
size_t dinput_index = build_memory_primitive(dinput_desc);
size_t dweights_index = build_memory_primitive(dweights_desc);
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_backward(
{{mkldnn::prop_kind::backward,
delta_desc,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift},
mkldnn_utils::global_cpu_engine,
{{mkldnn::prop_kind::forward_training,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift},
mkldnn_utils::global_cpu_engine}},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[mean_index],
*m_mkldnn_primitives[variance_index],
*m_mkldnn_primitives[delta_index],
*m_mkldnn_primitives[weights_index],
*m_mkldnn_primitives[dinput_index],
*m_mkldnn_primitives[dweights_index]));
m_primitive_deps[batchnorm_index] = {weights_index,
input_index,
mean_index,
variance_index,
delta_index,
dinput_index,
dweights_index};
return batchnorm_index;
}
\ No newline at end of file
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "ngraph/coordinate_diff.hpp" #include "ngraph/coordinate_diff.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/strides.hpp" #include "ngraph/strides.hpp"
#include "ngraph/types/element_type.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -48,6 +49,9 @@ namespace ngraph ...@@ -48,6 +49,9 @@ namespace ngraph
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw,
mkldnn::memory::format fmt) const; mkldnn::memory::format fmt) const;
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw) const; mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw) const;
mkldnn::memory::desc build_memory_descriptor(const Shape& shape,
const ngraph::element::Type& et,
mkldnn::memory::format fmt) const;
mkldnn::memory build_memory_primitive(const TensorViewWrapper& tvw) const; mkldnn::memory build_memory_primitive(const TensorViewWrapper& tvw) const;
size_t build_memory_primitive(const mkldnn::memory::desc& desc); size_t build_memory_primitive(const mkldnn::memory::desc& desc);
...@@ -107,6 +111,22 @@ namespace ngraph ...@@ -107,6 +111,22 @@ namespace ngraph
const std::vector<float>& scale_vector, const std::vector<float>& scale_vector,
const std::vector<mkldnn::memory::primitive_desc>& input_pd); const std::vector<mkldnn::memory::primitive_desc>& input_pd);
size_t build_batchnorm_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const double eps);
size_t build_batchnorm_backward(const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& dinput_desc,
const mkldnn::memory::desc& dweights_desc,
const double eps);
private: private:
std::vector<mkldnn::primitive*> m_mkldnn_primitives; std::vector<mkldnn::primitive*> m_mkldnn_primitives;
std::vector<mkldnn::stream> m_mkldnn_streams; std::vector<mkldnn::stream> m_mkldnn_streams;
......
...@@ -250,6 +250,16 @@ namespace ngraph ...@@ -250,6 +250,16 @@ namespace ngraph
op_annotations->set_mkldnn_op(true); op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations); batchnorm->set_op_annotations(op_annotations);
} }
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNormBackprop)
{
auto batchnorm = static_cast<op::BatchNormBackprop*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
}
} }
} }
} }
...@@ -260,6 +270,8 @@ namespace ngraph ...@@ -260,6 +270,8 @@ namespace ngraph
static const runtime::cpu::pass::AssignOpMap s_dispatcher{ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Add>}, {TI(ngraph::op::Add), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Add>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNorm>}, {TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNorm>},
{TI(ngraph::op::BatchNormBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNormBackprop>},
{TI(ngraph::op::Convolution), {TI(ngraph::op::Convolution),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionBackpropData), {TI(ngraph::op::ConvolutionBackpropData),
......
...@@ -737,6 +737,36 @@ namespace ngraph ...@@ -737,6 +737,36 @@ namespace ngraph
} }
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::BatchNormBackprop)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto delta_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 5);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(memory::format::x); // gamma
prim_input_formats.push_back(memory::format::x); // beta
prim_input_formats.push_back(delta_layout); // input
prim_input_formats.push_back(memory::format::x); // mean
prim_input_formats.push_back(memory::format::x); // variance
prim_input_formats.push_back(delta_layout); // delta
prim_output_formats.push_back(delta_layout); // dinput
prim_output_formats.push_back(memory::format::x); // dgamma
prim_output_formats.push_back(memory::format::x); // dbeta
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
throw ngraph_error("Batchnorm Backprop only supported in MKLDNN for now");
}
}
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Add) void CPULayout::LAYOUT_DECL(ngraph::op::Add)
{ {
...@@ -777,6 +807,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{ ...@@ -777,6 +807,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::AvgPoolBackprop), {TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNorm>}, {TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNorm>},
{TI(ngraph::op::BatchNormBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNormBackprop>},
{TI(ngraph::op::GetOutputElement), {TI(ngraph::op::GetOutputElement),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GetOutputElement>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::GetOutputElement>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>}, {TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment