Commit f0bc6c12 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Support non-mkldnn fallback for Batchnorm training bprop in CPU backend (#3688)

* Support non-mkldnn fallback for Batchnorm training bprop in CPU backend

* Skip unit test on PlaidML backend
parent 5a1de88e
...@@ -395,14 +395,83 @@ namespace ngraph ...@@ -395,14 +395,83 @@ namespace ngraph
template <> template <>
void Builder::BUILDER_DECL(ngraph::op::BatchNormTrainingBackprop) void Builder::BUILDER_DECL(ngraph::op::BatchNormTrainingBackprop)
{ {
if (!mkldnn_utils::use_mkldnn_kernel(node))
{
const ngraph::op::BatchNormTrainingBackprop* batchnorm =
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node);
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto arg0_buffer_index = external_function->get_buffer_index(args[0].get_name()); std::function<decltype(runtime::cpu::kernel::batch_norm_backprop<float>)>
auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name()); kernel;
auto arg2_buffer_index = external_function->get_buffer_index(args[2].get_name());
auto arg3_buffer_index = external_function->get_buffer_index(args[3].get_name()); SELECT_KERNEL(kernel,
auto arg4_buffer_index = external_function->get_buffer_index(args[4].get_name()); args[0].get_element_type(),
auto arg5_buffer_index = external_function->get_buffer_index(args[5].get_name()); runtime::cpu::kernel::batch_norm_backprop)
auto arg2_shape = args[2].get_shape();
auto arg0_buffer_index =
external_function->get_buffer_index(args[0].get_name()); /* gamma */
auto arg1_buffer_index =
external_function->get_buffer_index(args[1].get_name()); /* beta */
auto arg2_buffer_index =
external_function->get_buffer_index(args[2].get_name()); /* input */
auto arg3_buffer_index =
external_function->get_buffer_index(args[3].get_name()); /* mean */
auto arg4_buffer_index =
external_function->get_buffer_index(args[4].get_name()); /* variance */
auto arg5_buffer_index =
external_function->get_buffer_index(args[5].get_name()); /* delta */
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto out1_buffer_index = external_function->get_buffer_index(out[1].get_name());
auto out2_buffer_index = external_function->get_buffer_index(out[2].get_name());
auto eps = batchnorm->get_eps_value();
auto functor = [&,
kernel,
arg2_shape,
eps,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
arg3_buffer_index,
arg4_buffer_index,
arg5_buffer_index,
out0_buffer_index,
out1_buffer_index,
out2_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
kernel(eps,
ctx->buffer_data[arg0_buffer_index],
ctx->buffer_data[arg1_buffer_index],
ctx->buffer_data[arg2_buffer_index],
ctx->buffer_data[arg3_buffer_index],
ctx->buffer_data[arg4_buffer_index],
ctx->buffer_data[arg5_buffer_index],
ctx->buffer_data[out0_buffer_index],
ctx->buffer_data[out1_buffer_index],
ctx->buffer_data[out2_buffer_index],
arg2_shape);
};
functors.emplace_back(functor);
}
else
{
auto& functors = external_function->get_functors();
auto arg0_buffer_index =
external_function->get_buffer_index(args[0].get_name());
auto arg1_buffer_index =
external_function->get_buffer_index(args[1].get_name());
auto arg2_buffer_index =
external_function->get_buffer_index(args[2].get_name());
auto arg3_buffer_index =
external_function->get_buffer_index(args[3].get_name());
auto arg4_buffer_index =
external_function->get_buffer_index(args[4].get_name());
auto arg5_buffer_index =
external_function->get_buffer_index(args[5].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name()); auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto out1_buffer_index = external_function->get_buffer_index(out[1].get_name()); auto out1_buffer_index = external_function->get_buffer_index(out[1].get_name());
...@@ -421,9 +490,11 @@ namespace ngraph ...@@ -421,9 +490,11 @@ namespace ngraph
#if defined(__clang__) #if defined(__clang__)
#pragma clang diagnostic pop #pragma clang diagnostic pop
#endif #endif
shared_ptr<uint8_t> stacked_weights(new uint8_t[weight_sizes[0] + weight_sizes[1]], shared_ptr<uint8_t> stacked_weights(
new uint8_t[weight_sizes[0] + weight_sizes[1]],
std::default_delete<uint8_t[]>()); std::default_delete<uint8_t[]>());
shared_ptr<uint8_t> stacked_dweights(new uint8_t[weight_sizes[0] + weight_sizes[1]], shared_ptr<uint8_t> stacked_dweights(
new uint8_t[weight_sizes[0] + weight_sizes[1]],
std::default_delete<uint8_t[]>()); std::default_delete<uint8_t[]>());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
...@@ -499,7 +570,10 @@ namespace ngraph ...@@ -499,7 +570,10 @@ namespace ngraph
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[6], stacked_dweights.get()); cpu::mkldnn_utils::set_memory_ptr(ctx, deps[6], stacked_dweights.get());
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, batchnorm_index, deps, cpu::mkldnn_utils::OpType::BATCHNORMBACKPROP); ctx,
batchnorm_index,
deps,
cpu::mkldnn_utils::OpType::BATCHNORMBACKPROP);
memcpy(ctx->buffer_data[out1_buffer_index], memcpy(ctx->buffer_data[out1_buffer_index],
stacked_dweights.get(), stacked_dweights.get(),
...@@ -510,6 +584,7 @@ namespace ngraph ...@@ -510,6 +584,7 @@ namespace ngraph
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
}
template <> template <>
void Builder::BUILDER_DECL(ngraph::op::BatchNormTrainingRelu) void Builder::BUILDER_DECL(ngraph::op::BatchNormTrainingRelu)
......
...@@ -825,6 +825,26 @@ namespace ngraph ...@@ -825,6 +825,26 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormTrainingBackprop) void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormTrainingBackprop)
{ {
writer.block_begin(); writer.block_begin();
if (!mkldnn_utils::use_mkldnn_kernel(node))
{
const ngraph::op::BatchNormTrainingBackprop* batchnorm =
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node);
writer << "reference::batch_norm_backprop(" << batchnorm->get_eps_value()
<< ",\n";
writer << " " << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << args[3].get_name() << ",\n";
writer << " " << args[4].get_name() << ",\n";
writer << " " << args[5].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " " << out[1].get_name() << ",\n";
writer << " " << out[2].get_name() << ",\n";
writer << " {" << join(args[2].get_shape()) << "});\n";
}
else
{
// define weights // define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string() writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(2*" << args[0].get_size() << ");\n"; << ">bn_weights(2*" << args[0].get_size() << ");\n";
...@@ -865,6 +885,7 @@ namespace ngraph ...@@ -865,6 +885,7 @@ namespace ngraph
writer << "memcpy(" << out[2].get_name() << ", &bn_dweights[0]+" writer << "memcpy(" << out[2].get_name() << ", &bn_dweights[0]+"
<< args[0].get_size() << ", " << args[0].get_size() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n"; << args[1].get_size() * args[1].get_element_type().size() << ");\n";
}
writer.block_end(); writer.block_end();
} }
......
...@@ -66,6 +66,32 @@ namespace ngraph ...@@ -66,6 +66,32 @@ namespace ngraph
static_cast<ElementType*>(out0), static_cast<ElementType*>(out0),
arg2_shape); arg2_shape);
} }
template <typename ElementType>
void batch_norm_backprop(double eps,
const void* arg0,
const void* arg1,
const void* arg2,
const void* arg3,
const void* arg4,
const void* arg5,
void* out0,
void* out1,
void* out2,
const Shape& arg2_shape)
{
reference::batch_norm_backprop(eps,
static_cast<const ElementType*>(arg0),
static_cast<const ElementType*>(arg1),
static_cast<const ElementType*>(arg2),
static_cast<const ElementType*>(arg3),
static_cast<const ElementType*>(arg4),
static_cast<const ElementType*>(arg5),
static_cast<ElementType*>(out0),
static_cast<ElementType*>(out1),
static_cast<ElementType*>(out2),
arg2_shape);
}
} }
} }
} }
......
...@@ -2326,7 +2326,7 @@ namespace ngraph ...@@ -2326,7 +2326,7 @@ namespace ngraph
} }
else else
{ {
throw ngraph_error("Batchnorm Backprop only supported in MKLDNN for now"); set_native_layouts(external_function, node);
} }
} }
......
...@@ -188,6 +188,7 @@ max_trivial_5d_int32 ...@@ -188,6 +188,7 @@ max_trivial_5d_int32
floor_int32 floor_int32
any_trivial any_trivial
any_2x2x3_eliminate_dim_0 any_2x2x3_eliminate_dim_0
backwards_batch_norm_training_3d
# unsupported op: `GenerateMask` # unsupported op: `GenerateMask`
generate_mask generate_mask
......
...@@ -1659,7 +1659,7 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2) ...@@ -1659,7 +1659,7 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2)
ASSERT_TRUE(read_vector<float>(output) == expected); ASSERT_TRUE(read_vector<float>(output) == expected);
} }
NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training) NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training_4d)
{ {
const Shape input_shape{10, 4, 5, 5}; const Shape input_shape{10, 4, 5, 5};
const Shape channel_shape{input_shape.at(1)}; const Shape channel_shape{input_shape.at(1)};
...@@ -1697,6 +1697,44 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training) ...@@ -1697,6 +1697,44 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training)
autodiff_numeric_compare<T>(backend.get(), make_graph, {input, gamma, beta}, .005, .005)); autodiff_numeric_compare<T>(backend.get(), make_graph, {input, gamma, beta}, .005, .005));
} }
NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training_3d)
{
const Shape input_shape{10, 4, 5};
const Shape channel_shape{input_shape.at(1)};
const double eps = 1e-3;
// Need to keep the output elements for mean and variance from going out of scope
// and getting freed.
NodeVector goes;
auto make_graph = [&input_shape, &channel_shape, &eps, &goes] {
const element::Type& et = element::f32;
auto input = make_shared<op::Parameter>(et, input_shape);
auto gamma = make_shared<op::Parameter>(et, channel_shape);
auto beta = make_shared<op::Parameter>(et, channel_shape);
auto BN = make_shared<op::BatchNormTraining>(input, gamma, beta, eps);
auto normed_input = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 0));
auto mean = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 1));
auto variance = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 2));
goes.push_back(mean);
goes.push_back(variance);
// TODO autodiff testing with more than one result
auto f =
make_shared<Function>(ResultVector{normed_input}, ParameterVector{input, gamma, beta});
return f;
};
auto backend = runtime::Backend::create("${BACKEND_NAME}");
using T = float;
test::Uniform<T> rng(-5.0, 2.0);
auto input = rng.initialize(backend->create_tensor<T>(input_shape));
auto gamma = rng.initialize(backend->create_tensor<T>(channel_shape));
auto beta = rng.initialize(backend->create_tensor<T>(channel_shape));
EXPECT_TRUE(
autodiff_numeric_compare<T>(backend.get(), make_graph, {input, gamma, beta}, .005, .005));
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_reverse_sequence_n3_c2_h3) NGRAPH_TEST(${BACKEND_NAME}, backwards_reverse_sequence_n3_c2_h3)
{ {
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment