Commit f0bc6c12 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Support non-mkldnn fallback for Batchnorm training bprop in CPU backend (#3688)

* Support non-mkldnn fallback for Batchnorm training bprop in CPU backend

* Skip unit test on PlaidML backend
parent 5a1de88e
......@@ -825,46 +825,67 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormTrainingBackprop)
{
writer.block_begin();
// define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(2*" << args[0].get_size() << ");\n";
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_dweights(2*" << args[0].get_size() << ");\n";
if (!mkldnn_utils::use_mkldnn_kernel(node))
{
const ngraph::op::BatchNormTrainingBackprop* batchnorm =
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node);
writer << "memcpy(&bn_weights[0], " << args[0].get_name() << ", "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(&bn_weights[0]+" << args[0].get_size() << ", "
<< args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
writer << "reference::batch_norm_backprop(" << batchnorm->get_eps_value()
<< ",\n";
writer << " " << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << args[3].get_name() << ",\n";
writer << " " << args[4].get_name() << ",\n";
writer << " " << args[5].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " " << out[1].get_name() << ",\n";
writer << " " << out[2].get_name() << ",\n";
writer << " {" << join(args[2].get_shape()) << "});\n";
}
else
{
// define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(2*" << args[0].get_size() << ");\n";
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_dweights(2*" << args[0].get_size() << ");\n";
size_t batchnorm_index;
std::vector<std::size_t> deps;
emit_build_primitives(external_function, node, writer, batchnorm_index, deps);
writer << "memcpy(&bn_weights[0], " << args[0].get_name() << ", "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(&bn_weights[0]+" << args[0].get_size() << ", "
<< args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[0])
<< ", bn_weights.data());\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[1]) << ", "
<< args[2].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[2]) << ", "
<< args[3].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[3]) << ", "
<< args[4].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[4]) << ", "
<< args[5].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[5]) << ", "
<< out[0].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[6])
<< ", bn_dweights.data());\n";
size_t batchnorm_index;
std::vector<std::size_t> deps;
emit_build_primitives(external_function, node, writer, batchnorm_index, deps);
writer << "std::vector<size_t> deps{" << join(deps) << "};\n";
writer << "cg_ctx->mkldnn_invoke_primitive(" << to_string(batchnorm_index)
<< ", deps, OpType::BATCHNORMBACKPROP);\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[0])
<< ", bn_weights.data());\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[1]) << ", "
<< args[2].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[2]) << ", "
<< args[3].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[3]) << ", "
<< args[4].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[4]) << ", "
<< args[5].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[5]) << ", "
<< out[0].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[6])
<< ", bn_dweights.data());\n";
writer << "memcpy(" << out[1].get_name() << ", &bn_dweights[0], "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(" << out[2].get_name() << ", &bn_dweights[0]+"
<< args[0].get_size() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
writer << "std::vector<size_t> deps{" << join(deps) << "};\n";
writer << "cg_ctx->mkldnn_invoke_primitive(" << to_string(batchnorm_index)
<< ", deps, OpType::BATCHNORMBACKPROP);\n";
writer << "memcpy(" << out[1].get_name() << ", &bn_dweights[0], "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(" << out[2].get_name() << ", &bn_dweights[0]+"
<< args[0].get_size() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
}
writer.block_end();
}
......
......@@ -66,6 +66,32 @@ namespace ngraph
static_cast<ElementType*>(out0),
arg2_shape);
}
template <typename ElementType>
void batch_norm_backprop(double eps,
const void* arg0,
const void* arg1,
const void* arg2,
const void* arg3,
const void* arg4,
const void* arg5,
void* out0,
void* out1,
void* out2,
const Shape& arg2_shape)
{
reference::batch_norm_backprop(eps,
static_cast<const ElementType*>(arg0),
static_cast<const ElementType*>(arg1),
static_cast<const ElementType*>(arg2),
static_cast<const ElementType*>(arg3),
static_cast<const ElementType*>(arg4),
static_cast<const ElementType*>(arg5),
static_cast<ElementType*>(out0),
static_cast<ElementType*>(out1),
static_cast<ElementType*>(out2),
arg2_shape);
}
}
}
}
......
......@@ -2326,7 +2326,7 @@ namespace ngraph
}
else
{
throw ngraph_error("Batchnorm Backprop only supported in MKLDNN for now");
set_native_layouts(external_function, node);
}
}
......
......@@ -188,6 +188,7 @@ max_trivial_5d_int32
floor_int32
any_trivial
any_2x2x3_eliminate_dim_0
backwards_batch_norm_training_3d
# unsupported op: `GenerateMask`
generate_mask
......
......@@ -1659,7 +1659,7 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2)
ASSERT_TRUE(read_vector<float>(output) == expected);
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training)
NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training_4d)
{
const Shape input_shape{10, 4, 5, 5};
const Shape channel_shape{input_shape.at(1)};
......@@ -1697,6 +1697,44 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training)
autodiff_numeric_compare<T>(backend.get(), make_graph, {input, gamma, beta}, .005, .005));
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training_3d)
{
const Shape input_shape{10, 4, 5};
const Shape channel_shape{input_shape.at(1)};
const double eps = 1e-3;
// Need to keep the output elements for mean and variance from going out of scope
// and getting freed.
NodeVector goes;
auto make_graph = [&input_shape, &channel_shape, &eps, &goes] {
const element::Type& et = element::f32;
auto input = make_shared<op::Parameter>(et, input_shape);
auto gamma = make_shared<op::Parameter>(et, channel_shape);
auto beta = make_shared<op::Parameter>(et, channel_shape);
auto BN = make_shared<op::BatchNormTraining>(input, gamma, beta, eps);
auto normed_input = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 0));
auto mean = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 1));
auto variance = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 2));
goes.push_back(mean);
goes.push_back(variance);
// TODO autodiff testing with more than one result
auto f =
make_shared<Function>(ResultVector{normed_input}, ParameterVector{input, gamma, beta});
return f;
};
auto backend = runtime::Backend::create("${BACKEND_NAME}");
using T = float;
test::Uniform<T> rng(-5.0, 2.0);
auto input = rng.initialize(backend->create_tensor<T>(input_shape));
auto gamma = rng.initialize(backend->create_tensor<T>(channel_shape));
auto beta = rng.initialize(backend->create_tensor<T>(channel_shape));
EXPECT_TRUE(
autodiff_numeric_compare<T>(backend.get(), make_graph, {input, gamma, beta}, .005, .005));
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_reverse_sequence_n3_c2_h3)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment