Commit a5b75e36 authored by Jayaram Bobba's avatar Jayaram Bobba

Conditionally emit mkldnn headers

parent 924ec039
...@@ -90,6 +90,8 @@ static const string& get_mkldnn_data_type(const string& type) ...@@ -90,6 +90,8 @@ static const string& get_mkldnn_data_type(const string& type)
void runtime::cpu::CPU_Emitter::EmitMKLDNNPreamble(codegen::CodeWriter& writer) void runtime::cpu::CPU_Emitter::EmitMKLDNNPreamble(codegen::CodeWriter& writer)
{ {
writer << "// MKLDNN Preamble\n";
writer << "#include <mkldnn.hpp>;\n";
writer << "using namespace mkldnn;\n\n"; writer << "using namespace mkldnn;\n\n";
} }
......
...@@ -230,15 +230,25 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -230,15 +230,25 @@ void runtime::cpu::CPU_ExternalFunction::compile()
codegen::CodeWriter writer; codegen::CodeWriter writer;
bool include_mkldnn_headers = false;
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{
for (shared_ptr<Node> node : current_function->get_ordered_ops())
{
if (dynamic_cast<op::Convolution*>(node.get()) ||
dynamic_cast<op::AvgPool*>(node.get()) || dynamic_cast<op::MaxPool*>(node.get()))
{
include_mkldnn_headers = true;
}
}
}
writer += writer +=
R"(// Generated by the NGraph CPU backend R"(// Generated by the NGraph CPU backend
#include <cmath> #include <cmath>
#include <tbb/flow_graph.h>
#include <Eigen/Dense> #include <Eigen/Dense>
#include <mkldnn.hpp>
#include "ngraph/runtime/aligned_buffer.hpp" #include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/cpu/cpu_eigen_utils.hpp" #include "ngraph/runtime/cpu/cpu_eigen_utils.hpp"
#include "ngraph/runtime/cpu/cpu_kernels.hpp" #include "ngraph/runtime/cpu/cpu_kernels.hpp"
...@@ -264,6 +274,17 @@ using namespace ngraph::runtime::cpu::eigen; ...@@ -264,6 +274,17 @@ using namespace ngraph::runtime::cpu::eigen;
using namespace ngraph::runtime; using namespace ngraph::runtime;
)"; )";
if (m_use_tbb)
{
writer << "#include <tbb/flow_graph.h>\n";
}
if (include_mkldnn_headers)
{
runtime::cpu::CPU_Emitter::EmitMKLDNNPreamble(writer);
}
string pch_header_source = writer.get_code(); string pch_header_source = writer.get_code();
// The "dso_handle" symbol is required by __cxa_atexit() // The "dso_handle" symbol is required by __cxa_atexit()
...@@ -364,8 +385,6 @@ using namespace ngraph::runtime; ...@@ -364,8 +385,6 @@ using namespace ngraph::runtime;
} }
} }
runtime::cpu::CPU_Emitter::EmitMKLDNNPreamble(writer);
writer << "// Declare all functions\n"; writer << "// Declare all functions\n";
for (shared_ptr<Function> f : pass_manager.get_state().get_functions()) for (shared_ptr<Function> f : pass_manager.get_state().get_functions())
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment