Unverified Commit 6667bc0e authored by Fenglei's avatar Fenglei Committed by GitHub

Merge branch 'master' into tfl/gpu_dot_back

parents eec717c0 34a8b27d
......@@ -121,13 +121,6 @@ static string eigen_matrix_format(const ngraph::Shape& shape, const ngraph::Stri
return ss.str();
}
void runtime::cpu::CPU_Emitter::emit_mkldnn_preamble(codegen::CodeWriter& writer)
{
writer << "// MKLDNN Preamble\n";
writer << "#include <mkldnn.hpp>\n";
writer << "using namespace mkldnn;\n\n";
}
namespace ngraph
{
namespace runtime
......@@ -380,99 +373,55 @@ namespace ngraph
const ngraph::op::BatchNorm* batchnorm =
static_cast<const ngraph::op::BatchNorm*>(node);
// get the shape of all the inputs and output to batchnorm
auto gamma_shape = args[0].get_shape();
auto beta_shape = args[1].get_shape();
auto input_shape = args[2].get_shape();
auto result_shape = out[0].get_shape();
auto mean_shape = out[1].get_shape();
auto variance_shape = out[2].get_shape();
// get input element type
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
args[2].get_element_type());
const string& gamma_format = runtime::cpu::mkldnn_utils::get_mkldnn_format_string(
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
const string& beta_format = runtime::cpu::mkldnn_utils::get_mkldnn_format_string(
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1));
if (gamma_format.compare("memory::format::x") != 0 &&
beta_format.compare("memory::format::x") != 0)
{
throw std::runtime_error(
"gamma layout->" + gamma_format + ", beta layout->" + beta_format +
" should match and both should have memory::format::x format");
}
writer << "{\n";
writer.indent++;
writer << "{\n";
// define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(2*" << input_shape[1] << ");\n";
auto weights_shape = Shape{2, input_shape[1]};
<< ">bn_weights(2*" << args[0].get_size() << ");\n";
writer << "memcpy(&bn_weights[0], " << args[0].get_name() << ", "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(&bn_weights[0]+" << args[0].get_size() << ", "
<< args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
// push gamma and beta
writer << "auto gamma = " << args[0].get_name() << ";\n";
writer << "auto beta = " << args[1].get_name() << ";\n";
auto input_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
auto result_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto mean_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 1);
auto variance_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 2);
writer << "memcpy(&bn_weights[0], gamma,"
<< args[1].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(&bn_weights[0]+" << args[1].get_size() << ", beta, "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto weights_shape = Shape{2, args[0].get_size()};
auto input_desc = mkldnn_emitter->build_memory_descriptor(args[2], input_format);
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
auto results_desc = mkldnn_emitter->build_memory_descriptor(out[0], result_format);
auto mean_desc = mkldnn_emitter->build_memory_descriptor(out[1], mean_format);
auto variance_desc =
mkldnn_emitter->build_memory_descriptor(out[2], variance_format);
auto batchnorm_index =
mkldnn_emitter->build_batchnorm_forward(input_desc,
weights_desc,
results_desc,
mean_desc,
variance_desc,
batchnorm->get_eps_value());
// get the eps value from the bn node
writer << "auto epsilon = " << batchnorm->get_eps_value() << ";\n";
const string& input_format = runtime::cpu::mkldnn_utils::get_mkldnn_format_string(
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2));
const string& result_format = runtime::cpu::mkldnn_utils::get_mkldnn_format_string(
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
// Bind to CPU engine
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
// create memory descriptors
writer << "memory::desc input_data_desc = memory::desc({" << join(input_shape)
<< "}, " << et << ", " << input_format << ");\n";
// TODO define weights by stacking gamma and beta values
writer << "memory::desc weights_desc = memory::desc({" << join(weights_shape)
<< "}, " << et << ", memory::format::nc);\n";
writer << "memory::desc result_desc = memory::desc({" << join(result_shape) << "}, "
<< et << ", " << result_format << ");\n";
writer << "memory::desc mean_desc = memory::desc({" << join(mean_shape) << "}, "
<< et << ", memory::format::x);\n";
writer << "memory::desc variance_desc = memory::desc({" << join(variance_shape)
<< "}, " << et << ", memory::format::x);\n";
// Define memory for the user data
writer << "memory input_data = memory({input_data_desc, cpu_engine}, "
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", "
<< args[2].get_name() << ");\n";
writer << "memory weights = memory({weights_desc, cpu_engine}, bn_weights.data()"
<< ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, " << out[0].get_name()
<< ");\n";
writer << "memory mean = memory({mean_desc, cpu_engine}, " << out[1].get_name()
<< ");\n";
writer << "memory variance = memory({variance_desc, cpu_engine}, "
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", bn_weights.data());\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", "
<< out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3]) << ", "
<< out[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[4]) << ", "
<< out[2].get_name() << ");\n";
// create batchnorm descriptor
writer << "batch_normalization_forward::desc bn_fprop_desc = "
"batch_normalization_forward::desc(forward_training,"
<< "input_data_desc, epsilon, use_scale_shift);\n";
// bn fprop primitive descriptor
writer
<< "batch_normalization_forward::primitive_desc bn_fprop_prim_desc = "
"batch_normalization_forward::primitive_desc(bn_fprop_desc, cpu_engine);\n";
// create a batchnorm fprop primitive
writer << "batch_normalization_forward bn_fprop = "
"batch_normalization_forward(bn_fprop_prim_desc, "
"primitive::at(input_data),"
<< "primitive::at(weights), result, mean, variance); \n";
// create stream and execute
writer << "stream s = stream(stream::kind::eager);\n"
<< "s.submit({bn_fprop}).wait();\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(batchnorm_index) << ");\n";
writer.indent--;
writer << "}\n";
}
......@@ -482,108 +431,74 @@ namespace ngraph
{
const ngraph::op::BatchNormBackprop* batchnorm =
static_cast<const ngraph::op::BatchNormBackprop*>(node);
auto gamma_shape = args[0].get_shape();
auto beta_shape = args[1].get_shape();
auto input_shape = args[2].get_shape();
auto mean_shape = args[3].get_shape();
auto variance_shape = args[4].get_shape();
auto delta_shape = args[5].get_shape();
auto result_shape = out[0].get_shape();
// get input element type
const string& et =
mkldnn_utils::get_mkldnn_data_type_string(args[2].get_element_type());
writer << "{\n";
writer.indent++;
writer << "{\n";
// define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(" << input_shape[1] * 2 << ");\n";
<< ">bn_weights(2*" << args[0].get_size() << ");\n";
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">vdiff_weights(" << input_shape[1] * 2 << ");\n";
auto weights_shape = Shape{2, input_shape[1]};
// push gamma and beta
writer << "auto gamma = " << args[0].get_name() << ";\n";
writer << "auto beta = " << args[1].get_name() << ";\n";
<< ">bn_dweights(2*" << args[0].get_size() << ");\n";
writer << "memcpy(&bn_weights[0], gamma,"
<< args[1].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(&bn_weights[0]+" << args[1].get_size() << ", beta, "
writer << "memcpy(&bn_weights[0], " << args[0].get_name() << ", "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(&bn_weights[0]+" << args[0].get_size() << ", "
<< args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
// get the eps value from the bn node
writer << "auto epsilon = " << batchnorm->get_eps_value() << ";\n";
// Bind to CPU engine
writer << "using namespace mkldnn; \n";
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
// create memory descriptors
writer << "memory::desc input_data_desc = memory::desc({" << join(input_shape)
<< "}, " << et << ", memory::format::nchw);\n";
// TODO define weights by stacking gamma and beta values
writer << "memory::desc weights_desc = memory::desc({" << join(weights_shape)
<< "}, " << et << ", memory::format::nc);\n";
writer << "memory::desc diff_weights_desc = memory::desc({" << join(weights_shape)
<< "}, " << et << ", memory::format::nc);\n";
writer << "memory::desc result_desc = memory::desc({" << join(result_shape) << "}, "
<< et << ", memory::format::nchw);\n";
writer << "memory::desc mean_desc = memory::desc({" << join(mean_shape) << "}, "
<< et << ", memory::format::x);\n";
writer << "memory::desc variance_desc = memory::desc({" << join(variance_shape)
<< "}, " << et << ", memory::format::x);\n";
writer << "memory::desc delta_desc = memory::desc({" << join(input_shape) << "}, "
<< et << ", memory::format::nchw);\n";
// Define memory for the user data
writer << "memory input_data = memory({input_data_desc, cpu_engine}, "
auto input_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
auto mean_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 3);
auto variance_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 4);
auto delta_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 5);
auto dinput_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto weights_shape = Shape{2, args[0].get_size()};
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
auto input_desc = mkldnn_emitter->build_memory_descriptor(args[2], input_format);
auto mean_desc = mkldnn_emitter->build_memory_descriptor(args[3], mean_format);
auto variance_desc =
mkldnn_emitter->build_memory_descriptor(args[4], variance_format);
auto delta_desc = mkldnn_emitter->build_memory_descriptor(args[5], delta_format);
auto dinput_desc = mkldnn_emitter->build_memory_descriptor(out[0], dinput_format);
auto dweights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
auto batchnorm_index =
mkldnn_emitter->build_batchnorm_backward(weights_desc,
input_desc,
mean_desc,
variance_desc,
delta_desc,
dinput_desc,
dweights_desc,
batchnorm->get_eps_value());
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", bn_weights.data());\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1]) << ", "
<< args[2].get_name() << ");\n";
writer << "memory weights = memory({weights_desc, cpu_engine}, bn_weights.data()"
<< ");\n";
writer << "memory diff_weights = memory({diff_weights_desc, cpu_engine}, "
"vdiff_weights.data()"
<< ");\n";
writer << "memory mean = memory({mean_desc, cpu_engine}, " << args[3].get_name()
<< ");\n";
writer << "memory variance = memory({variance_desc, cpu_engine}, "
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", "
<< args[3].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3]) << ", "
<< args[4].get_name() << ");\n";
writer << "memory delta = memory({delta_desc, cpu_engine}, " << args[5].get_name()
<< ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, " << out[0].get_name()
<< ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[4]) << ", "
<< args[5].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[5]) << ", "
<< out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[6])
<< ", bn_dweights.data());\n";
//create fprop batchnorm descriptor
writer << "batch_normalization_forward::desc bn_fprop_desc = "
"batch_normalization_forward::desc(forward_training,"
<< "input_data_desc, epsilon, use_scale_shift);\n";
//bn fprop primitive descriptor
writer
<< "batch_normalization_forward::primitive_desc bn_fprop_prim_desc = "
"batch_normalization_forward::primitive_desc(bn_fprop_desc, cpu_engine);\n";
//create bprop batchnorm descriptor
writer << "batch_normalization_backward::desc bn_bprop_desc = "
"batch_normalization_backward::desc(backward, delta_desc, "
"input_data_desc, epsilon, use_scale_shift);\n";
//bn bprop primitive descriptor
writer << "batch_normalization_backward::primitive_desc bn_bprop_prim_desc = "
"batch_normalization_backward::primitive_desc(bn_bprop_desc, cpu_engine, "
"bn_fprop_prim_desc);\n";
//create a batchnorm fprop primitive
writer << " batch_normalization_backward bn_bprop = "
"batch_normalization_backward(bn_bprop_prim_desc, input_data, mean, "
"variance, delta, weights, result, diff_weights);\n ";
//create stream and execute
writer << "stream s = stream(stream::kind::eager);\n"
<< "s.submit({bn_bprop}).wait();\n";
writer << "memcpy(" << out[1].get_name() << ",&vdiff_weights[0],"
<< args[1].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(" << out[2].get_name() << ",&vdiff_weights[0] + "
<< args[1].get_size() << ","
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(batchnorm_index) << ");\n";
writer << "memcpy(" << out[1].get_name() << ", &bn_dweights[0], "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(" << out[2].get_name() << ", &bn_dweights[0]+"
<< args[0].get_size() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
writer.indent--;
writer << "}\n";
}
......@@ -3265,78 +3180,29 @@ namespace ngraph
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ReluBackprop)
{
const auto& arg_shape = args[0].get_shape();
const auto& result_shape = out[0].get_shape();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
args[0].get_element_type());
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto delta_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1);
if (!runtime::cpu::mkldnn_utils::compare_mkldnn_formats(input_format,
delta_format))
{
throw ngraph_error(
"mkldnn emitter: Relu backprop fprop input and delta layouts should be "
"the same");
}
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
auto delta_desc = mkldnn_emitter->build_memory_descriptor(
args[1], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1));
auto result_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
writer << "{\n";
writer.indent++;
size_t relu_index =
mkldnn_emitter->build_relu_backward(input_desc, delta_desc, result_desc);
writer << "try {\n";
writer.indent++;
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
writer << "memory::desc input_data_desc = memory::desc({" << join(arg_shape)
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(input_format)
<< ");\n";
writer << "memory::desc delta_data_desc = memory::desc({"
<< join(args[1].get_shape()) << "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(delta_format)
<< ");\n";
writer << "memory::desc result_desc = memory::desc({" << join(result_shape)
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(result_format)
<< ");\n";
auto& deps = mkldnn_emitter->get_primitive_deps(relu_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << args[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ", " << out[0].get_name() << ");\n";
writer << "memory input_data = memory({input_data_desc, cpu_engine}, "
<< args[0].get_name() << ");\n";
writer << "memory delta_data = memory({delta_data_desc, cpu_engine}, "
<< args[1].get_name() << ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, "
<< out[0].get_name() << ");\n";
writer << "relu_forward::desc relu_fwd_desc = "
"relu_forward::desc(prop_kind::forward, "
"algorithm::eltwise_relu, input_data_desc, 0, 0);\n";
writer << "relu_forward::primitive_desc relu_fwd_prim_desc = "
"relu_forward::primitive_desc(relu_fwd_desc, cpu_engine);\n";
writer << "relu_backward::desc relu_bwd_desc = "
"relu_backward::desc(algorithm::eltwise_relu, "
"delta_data_desc, input_data_desc, 0, 0);\n";
writer << "relu_backward::primitive_desc relu_bdw_prim_desc = "
"relu_backward::primitive_desc(relu_bwd_desc, cpu_engine, "
"relu_fwd_prim_desc);\n";
writer
<< "relu_backward relu_bwd= relu_backward(relu_bdw_prim_desc, input_data, "
"delta_data, result);\n";
writer << "stream s = stream(stream::kind::eager);\n"
"s.submit({relu_bwd}).wait();\n";
writer.indent--;
writer << "} catch (const mkldnn::error& e) {\n";
writer.indent++;
writer << "throw ngraph::ngraph_error(\"MKLDNN ERROR (\" + std::to_string("
"e.status) + \"): \" + e.message);\n";
writer.indent--;
writer << "}\n";
writer.indent--;
writer << "}\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(relu_index) << ");\n";
}
else
{
......
......@@ -58,8 +58,6 @@ namespace ngraph
{
}
static void emit_mkldnn_preamble(codegen::CodeWriter& writer);
private:
static std::string emit_vector(const TensorViewWrapper&,
const std::string& name = "");
......
......@@ -293,18 +293,6 @@ void runtime::cpu::CPU_ExternalFunction::compile()
codegen::CodeWriter writer;
bool include_mkldnn_headers = false;
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{
for (shared_ptr<Node> node : current_function->get_ordered_ops())
{
if (ngraph::runtime::cpu::mkldnn_utils::IsMKLDNNOp(*node))
{
include_mkldnn_headers = true;
}
}
}
writer +=
R"(// Generated by the nGraph CPU backend
#include <cmath>
......@@ -354,11 +342,6 @@ using namespace ngraph::runtime;
writer << "#include <tbb/flow_graph.h>\n";
}
if (include_mkldnn_headers)
{
runtime::cpu::CPU_Emitter::emit_mkldnn_preamble(writer);
}
string pch_header_source = writer.get_code();
// The "dso_handle" symbol is required by __cxa_atexit()
......
......@@ -23,6 +23,7 @@
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/types/element_type.hpp"
using namespace ngraph::runtime::cpu;
......@@ -77,6 +78,15 @@ mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrap
return build_memory_descriptor(tvw, layout->get_mkldnn_format());
}
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const Shape& shape,
const ngraph::element::Type& et,
mkldnn::memory::format fmt) const
{
return mkldnn::memory::desc(mkldnn::memory::dims(shape.begin(), shape.end()),
mkldnn_utils::get_mkldnn_data_type(et),
fmt);
}
mkldnn::memory MKLDNNEmitter::build_memory_primitive(const TensorViewWrapper& tvw) const
{
return mkldnn::memory({build_memory_descriptor(tvw), mkldnn_utils::global_cpu_engine}, nullptr);
......@@ -462,6 +472,27 @@ size_t MKLDNNEmitter::build_relu_forward(const mkldnn::memory::desc& input_desc,
return primitive_index;
}
size_t MKLDNNEmitter::build_relu_backward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& result_desc)
{
size_t input_index = build_memory_primitive(input_desc);
size_t delta_index = build_memory_primitive(delta_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index = insert_primitive(new mkldnn::relu_backward(
{{mkldnn::algorithm::eltwise_relu, delta_desc, input_desc, 0, 0},
mkldnn_utils::global_cpu_engine,
{{mkldnn::prop_kind::forward, mkldnn::algorithm::eltwise_relu, input_desc, 0, 0},
mkldnn_utils::global_cpu_engine}},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[delta_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, delta_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::build_sigmoid_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc)
{
......@@ -509,3 +540,80 @@ size_t MKLDNNEmitter::build_elementwise_add(
m_primitive_deps[add_index] = {input0_data_index, input1_data_index, result_index};
return add_index;
}
size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const double eps)
{
size_t input_index = build_memory_primitive(input_desc);
size_t weights_index = build_memory_primitive(weights_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t mean_index = build_memory_primitive(mean_desc);
size_t variance_index = build_memory_primitive(variance_desc);
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward(
{{mkldnn::prop_kind::forward_training,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift},
mkldnn_utils::global_cpu_engine},
mkldnn::primitive::at(*m_mkldnn_primitives[input_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[weights_index]),
static_cast<mkldnn::memory>(*m_mkldnn_primitives[result_index]),
*m_mkldnn_primitives[mean_index],
*m_mkldnn_primitives[variance_index]));
m_primitive_deps[batchnorm_index] = {
input_index, weights_index, result_index, mean_index, variance_index};
return batchnorm_index;
}
size_t MKLDNNEmitter::build_batchnorm_backward(const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& dinput_desc,
const mkldnn::memory::desc& dweights_desc,
const double eps)
{
size_t weights_index = build_memory_primitive(weights_desc);
size_t input_index = build_memory_primitive(input_desc);
size_t mean_index = build_memory_primitive(mean_desc);
size_t variance_index = build_memory_primitive(variance_desc);
size_t delta_index = build_memory_primitive(delta_desc);
size_t dinput_index = build_memory_primitive(dinput_desc);
size_t dweights_index = build_memory_primitive(dweights_desc);
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_backward(
{{mkldnn::prop_kind::backward,
delta_desc,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift},
mkldnn_utils::global_cpu_engine,
{{mkldnn::prop_kind::forward_training,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift},
mkldnn_utils::global_cpu_engine}},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[mean_index],
*m_mkldnn_primitives[variance_index],
*m_mkldnn_primitives[delta_index],
*m_mkldnn_primitives[weights_index],
*m_mkldnn_primitives[dinput_index],
*m_mkldnn_primitives[dweights_index]));
m_primitive_deps[batchnorm_index] = {weights_index,
input_index,
mean_index,
variance_index,
delta_index,
dinput_index,
dweights_index};
return batchnorm_index;
}
......@@ -25,6 +25,7 @@
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp"
#include "ngraph/types/element_type.hpp"
namespace ngraph
{
......@@ -60,6 +61,9 @@ namespace ngraph
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw,
mkldnn::memory::format fmt) const;
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw) const;
mkldnn::memory::desc build_memory_descriptor(const Shape& shape,
const ngraph::element::Type& et,
mkldnn::memory::format fmt) const;
mkldnn::memory build_memory_primitive(const TensorViewWrapper& tvw) const;
size_t build_memory_primitive(const mkldnn::memory::desc& desc);
......@@ -142,6 +146,10 @@ namespace ngraph
size_t build_relu_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
size_t build_relu_backward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& result_desc);
size_t build_sigmoid_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
......@@ -152,6 +160,22 @@ namespace ngraph
const std::vector<float>& scale_vector,
const std::vector<mkldnn::memory::primitive_desc>& input_pd);
size_t build_batchnorm_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const double eps);
size_t build_batchnorm_backward(const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& dinput_desc,
const mkldnn::memory::desc& dweights_desc,
const double eps);
private:
std::vector<mkldnn::primitive*> m_mkldnn_primitives;
std::vector<mkldnn::stream> m_mkldnn_streams;
......
......@@ -344,6 +344,16 @@ namespace ngraph
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNormBackprop)
{
auto batchnorm = static_cast<op::BatchNormBackprop*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
}
}
}
}
......@@ -357,6 +367,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNorm>},
{TI(ngraph::op::BatchNormBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNormBackprop>},
{TI(ngraph::op::Convolution),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionBackpropData),
......
......@@ -1009,6 +1009,36 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::BatchNormBackprop)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto delta_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 5);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(memory::format::x); // gamma
prim_input_formats.push_back(memory::format::x); // beta
prim_input_formats.push_back(delta_layout); // input
prim_input_formats.push_back(memory::format::x); // mean
prim_input_formats.push_back(memory::format::x); // variance
prim_input_formats.push_back(delta_layout); // delta
prim_output_formats.push_back(delta_layout); // dinput
prim_output_formats.push_back(memory::format::x); // dgamma
prim_output_formats.push_back(memory::format::x); // dbeta
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
throw ngraph_error("Batchnorm Backprop only supported in MKLDNN for now");
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Add)
{
......@@ -1056,6 +1086,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBiasBackpropFiltersBias>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNorm>},
{TI(ngraph::op::BatchNormBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNormBackprop>},
{TI(ngraph::op::GetOutputElement),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GetOutputElement>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment