Commit f0bc6c12 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Support non-mkldnn fallback for Batchnorm training bprop in CPU backend (#3688)

* Support non-mkldnn fallback for Batchnorm training bprop in CPU backend

* Skip unit test on PlaidML backend
parent 5a1de88e
...@@ -395,18 +395,87 @@ namespace ngraph ...@@ -395,18 +395,87 @@ namespace ngraph
template <> template <>
void Builder::BUILDER_DECL(ngraph::op::BatchNormTrainingBackprop) void Builder::BUILDER_DECL(ngraph::op::BatchNormTrainingBackprop)
{ {
auto& functors = external_function->get_functors(); if (!mkldnn_utils::use_mkldnn_kernel(node))
{
const ngraph::op::BatchNormTrainingBackprop* batchnorm =
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node);
auto arg0_buffer_index = external_function->get_buffer_index(args[0].get_name()); auto& functors = external_function->get_functors();
auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto arg2_buffer_index = external_function->get_buffer_index(args[2].get_name());
auto arg3_buffer_index = external_function->get_buffer_index(args[3].get_name());
auto arg4_buffer_index = external_function->get_buffer_index(args[4].get_name());
auto arg5_buffer_index = external_function->get_buffer_index(args[5].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name()); std::function<decltype(runtime::cpu::kernel::batch_norm_backprop<float>)>
auto out1_buffer_index = external_function->get_buffer_index(out[1].get_name()); kernel;
auto out2_buffer_index = external_function->get_buffer_index(out[2].get_name());
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::batch_norm_backprop)
auto arg2_shape = args[2].get_shape();
auto arg0_buffer_index =
external_function->get_buffer_index(args[0].get_name()); /* gamma */
auto arg1_buffer_index =
external_function->get_buffer_index(args[1].get_name()); /* beta */
auto arg2_buffer_index =
external_function->get_buffer_index(args[2].get_name()); /* input */
auto arg3_buffer_index =
external_function->get_buffer_index(args[3].get_name()); /* mean */
auto arg4_buffer_index =
external_function->get_buffer_index(args[4].get_name()); /* variance */
auto arg5_buffer_index =
external_function->get_buffer_index(args[5].get_name()); /* delta */
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto out1_buffer_index = external_function->get_buffer_index(out[1].get_name());
auto out2_buffer_index = external_function->get_buffer_index(out[2].get_name());
auto eps = batchnorm->get_eps_value();
auto functor = [&,
kernel,
arg2_shape,
eps,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
arg3_buffer_index,
arg4_buffer_index,
arg5_buffer_index,
out0_buffer_index,
out1_buffer_index,
out2_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
kernel(eps,
ctx->buffer_data[arg0_buffer_index],
ctx->buffer_data[arg1_buffer_index],
ctx->buffer_data[arg2_buffer_index],
ctx->buffer_data[arg3_buffer_index],
ctx->buffer_data[arg4_buffer_index],
ctx->buffer_data[arg5_buffer_index],
ctx->buffer_data[out0_buffer_index],
ctx->buffer_data[out1_buffer_index],
ctx->buffer_data[out2_buffer_index],
arg2_shape);
};
functors.emplace_back(functor);
}
else
{
auto& functors = external_function->get_functors();
auto arg0_buffer_index =
external_function->get_buffer_index(args[0].get_name());
auto arg1_buffer_index =
external_function->get_buffer_index(args[1].get_name());
auto arg2_buffer_index =
external_function->get_buffer_index(args[2].get_name());
auto arg3_buffer_index =
external_function->get_buffer_index(args[3].get_name());
auto arg4_buffer_index =
external_function->get_buffer_index(args[4].get_name());
auto arg5_buffer_index =
external_function->get_buffer_index(args[5].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto out1_buffer_index = external_function->get_buffer_index(out[1].get_name());
auto out2_buffer_index = external_function->get_buffer_index(out[2].get_name());
// Kill clang diagnostics bug // Kill clang diagnostics bug
#if defined(__clang__) #if defined(__clang__)
...@@ -414,101 +483,107 @@ namespace ngraph ...@@ -414,101 +483,107 @@ namespace ngraph
#pragma clang diagnostic ignored "-Wmissing-braces" #pragma clang diagnostic ignored "-Wmissing-braces"
#endif #endif
array<size_t, 2> weight_sizes{ array<size_t, 2> weight_sizes{
args[0].get_size() * args[0].get_element_type().size(), args[0].get_size() * args[0].get_element_type().size(),
args[1].get_size() * args[1].get_element_type().size()}; args[1].get_size() * args[1].get_element_type().size()};
#if defined(__clang__) #if defined(__clang__)
#pragma clang diagnostic pop #pragma clang diagnostic pop
#endif #endif
shared_ptr<uint8_t> stacked_weights(new uint8_t[weight_sizes[0] + weight_sizes[1]], shared_ptr<uint8_t> stacked_weights(
std::default_delete<uint8_t[]>()); new uint8_t[weight_sizes[0] + weight_sizes[1]],
shared_ptr<uint8_t> stacked_dweights(new uint8_t[weight_sizes[0] + weight_sizes[1]], std::default_delete<uint8_t[]>());
std::default_delete<uint8_t[]>()); shared_ptr<uint8_t> stacked_dweights(
new uint8_t[weight_sizes[0] + weight_sizes[1]],
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); std::default_delete<uint8_t[]>());
auto batchnorm_desc = mkldnn_emitter->get_batchnorm_backward_desc(node);
auto weights_shape = Shape{2, args[0].get_size()}; auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto weights_desc = mkldnn_emitter->build_memory_descriptor( auto batchnorm_desc = mkldnn_emitter->get_batchnorm_backward_desc(node);
weights_shape, args[0].get_element_type(), mkldnn::memory::FORMAT::nc); auto weights_shape = Shape{2, args[0].get_size()};
auto dweights_desc = mkldnn_emitter->build_memory_descriptor( auto weights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::FORMAT::nc); weights_shape, args[0].get_element_type(), mkldnn::memory::FORMAT::nc);
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 2); auto dweights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::FORMAT::nc);
// batchnorm backward needs 8 primitives: weights, input, mean, variance, auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 2);
// dinput, dweights, and batch_normalization_backward.
auto batchnorm_index = mkldnn_emitter->reserve_primitive_space(8); // batchnorm backward needs 8 primitives: weights, input, mean, variance,
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index); // dinput, dweights, and batch_normalization_backward.
auto batchnorm_index = mkldnn_emitter->reserve_primitive_space(8);
const ngraph::op::BatchNormTrainingBackprop* batchnorm = auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node);
auto eps = batchnorm->get_eps_value(); const ngraph::op::BatchNormTrainingBackprop* batchnorm =
(void)eps; // Use depends on mkl-dnn version static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node);
QUERY_SCRATCHPAD_3ARGS(batchnorm_backward, batchnorm_desc, input_desc, eps); auto eps = batchnorm->get_eps_value();
(void)eps; // Use depends on mkl-dnn version
auto functor = [&, QUERY_SCRATCHPAD_3ARGS(batchnorm_backward, batchnorm_desc, input_desc, eps);
batchnorm_desc,
input_desc, auto functor = [&,
weights_desc, batchnorm_desc,
dweights_desc, input_desc,
batchnorm_index, weights_desc,
stacked_weights, dweights_desc,
stacked_dweights, batchnorm_index,
weight_sizes, stacked_weights,
arg0_buffer_index, stacked_dweights,
arg1_buffer_index, weight_sizes,
arg2_buffer_index, arg0_buffer_index,
arg3_buffer_index, arg1_buffer_index,
arg4_buffer_index, arg2_buffer_index,
arg5_buffer_index, arg3_buffer_index,
out0_buffer_index, arg4_buffer_index,
out1_buffer_index, arg5_buffer_index,
out2_buffer_index](CPURuntimeContext* ctx, out0_buffer_index,
CPUExecutionContext* /* ectx */) { out1_buffer_index,
if (ctx->first_iteration) out2_buffer_index](CPURuntimeContext* ctx,
{ CPUExecutionContext* /* ectx */) {
mkldnn_emitter->build_batchnorm_backward(ctx->mkldnn_memories, if (ctx->first_iteration)
ctx->mkldnn_primitives, {
ctx->mkldnn_scratchpad_mds, mkldnn_emitter->build_batchnorm_backward(ctx->mkldnn_memories,
batchnorm_desc, ctx->mkldnn_primitives,
input_desc, ctx->mkldnn_scratchpad_mds,
weights_desc, batchnorm_desc,
dweights_desc, input_desc,
eps, weights_desc,
deps, dweights_desc,
batchnorm_index); eps,
} deps,
memcpy(stacked_weights.get(), batchnorm_index);
ctx->buffer_data[arg0_buffer_index], }
weight_sizes[0]); memcpy(stacked_weights.get(),
memcpy(stacked_weights.get() + weight_sizes[0], ctx->buffer_data[arg0_buffer_index],
ctx->buffer_data[arg1_buffer_index], weight_sizes[0]);
weight_sizes[1]); memcpy(stacked_weights.get() + weight_sizes[0],
ctx->buffer_data[arg1_buffer_index],
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], stacked_weights.get()); weight_sizes[1]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[arg2_buffer_index]); cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], stacked_weights.get());
cpu::mkldnn_utils::set_memory_ptr( cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[2], ctx->buffer_data[arg3_buffer_index]); ctx, deps[1], ctx->buffer_data[arg2_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr( cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[3], ctx->buffer_data[arg4_buffer_index]); ctx, deps[2], ctx->buffer_data[arg3_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr( cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[4], ctx->buffer_data[arg5_buffer_index]); ctx, deps[3], ctx->buffer_data[arg4_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr( cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[5], ctx->buffer_data[out0_buffer_index]); ctx, deps[4], ctx->buffer_data[arg5_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[6], stacked_dweights.get()); cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[5], ctx->buffer_data[out0_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::set_memory_ptr(ctx, deps[6], stacked_dweights.get());
ctx, batchnorm_index, deps, cpu::mkldnn_utils::OpType::BATCHNORMBACKPROP);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
memcpy(ctx->buffer_data[out1_buffer_index], ctx,
stacked_dweights.get(), batchnorm_index,
weight_sizes[0]); deps,
memcpy(ctx->buffer_data[out2_buffer_index], cpu::mkldnn_utils::OpType::BATCHNORMBACKPROP);
stacked_dweights.get() + weight_sizes[0],
weight_sizes[1]); memcpy(ctx->buffer_data[out1_buffer_index],
}; stacked_dweights.get(),
functors.emplace_back(functor); weight_sizes[0]);
memcpy(ctx->buffer_data[out2_buffer_index],
stacked_dweights.get() + weight_sizes[0],
weight_sizes[1]);
};
functors.emplace_back(functor);
}
} }
template <> template <>
......
...@@ -825,46 +825,67 @@ namespace ngraph ...@@ -825,46 +825,67 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormTrainingBackprop) void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormTrainingBackprop)
{ {
writer.block_begin(); writer.block_begin();
// define weights if (!mkldnn_utils::use_mkldnn_kernel(node))
writer << "std::vector<" << args[0].get_element_type().c_type_string() {
<< ">bn_weights(2*" << args[0].get_size() << ");\n"; const ngraph::op::BatchNormTrainingBackprop* batchnorm =
writer << "std::vector<" << args[0].get_element_type().c_type_string() static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node);
<< ">bn_dweights(2*" << args[0].get_size() << ");\n";
writer << "memcpy(&bn_weights[0], " << args[0].get_name() << ", " writer << "reference::batch_norm_backprop(" << batchnorm->get_eps_value()
<< args[0].get_size() * args[0].get_element_type().size() << ");\n"; << ",\n";
writer << "memcpy(&bn_weights[0]+" << args[0].get_size() << ", " writer << " " << args[0].get_name() << ",\n";
<< args[1].get_name() << ", " writer << " " << args[1].get_name() << ",\n";
<< args[1].get_size() * args[1].get_element_type().size() << ");\n"; writer << " " << args[2].get_name() << ",\n";
writer << " " << args[3].get_name() << ",\n";
writer << " " << args[4].get_name() << ",\n";
writer << " " << args[5].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " " << out[1].get_name() << ",\n";
writer << " " << out[2].get_name() << ",\n";
writer << " {" << join(args[2].get_shape()) << "});\n";
}
else
{
// define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(2*" << args[0].get_size() << ");\n";
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_dweights(2*" << args[0].get_size() << ");\n";
size_t batchnorm_index; writer << "memcpy(&bn_weights[0], " << args[0].get_name() << ", "
std::vector<std::size_t> deps; << args[0].get_size() * args[0].get_element_type().size() << ");\n";
emit_build_primitives(external_function, node, writer, batchnorm_index, deps); writer << "memcpy(&bn_weights[0]+" << args[0].get_size() << ", "
<< args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[0]) size_t batchnorm_index;
<< ", bn_weights.data());\n"; std::vector<std::size_t> deps;
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[1]) << ", " emit_build_primitives(external_function, node, writer, batchnorm_index, deps);
<< args[2].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[2]) << ", "
<< args[3].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[3]) << ", "
<< args[4].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[4]) << ", "
<< args[5].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[5]) << ", "
<< out[0].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[6])
<< ", bn_dweights.data());\n";
writer << "std::vector<size_t> deps{" << join(deps) << "};\n"; writer << "cg_ctx->set_memory_ptr(" << to_string(deps[0])
writer << "cg_ctx->mkldnn_invoke_primitive(" << to_string(batchnorm_index) << ", bn_weights.data());\n";
<< ", deps, OpType::BATCHNORMBACKPROP);\n"; writer << "cg_ctx->set_memory_ptr(" << to_string(deps[1]) << ", "
<< args[2].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[2]) << ", "
<< args[3].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[3]) << ", "
<< args[4].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[4]) << ", "
<< args[5].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[5]) << ", "
<< out[0].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[6])
<< ", bn_dweights.data());\n";
writer << "memcpy(" << out[1].get_name() << ", &bn_dweights[0], " writer << "std::vector<size_t> deps{" << join(deps) << "};\n";
<< args[0].get_size() * args[0].get_element_type().size() << ");\n"; writer << "cg_ctx->mkldnn_invoke_primitive(" << to_string(batchnorm_index)
writer << "memcpy(" << out[2].get_name() << ", &bn_dweights[0]+" << ", deps, OpType::BATCHNORMBACKPROP);\n";
<< args[0].get_size() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n"; writer << "memcpy(" << out[1].get_name() << ", &bn_dweights[0], "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(" << out[2].get_name() << ", &bn_dweights[0]+"
<< args[0].get_size() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
}
writer.block_end(); writer.block_end();
} }
......
...@@ -66,6 +66,32 @@ namespace ngraph ...@@ -66,6 +66,32 @@ namespace ngraph
static_cast<ElementType*>(out0), static_cast<ElementType*>(out0),
arg2_shape); arg2_shape);
} }
template <typename ElementType>
void batch_norm_backprop(double eps,
const void* arg0,
const void* arg1,
const void* arg2,
const void* arg3,
const void* arg4,
const void* arg5,
void* out0,
void* out1,
void* out2,
const Shape& arg2_shape)
{
reference::batch_norm_backprop(eps,
static_cast<const ElementType*>(arg0),
static_cast<const ElementType*>(arg1),
static_cast<const ElementType*>(arg2),
static_cast<const ElementType*>(arg3),
static_cast<const ElementType*>(arg4),
static_cast<const ElementType*>(arg5),
static_cast<ElementType*>(out0),
static_cast<ElementType*>(out1),
static_cast<ElementType*>(out2),
arg2_shape);
}
} }
} }
} }
......
...@@ -2326,7 +2326,7 @@ namespace ngraph ...@@ -2326,7 +2326,7 @@ namespace ngraph
} }
else else
{ {
throw ngraph_error("Batchnorm Backprop only supported in MKLDNN for now"); set_native_layouts(external_function, node);
} }
} }
......
...@@ -188,6 +188,7 @@ max_trivial_5d_int32 ...@@ -188,6 +188,7 @@ max_trivial_5d_int32
floor_int32 floor_int32
any_trivial any_trivial
any_2x2x3_eliminate_dim_0 any_2x2x3_eliminate_dim_0
backwards_batch_norm_training_3d
# unsupported op: `GenerateMask` # unsupported op: `GenerateMask`
generate_mask generate_mask
......
...@@ -1659,7 +1659,7 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2) ...@@ -1659,7 +1659,7 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2)
ASSERT_TRUE(read_vector<float>(output) == expected); ASSERT_TRUE(read_vector<float>(output) == expected);
} }
NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training) NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training_4d)
{ {
const Shape input_shape{10, 4, 5, 5}; const Shape input_shape{10, 4, 5, 5};
const Shape channel_shape{input_shape.at(1)}; const Shape channel_shape{input_shape.at(1)};
...@@ -1697,6 +1697,44 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training) ...@@ -1697,6 +1697,44 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training)
autodiff_numeric_compare<T>(backend.get(), make_graph, {input, gamma, beta}, .005, .005)); autodiff_numeric_compare<T>(backend.get(), make_graph, {input, gamma, beta}, .005, .005));
} }
NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training_3d)
{
const Shape input_shape{10, 4, 5};
const Shape channel_shape{input_shape.at(1)};
const double eps = 1e-3;
// Need to keep the output elements for mean and variance from going out of scope
// and getting freed.
NodeVector goes;
auto make_graph = [&input_shape, &channel_shape, &eps, &goes] {
const element::Type& et = element::f32;
auto input = make_shared<op::Parameter>(et, input_shape);
auto gamma = make_shared<op::Parameter>(et, channel_shape);
auto beta = make_shared<op::Parameter>(et, channel_shape);
auto BN = make_shared<op::BatchNormTraining>(input, gamma, beta, eps);
auto normed_input = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 0));
auto mean = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 1));
auto variance = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 2));
goes.push_back(mean);
goes.push_back(variance);
// TODO autodiff testing with more than one result
auto f =
make_shared<Function>(ResultVector{normed_input}, ParameterVector{input, gamma, beta});
return f;
};
auto backend = runtime::Backend::create("${BACKEND_NAME}");
using T = float;
test::Uniform<T> rng(-5.0, 2.0);
auto input = rng.initialize(backend->create_tensor<T>(input_shape));
auto gamma = rng.initialize(backend->create_tensor<T>(channel_shape));
auto beta = rng.initialize(backend->create_tensor<T>(channel_shape));
EXPECT_TRUE(
autodiff_numeric_compare<T>(backend.get(), make_graph, {input, gamma, beta}, .005, .005));
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_reverse_sequence_n3_c2_h3) NGRAPH_TEST(${BACKEND_NAME}, backwards_reverse_sequence_n3_c2_h3)
{ {
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment