Commit f0bc6c12 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Support non-mkldnn fallback for Batchnorm training bprop in CPU backend (#3688)

* Support non-mkldnn fallback for Batchnorm training bprop in CPU backend

* Skip unit test on PlaidML backend
parent 5a1de88e
......@@ -395,18 +395,87 @@ namespace ngraph
template <>
void Builder::BUILDER_DECL(ngraph::op::BatchNormTrainingBackprop)
{
auto& functors = external_function->get_functors();
if (!mkldnn_utils::use_mkldnn_kernel(node))
{
const ngraph::op::BatchNormTrainingBackprop* batchnorm =
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node);
auto arg0_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto arg2_buffer_index = external_function->get_buffer_index(args[2].get_name());
auto arg3_buffer_index = external_function->get_buffer_index(args[3].get_name());
auto arg4_buffer_index = external_function->get_buffer_index(args[4].get_name());
auto arg5_buffer_index = external_function->get_buffer_index(args[5].get_name());
auto& functors = external_function->get_functors();
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto out1_buffer_index = external_function->get_buffer_index(out[1].get_name());
auto out2_buffer_index = external_function->get_buffer_index(out[2].get_name());
std::function<decltype(runtime::cpu::kernel::batch_norm_backprop<float>)>
kernel;
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::batch_norm_backprop)
auto arg2_shape = args[2].get_shape();
auto arg0_buffer_index =
external_function->get_buffer_index(args[0].get_name()); /* gamma */
auto arg1_buffer_index =
external_function->get_buffer_index(args[1].get_name()); /* beta */
auto arg2_buffer_index =
external_function->get_buffer_index(args[2].get_name()); /* input */
auto arg3_buffer_index =
external_function->get_buffer_index(args[3].get_name()); /* mean */
auto arg4_buffer_index =
external_function->get_buffer_index(args[4].get_name()); /* variance */
auto arg5_buffer_index =
external_function->get_buffer_index(args[5].get_name()); /* delta */
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto out1_buffer_index = external_function->get_buffer_index(out[1].get_name());
auto out2_buffer_index = external_function->get_buffer_index(out[2].get_name());
auto eps = batchnorm->get_eps_value();
auto functor = [&,
kernel,
arg2_shape,
eps,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
arg3_buffer_index,
arg4_buffer_index,
arg5_buffer_index,
out0_buffer_index,
out1_buffer_index,
out2_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
kernel(eps,
ctx->buffer_data[arg0_buffer_index],
ctx->buffer_data[arg1_buffer_index],
ctx->buffer_data[arg2_buffer_index],
ctx->buffer_data[arg3_buffer_index],
ctx->buffer_data[arg4_buffer_index],
ctx->buffer_data[arg5_buffer_index],
ctx->buffer_data[out0_buffer_index],
ctx->buffer_data[out1_buffer_index],
ctx->buffer_data[out2_buffer_index],
arg2_shape);
};
functors.emplace_back(functor);
}
else
{
auto& functors = external_function->get_functors();
auto arg0_buffer_index =
external_function->get_buffer_index(args[0].get_name());
auto arg1_buffer_index =
external_function->get_buffer_index(args[1].get_name());
auto arg2_buffer_index =
external_function->get_buffer_index(args[2].get_name());
auto arg3_buffer_index =
external_function->get_buffer_index(args[3].get_name());
auto arg4_buffer_index =
external_function->get_buffer_index(args[4].get_name());
auto arg5_buffer_index =
external_function->get_buffer_index(args[5].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto out1_buffer_index = external_function->get_buffer_index(out[1].get_name());
auto out2_buffer_index = external_function->get_buffer_index(out[2].get_name());
// Kill clang diagnostics bug
#if defined(__clang__)
......@@ -414,101 +483,107 @@ namespace ngraph
#pragma clang diagnostic ignored "-Wmissing-braces"
#endif
array<size_t, 2> weight_sizes{
args[0].get_size() * args[0].get_element_type().size(),
args[1].get_size() * args[1].get_element_type().size()};
array<size_t, 2> weight_sizes{
args[0].get_size() * args[0].get_element_type().size(),
args[1].get_size() * args[1].get_element_type().size()};
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
shared_ptr<uint8_t> stacked_weights(new uint8_t[weight_sizes[0] + weight_sizes[1]],
std::default_delete<uint8_t[]>());
shared_ptr<uint8_t> stacked_dweights(new uint8_t[weight_sizes[0] + weight_sizes[1]],
std::default_delete<uint8_t[]>());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto batchnorm_desc = mkldnn_emitter->get_batchnorm_backward_desc(node);
auto weights_shape = Shape{2, args[0].get_size()};
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::FORMAT::nc);
auto dweights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::FORMAT::nc);
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 2);
// batchnorm backward needs 8 primitives: weights, input, mean, variance,
// dinput, dweights, and batch_normalization_backward.
auto batchnorm_index = mkldnn_emitter->reserve_primitive_space(8);
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
const ngraph::op::BatchNormTrainingBackprop* batchnorm =
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node);
auto eps = batchnorm->get_eps_value();
(void)eps; // Use depends on mkl-dnn version
QUERY_SCRATCHPAD_3ARGS(batchnorm_backward, batchnorm_desc, input_desc, eps);
auto functor = [&,
batchnorm_desc,
input_desc,
weights_desc,
dweights_desc,
batchnorm_index,
stacked_weights,
stacked_dweights,
weight_sizes,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
arg3_buffer_index,
arg4_buffer_index,
arg5_buffer_index,
out0_buffer_index,
out1_buffer_index,
out2_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_batchnorm_backward(ctx->mkldnn_memories,
ctx->mkldnn_primitives,
ctx->mkldnn_scratchpad_mds,
batchnorm_desc,
input_desc,
weights_desc,
dweights_desc,
eps,
deps,
batchnorm_index);
}
memcpy(stacked_weights.get(),
ctx->buffer_data[arg0_buffer_index],
weight_sizes[0]);
memcpy(stacked_weights.get() + weight_sizes[0],
ctx->buffer_data[arg1_buffer_index],
weight_sizes[1]);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], stacked_weights.get());
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[arg2_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[2], ctx->buffer_data[arg3_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[3], ctx->buffer_data[arg4_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[4], ctx->buffer_data[arg5_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[5], ctx->buffer_data[out0_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[6], stacked_dweights.get());
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, batchnorm_index, deps, cpu::mkldnn_utils::OpType::BATCHNORMBACKPROP);
memcpy(ctx->buffer_data[out1_buffer_index],
stacked_dweights.get(),
weight_sizes[0]);
memcpy(ctx->buffer_data[out2_buffer_index],
stacked_dweights.get() + weight_sizes[0],
weight_sizes[1]);
};
functors.emplace_back(functor);
shared_ptr<uint8_t> stacked_weights(
new uint8_t[weight_sizes[0] + weight_sizes[1]],
std::default_delete<uint8_t[]>());
shared_ptr<uint8_t> stacked_dweights(
new uint8_t[weight_sizes[0] + weight_sizes[1]],
std::default_delete<uint8_t[]>());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto batchnorm_desc = mkldnn_emitter->get_batchnorm_backward_desc(node);
auto weights_shape = Shape{2, args[0].get_size()};
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::FORMAT::nc);
auto dweights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::FORMAT::nc);
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 2);
// batchnorm backward needs 8 primitives: weights, input, mean, variance,
// dinput, dweights, and batch_normalization_backward.
auto batchnorm_index = mkldnn_emitter->reserve_primitive_space(8);
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
const ngraph::op::BatchNormTrainingBackprop* batchnorm =
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node);
auto eps = batchnorm->get_eps_value();
(void)eps; // Use depends on mkl-dnn version
QUERY_SCRATCHPAD_3ARGS(batchnorm_backward, batchnorm_desc, input_desc, eps);
auto functor = [&,
batchnorm_desc,
input_desc,
weights_desc,
dweights_desc,
batchnorm_index,
stacked_weights,
stacked_dweights,
weight_sizes,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
arg3_buffer_index,
arg4_buffer_index,
arg5_buffer_index,
out0_buffer_index,
out1_buffer_index,
out2_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_batchnorm_backward(ctx->mkldnn_memories,
ctx->mkldnn_primitives,
ctx->mkldnn_scratchpad_mds,
batchnorm_desc,
input_desc,
weights_desc,
dweights_desc,
eps,
deps,
batchnorm_index);
}
memcpy(stacked_weights.get(),
ctx->buffer_data[arg0_buffer_index],
weight_sizes[0]);
memcpy(stacked_weights.get() + weight_sizes[0],
ctx->buffer_data[arg1_buffer_index],
weight_sizes[1]);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], stacked_weights.get());
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[arg2_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[2], ctx->buffer_data[arg3_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[3], ctx->buffer_data[arg4_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[4], ctx->buffer_data[arg5_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[5], ctx->buffer_data[out0_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[6], stacked_dweights.get());
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx,
batchnorm_index,
deps,
cpu::mkldnn_utils::OpType::BATCHNORMBACKPROP);
memcpy(ctx->buffer_data[out1_buffer_index],
stacked_dweights.get(),
weight_sizes[0]);
memcpy(ctx->buffer_data[out2_buffer_index],
stacked_dweights.get() + weight_sizes[0],
weight_sizes[1]);
};
functors.emplace_back(functor);
}
}
template <>
......
......@@ -825,46 +825,67 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormTrainingBackprop)
{
writer.block_begin();
// define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(2*" << args[0].get_size() << ");\n";
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_dweights(2*" << args[0].get_size() << ");\n";
if (!mkldnn_utils::use_mkldnn_kernel(node))
{
const ngraph::op::BatchNormTrainingBackprop* batchnorm =
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node);
writer << "memcpy(&bn_weights[0], " << args[0].get_name() << ", "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(&bn_weights[0]+" << args[0].get_size() << ", "
<< args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
writer << "reference::batch_norm_backprop(" << batchnorm->get_eps_value()
<< ",\n";
writer << " " << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << args[3].get_name() << ",\n";
writer << " " << args[4].get_name() << ",\n";
writer << " " << args[5].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " " << out[1].get_name() << ",\n";
writer << " " << out[2].get_name() << ",\n";
writer << " {" << join(args[2].get_shape()) << "});\n";
}
else
{
// define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(2*" << args[0].get_size() << ");\n";
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_dweights(2*" << args[0].get_size() << ");\n";
size_t batchnorm_index;
std::vector<std::size_t> deps;
emit_build_primitives(external_function, node, writer, batchnorm_index, deps);
writer << "memcpy(&bn_weights[0], " << args[0].get_name() << ", "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(&bn_weights[0]+" << args[0].get_size() << ", "
<< args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[0])
<< ", bn_weights.data());\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[1]) << ", "
<< args[2].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[2]) << ", "
<< args[3].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[3]) << ", "
<< args[4].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[4]) << ", "
<< args[5].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[5]) << ", "
<< out[0].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[6])
<< ", bn_dweights.data());\n";
size_t batchnorm_index;
std::vector<std::size_t> deps;
emit_build_primitives(external_function, node, writer, batchnorm_index, deps);
writer << "std::vector<size_t> deps{" << join(deps) << "};\n";
writer << "cg_ctx->mkldnn_invoke_primitive(" << to_string(batchnorm_index)
<< ", deps, OpType::BATCHNORMBACKPROP);\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[0])
<< ", bn_weights.data());\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[1]) << ", "
<< args[2].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[2]) << ", "
<< args[3].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[3]) << ", "
<< args[4].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[4]) << ", "
<< args[5].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[5]) << ", "
<< out[0].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[6])
<< ", bn_dweights.data());\n";
writer << "memcpy(" << out[1].get_name() << ", &bn_dweights[0], "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(" << out[2].get_name() << ", &bn_dweights[0]+"
<< args[0].get_size() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
writer << "std::vector<size_t> deps{" << join(deps) << "};\n";
writer << "cg_ctx->mkldnn_invoke_primitive(" << to_string(batchnorm_index)
<< ", deps, OpType::BATCHNORMBACKPROP);\n";
writer << "memcpy(" << out[1].get_name() << ", &bn_dweights[0], "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(" << out[2].get_name() << ", &bn_dweights[0]+"
<< args[0].get_size() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
}
writer.block_end();
}
......
......@@ -66,6 +66,32 @@ namespace ngraph
static_cast<ElementType*>(out0),
arg2_shape);
}
template <typename ElementType>
void batch_norm_backprop(double eps,
const void* arg0,
const void* arg1,
const void* arg2,
const void* arg3,
const void* arg4,
const void* arg5,
void* out0,
void* out1,
void* out2,
const Shape& arg2_shape)
{
reference::batch_norm_backprop(eps,
static_cast<const ElementType*>(arg0),
static_cast<const ElementType*>(arg1),
static_cast<const ElementType*>(arg2),
static_cast<const ElementType*>(arg3),
static_cast<const ElementType*>(arg4),
static_cast<const ElementType*>(arg5),
static_cast<ElementType*>(out0),
static_cast<ElementType*>(out1),
static_cast<ElementType*>(out2),
arg2_shape);
}
}
}
}
......
......@@ -2326,7 +2326,7 @@ namespace ngraph
}
else
{
throw ngraph_error("Batchnorm Backprop only supported in MKLDNN for now");
set_native_layouts(external_function, node);
}
}
......
......@@ -188,6 +188,7 @@ max_trivial_5d_int32
floor_int32
any_trivial
any_2x2x3_eliminate_dim_0
backwards_batch_norm_training_3d
# unsupported op: `GenerateMask`
generate_mask
......
......@@ -1659,7 +1659,7 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2)
ASSERT_TRUE(read_vector<float>(output) == expected);
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training)
NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training_4d)
{
const Shape input_shape{10, 4, 5, 5};
const Shape channel_shape{input_shape.at(1)};
......@@ -1697,6 +1697,44 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training)
autodiff_numeric_compare<T>(backend.get(), make_graph, {input, gamma, beta}, .005, .005));
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training_3d)
{
const Shape input_shape{10, 4, 5};
const Shape channel_shape{input_shape.at(1)};
const double eps = 1e-3;
// Need to keep the output elements for mean and variance from going out of scope
// and getting freed.
NodeVector goes;
auto make_graph = [&input_shape, &channel_shape, &eps, &goes] {
const element::Type& et = element::f32;
auto input = make_shared<op::Parameter>(et, input_shape);
auto gamma = make_shared<op::Parameter>(et, channel_shape);
auto beta = make_shared<op::Parameter>(et, channel_shape);
auto BN = make_shared<op::BatchNormTraining>(input, gamma, beta, eps);
auto normed_input = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 0));
auto mean = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 1));
auto variance = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 2));
goes.push_back(mean);
goes.push_back(variance);
// TODO autodiff testing with more than one result
auto f =
make_shared<Function>(ResultVector{normed_input}, ParameterVector{input, gamma, beta});
return f;
};
auto backend = runtime::Backend::create("${BACKEND_NAME}");
using T = float;
test::Uniform<T> rng(-5.0, 2.0);
auto input = rng.initialize(backend->create_tensor<T>(input_shape));
auto gamma = rng.initialize(backend->create_tensor<T>(channel_shape));
auto beta = rng.initialize(backend->create_tensor<T>(channel_shape));
EXPECT_TRUE(
autodiff_numeric_compare<T>(backend.get(), make_graph, {input, gamma, beta}, .005, .005));
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_reverse_sequence_n3_c2_h3)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment