Commit da24b77c authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

hooking up batchnorm kernels to cpu codegen (#1001)

parent 67078856
......@@ -503,9 +503,38 @@ namespace ngraph
{
if (!mkldnn_utils::use_mkldnn_kernel(node))
{
throw ngraph_error("BatchNorm is only supported with 4-D MKLDNN kernel.");
const ngraph::op::BatchNorm* batchnorm =
static_cast<const ngraph::op::BatchNorm*>(node);
if (batchnorm->get_training_flag() && args.size() == 3)
{
writer << "reference::batch_norm_three_outputs("
<< batchnorm->get_eps_value() << ",\n";
writer << " " << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " " << out[1].get_name() << ",\n";
writer << " " << out[2].get_name() << ",\n";
writer << " {" << join(args[2].get_shape()) << "});\n";
}
else
{
writer << "reference::batch_norm_one_output(" << batchnorm->get_eps_value()
<< ",\n";
writer << " " << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << args[3].get_name() << ",\n";
writer << " " << args[4].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[2].get_shape()) << "});\n";
}
}
else
{
emitBatchNorm(external_function, writer, node, args, out, false);
}
emitBatchNorm(external_function, writer, node, args, out, false);
}
template <>
......
......@@ -351,6 +351,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/convolution.hpp"
......
......@@ -1172,7 +1172,7 @@ namespace ngraph
}
else
{
throw ngraph_error("Batchnorm only supported in MKLDNN for now");
set_default_layouts(external_function, node);
}
}
......
batch_norm_one_output
batch_norm_three_outputs
one_hot_matrix_0
one_hot_scalar_0_in_3
one_hot_scalar_1_in_3
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment