Commit f0bc6c12 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Support non-mkldnn fallback for Batchnorm training bprop in CPU backend (#3688)

* Support non-mkldnn fallback for Batchnorm training bprop in CPU backend

* Skip unit test on PlaidML backend
parent 5a1de88e
...@@ -825,46 +825,67 @@ namespace ngraph ...@@ -825,46 +825,67 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormTrainingBackprop) void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormTrainingBackprop)
{ {
writer.block_begin(); writer.block_begin();
// define weights if (!mkldnn_utils::use_mkldnn_kernel(node))
writer << "std::vector<" << args[0].get_element_type().c_type_string() {
<< ">bn_weights(2*" << args[0].get_size() << ");\n"; const ngraph::op::BatchNormTrainingBackprop* batchnorm =
writer << "std::vector<" << args[0].get_element_type().c_type_string() static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node);
<< ">bn_dweights(2*" << args[0].get_size() << ");\n";
writer << "memcpy(&bn_weights[0], " << args[0].get_name() << ", " writer << "reference::batch_norm_backprop(" << batchnorm->get_eps_value()
<< args[0].get_size() * args[0].get_element_type().size() << ");\n"; << ",\n";
writer << "memcpy(&bn_weights[0]+" << args[0].get_size() << ", " writer << " " << args[0].get_name() << ",\n";
<< args[1].get_name() << ", " writer << " " << args[1].get_name() << ",\n";
<< args[1].get_size() * args[1].get_element_type().size() << ");\n"; writer << " " << args[2].get_name() << ",\n";
writer << " " << args[3].get_name() << ",\n";
writer << " " << args[4].get_name() << ",\n";
writer << " " << args[5].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " " << out[1].get_name() << ",\n";
writer << " " << out[2].get_name() << ",\n";
writer << " {" << join(args[2].get_shape()) << "});\n";
}
else
{
// define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(2*" << args[0].get_size() << ");\n";
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_dweights(2*" << args[0].get_size() << ");\n";
size_t batchnorm_index; writer << "memcpy(&bn_weights[0], " << args[0].get_name() << ", "
std::vector<std::size_t> deps; << args[0].get_size() * args[0].get_element_type().size() << ");\n";
emit_build_primitives(external_function, node, writer, batchnorm_index, deps); writer << "memcpy(&bn_weights[0]+" << args[0].get_size() << ", "
<< args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[0]) size_t batchnorm_index;
<< ", bn_weights.data());\n"; std::vector<std::size_t> deps;
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[1]) << ", " emit_build_primitives(external_function, node, writer, batchnorm_index, deps);
<< args[2].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[2]) << ", "
<< args[3].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[3]) << ", "
<< args[4].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[4]) << ", "
<< args[5].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[5]) << ", "
<< out[0].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[6])
<< ", bn_dweights.data());\n";
writer << "std::vector<size_t> deps{" << join(deps) << "};\n"; writer << "cg_ctx->set_memory_ptr(" << to_string(deps[0])
writer << "cg_ctx->mkldnn_invoke_primitive(" << to_string(batchnorm_index) << ", bn_weights.data());\n";
<< ", deps, OpType::BATCHNORMBACKPROP);\n"; writer << "cg_ctx->set_memory_ptr(" << to_string(deps[1]) << ", "
<< args[2].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[2]) << ", "
<< args[3].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[3]) << ", "
<< args[4].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[4]) << ", "
<< args[5].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[5]) << ", "
<< out[0].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[6])
<< ", bn_dweights.data());\n";
writer << "memcpy(" << out[1].get_name() << ", &bn_dweights[0], " writer << "std::vector<size_t> deps{" << join(deps) << "};\n";
<< args[0].get_size() * args[0].get_element_type().size() << ");\n"; writer << "cg_ctx->mkldnn_invoke_primitive(" << to_string(batchnorm_index)
writer << "memcpy(" << out[2].get_name() << ", &bn_dweights[0]+" << ", deps, OpType::BATCHNORMBACKPROP);\n";
<< args[0].get_size() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n"; writer << "memcpy(" << out[1].get_name() << ", &bn_dweights[0], "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(" << out[2].get_name() << ", &bn_dweights[0]+"
<< args[0].get_size() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
}
writer.block_end(); writer.block_end();
} }
......
...@@ -66,6 +66,32 @@ namespace ngraph ...@@ -66,6 +66,32 @@ namespace ngraph
static_cast<ElementType*>(out0), static_cast<ElementType*>(out0),
arg2_shape); arg2_shape);
} }
template <typename ElementType>
void batch_norm_backprop(double eps,
const void* arg0,
const void* arg1,
const void* arg2,
const void* arg3,
const void* arg4,
const void* arg5,
void* out0,
void* out1,
void* out2,
const Shape& arg2_shape)
{
reference::batch_norm_backprop(eps,
static_cast<const ElementType*>(arg0),
static_cast<const ElementType*>(arg1),
static_cast<const ElementType*>(arg2),
static_cast<const ElementType*>(arg3),
static_cast<const ElementType*>(arg4),
static_cast<const ElementType*>(arg5),
static_cast<ElementType*>(out0),
static_cast<ElementType*>(out1),
static_cast<ElementType*>(out2),
arg2_shape);
}
} }
} }
} }
......
...@@ -2326,7 +2326,7 @@ namespace ngraph ...@@ -2326,7 +2326,7 @@ namespace ngraph
} }
else else
{ {
throw ngraph_error("Batchnorm Backprop only supported in MKLDNN for now"); set_native_layouts(external_function, node);
} }
} }
......
...@@ -188,6 +188,7 @@ max_trivial_5d_int32 ...@@ -188,6 +188,7 @@ max_trivial_5d_int32
floor_int32 floor_int32
any_trivial any_trivial
any_2x2x3_eliminate_dim_0 any_2x2x3_eliminate_dim_0
backwards_batch_norm_training_3d
# unsupported op: `GenerateMask` # unsupported op: `GenerateMask`
generate_mask generate_mask
......
...@@ -1659,7 +1659,7 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2) ...@@ -1659,7 +1659,7 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2)
ASSERT_TRUE(read_vector<float>(output) == expected); ASSERT_TRUE(read_vector<float>(output) == expected);
} }
NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training) NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training_4d)
{ {
const Shape input_shape{10, 4, 5, 5}; const Shape input_shape{10, 4, 5, 5};
const Shape channel_shape{input_shape.at(1)}; const Shape channel_shape{input_shape.at(1)};
...@@ -1697,6 +1697,44 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training) ...@@ -1697,6 +1697,44 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training)
autodiff_numeric_compare<T>(backend.get(), make_graph, {input, gamma, beta}, .005, .005)); autodiff_numeric_compare<T>(backend.get(), make_graph, {input, gamma, beta}, .005, .005));
} }
NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training_3d)
{
const Shape input_shape{10, 4, 5};
const Shape channel_shape{input_shape.at(1)};
const double eps = 1e-3;
// Need to keep the output elements for mean and variance from going out of scope
// and getting freed.
NodeVector goes;
auto make_graph = [&input_shape, &channel_shape, &eps, &goes] {
const element::Type& et = element::f32;
auto input = make_shared<op::Parameter>(et, input_shape);
auto gamma = make_shared<op::Parameter>(et, channel_shape);
auto beta = make_shared<op::Parameter>(et, channel_shape);
auto BN = make_shared<op::BatchNormTraining>(input, gamma, beta, eps);
auto normed_input = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 0));
auto mean = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 1));
auto variance = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 2));
goes.push_back(mean);
goes.push_back(variance);
// TODO autodiff testing with more than one result
auto f =
make_shared<Function>(ResultVector{normed_input}, ParameterVector{input, gamma, beta});
return f;
};
auto backend = runtime::Backend::create("${BACKEND_NAME}");
using T = float;
test::Uniform<T> rng(-5.0, 2.0);
auto input = rng.initialize(backend->create_tensor<T>(input_shape));
auto gamma = rng.initialize(backend->create_tensor<T>(channel_shape));
auto beta = rng.initialize(backend->create_tensor<T>(channel_shape));
EXPECT_TRUE(
autodiff_numeric_compare<T>(backend.get(), make_graph, {input, gamma, beta}, .005, .005));
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_reverse_sequence_n3_c2_h3) NGRAPH_TEST(${BACKEND_NAME}, backwards_reverse_sequence_n3_c2_h3)
{ {
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment