Commit da24b77c authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

hooking up batchnorm kernels to cpu codegen (#1001)

parent 67078856
...@@ -503,9 +503,38 @@ namespace ngraph ...@@ -503,9 +503,38 @@ namespace ngraph
{ {
if (!mkldnn_utils::use_mkldnn_kernel(node)) if (!mkldnn_utils::use_mkldnn_kernel(node))
{ {
throw ngraph_error("BatchNorm is only supported with 4-D MKLDNN kernel."); const ngraph::op::BatchNorm* batchnorm =
static_cast<const ngraph::op::BatchNorm*>(node);
if (batchnorm->get_training_flag() && args.size() == 3)
{
writer << "reference::batch_norm_three_outputs("
<< batchnorm->get_eps_value() << ",\n";
writer << " " << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " " << out[1].get_name() << ",\n";
writer << " " << out[2].get_name() << ",\n";
writer << " {" << join(args[2].get_shape()) << "});\n";
}
else
{
writer << "reference::batch_norm_one_output(" << batchnorm->get_eps_value()
<< ",\n";
writer << " " << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << args[3].get_name() << ",\n";
writer << " " << args[4].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[2].get_shape()) << "});\n";
}
}
else
{
emitBatchNorm(external_function, writer, node, args, out, false);
} }
emitBatchNorm(external_function, writer, node, args, out, false);
} }
template <> template <>
......
...@@ -351,6 +351,7 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -351,6 +351,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp" #include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/reference/and.hpp" #include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp" #include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
#include "ngraph/runtime/reference/broadcast.hpp" #include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/concat.hpp" #include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/convolution.hpp" #include "ngraph/runtime/reference/convolution.hpp"
......
...@@ -1172,7 +1172,7 @@ namespace ngraph ...@@ -1172,7 +1172,7 @@ namespace ngraph
} }
else else
{ {
throw ngraph_error("Batchnorm only supported in MKLDNN for now"); set_default_layouts(external_function, node);
} }
} }
......
batch_norm_one_output
batch_norm_three_outputs
one_hot_matrix_0 one_hot_matrix_0
one_hot_scalar_0_in_3 one_hot_scalar_0_in_3
one_hot_scalar_1_in_3 one_hot_scalar_1_in_3
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment