Commit 5190ee90 authored by Amy Zhuang's avatar Amy Zhuang Committed by Robert Kimball

Modify TBB graph nodes creation and deletion (#1226)

* Modify TBB graph nodes creation and deletion
* Add a graph* member to CPURuntimeContext.
* Create nodes the first time a function is called, all the following calls only exectue the computation.
* Delete nodes when cleanup_runtime_context is called.

* Add TBB global_control* and task_scheduler_init* members to CPURuntimeContext.

* Remove one comment.
Do not write two TBB header files and one #define to generated c++ source code.

* Move TBB header file and #define before other header files in generated c++ source code.

* Move one comment to the top in generated c++ source code.
parent 36473a8a
...@@ -120,6 +120,15 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context() ...@@ -120,6 +120,15 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context()
const auto& mkldnn_emitter = m_external_function->get_mkldnn_emitter(); const auto& mkldnn_emitter = m_external_function->get_mkldnn_emitter();
ctx->mkldnn_primitives = mkldnn_emitter->get_mkldnn_primitives().data(); ctx->mkldnn_primitives = mkldnn_emitter->get_mkldnn_primitives().data();
ctx->mkldnn_workspaces = mkldnn_emitter->get_mkldnn_workspaces().data(); ctx->mkldnn_workspaces = mkldnn_emitter->get_mkldnn_workspaces().data();
if (std::getenv("NGRAPH_CPU_USE_TBB") != nullptr)
{
ctx->G = new tbb::flow::graph;
const auto envParallelism = std::getenv("NGRAPH_INTER_OP_PARALLELISM");
const auto parallelism = envParallelism == nullptr ? 1 : std::atoi(envParallelism);
ctx->c = new tbb::global_control(tbb::global_control::max_allowed_parallelism, parallelism);
ctx->init = new tbb::task_scheduler_init(parallelism);
}
} }
void runtime::cpu::CPU_CallFrame::cleanup_runtime_context() void runtime::cpu::CPU_CallFrame::cleanup_runtime_context()
...@@ -130,5 +139,22 @@ void runtime::cpu::CPU_CallFrame::cleanup_runtime_context() ...@@ -130,5 +139,22 @@ void runtime::cpu::CPU_CallFrame::cleanup_runtime_context()
{ {
delete buffer; delete buffer;
} }
if (std::getenv("NGRAPH_CPU_USE_TBB") != nullptr)
{
// delete graph G and nodes in G
ctx->G->wait_for_all();
std::vector<tbb::flow::graph_node*> to_be_deleted;
for (auto it = ctx->G->begin(); it != ctx->G->end(); it++)
{
to_be_deleted.push_back(&(*it));
}
delete ctx->G;
for (auto node : to_be_deleted)
{
delete node;
}
delete ctx->c;
delete ctx->init;
}
delete ctx; delete ctx;
} }
...@@ -385,8 +385,15 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -385,8 +385,15 @@ void runtime::cpu::CPU_ExternalFunction::compile()
codegen::CodeWriter writer; codegen::CodeWriter writer;
writer << "//Generated by the nGraph CPU backend\n";
if (m_use_tbb)
{
writer << "#define __TBB_PREVIEW_LIGHTWEIGHT_POLICY 1\n";
writer << "#include <tbb/flow_graph.h>";
}
writer += writer +=
R"(// Generated by the nGraph CPU backend R"(
#include <cmath> #include <cmath>
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "ngraph/runtime/aligned_buffer.hpp" #include "ngraph/runtime/aligned_buffer.hpp"
...@@ -433,11 +440,6 @@ using namespace ngraph::runtime; ...@@ -433,11 +440,6 @@ using namespace ngraph::runtime;
writer << "#include <mpi.h>\n\n"; writer << "#include <mpi.h>\n\n";
#endif #endif
if (m_use_tbb)
{
writer << "#include <tbb/flow_graph.h>\n";
}
string pch_header_source = writer.get_code(); string pch_header_source = writer.get_code();
// The "dso_handle" symbol is required by __cxa_atexit() // The "dso_handle" symbol is required by __cxa_atexit()
...@@ -591,12 +593,6 @@ using namespace ngraph::runtime; ...@@ -591,12 +593,6 @@ using namespace ngraph::runtime;
writer << "{\n"; writer << "{\n";
writer.indent++; writer.indent++;
if (m_use_tbb)
{
// TODO: This should be static but we don't codegen statics correctly yet
writer << "tbb::flow::graph G;\n\n";
}
// Execution tracing support // Execution tracing support
if (runtime::cpu::IsTracingEnabled() && current_function->get_name() == m_function_name) if (runtime::cpu::IsTracingEnabled() && current_function->get_name() == m_function_name)
{ {
...@@ -625,6 +621,18 @@ using namespace ngraph::runtime; ...@@ -625,6 +621,18 @@ using namespace ngraph::runtime;
writer << "bool* t_en = (bool*)" << current_function->get_name() << "_t_en;\n"; writer << "bool* t_en = (bool*)" << current_function->get_name() << "_t_en;\n";
if (m_use_tbb)
{
writer << "\n";
writer << "if (" << current_function->get_name() << "_init) {\n";
writer.indent++;
writer << "tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight>* "
"flowgraph_node_start"
<< " = new tbb::flow::continue_node<tbb::flow::continue_msg, "
"tbb::flow::lightweight>"
"(*(ctx->G), [&](const tbb::flow::continue_msg &msg)\n{});\n";
}
// Add inputs to the variable name map // Add inputs to the variable name map
size_t arg_index = 0; size_t arg_index = 0;
for (shared_ptr<ngraph::op::Parameter> param : current_function->get_parameters()) for (shared_ptr<ngraph::op::Parameter> param : current_function->get_parameters())
...@@ -705,10 +713,13 @@ using namespace ngraph::runtime; ...@@ -705,10 +713,13 @@ using namespace ngraph::runtime;
} }
if (m_use_tbb) if (m_use_tbb)
{ {
writer << "tbb::flow::continue_node<tbb::flow::continue_msg> " writer << "tbb::flow::continue_node<tbb::flow::continue_msg, "
"tbb::flow::lightweight>* "
"flowgraph_node_" "flowgraph_node_"
<< node->get_name() << node->get_name()
<< "(G, [&](const tbb::flow::continue_msg &msg)\n{\n"; << " = new tbb::flow::continue_node<tbb::flow::continue_msg, "
"tbb::flow::lightweight>"
"(*(ctx->G), [&](const tbb::flow::continue_msg &msg)\n{\n";
writer.indent++; writer.indent++;
} }
if (runtime::cpu::IsTracingEnabled() && if (runtime::cpu::IsTracingEnabled() &&
...@@ -855,8 +866,8 @@ using namespace ngraph::runtime; ...@@ -855,8 +866,8 @@ using namespace ngraph::runtime;
if (!arg->is_parameter() && !arg->is_constant()) if (!arg->is_parameter() && !arg->is_constant())
{ {
is_head = false; is_head = false;
writer << "tbb::flow::make_edge(flowgraph_node_" << arg->get_name() writer << "tbb::flow::make_edge(*flowgraph_node_" << arg->get_name()
<< ", flowgraph_node_" << n->get_name() << ");\n"; << ", *flowgraph_node_" << n->get_name() << ");\n";
} }
} }
if (is_head) if (is_head)
...@@ -867,17 +878,24 @@ using namespace ngraph::runtime; ...@@ -867,17 +878,24 @@ using namespace ngraph::runtime;
}); });
writer << "\n"; writer << "\n";
// Execute the flow graph
if (!dependence_graph_heads.empty()) if (!dependence_graph_heads.empty())
{ {
for (Node* n : dependence_graph_heads) for (Node* n : dependence_graph_heads)
{ {
writer << "flowgraph_node_" << n->get_name() writer << "tbb::flow::make_edge(*flowgraph_node_start"
<< ".try_put(tbb::flow::continue_msg());\n"; << ", *flowgraph_node_" << n->get_name() << ");\n";
} }
writer << "try { G.wait_for_all(); } catch(...) { throw; }\n";
} }
writer.indent--;
writer << "}\n";
// Execute the flow graph
writer << "auto start = &(*(ctx->G->begin()));\n";
writer << "((tbb::flow::continue_node<tbb::flow::continue_msg, "
"tbb::flow::lightweight>*)(&(*(ctx->G->begin()))))"
<< "->try_put(tbb::flow::continue_msg());\n";
writer << "try { ctx->G->wait_for_all(); } catch(...) { throw; }\n";
} }
writer << current_function->get_name() << "_init = false;\n"; writer << current_function->get_name() << "_init = false;\n";
......
...@@ -19,6 +19,11 @@ ...@@ -19,6 +19,11 @@
#include <chrono> #include <chrono>
#include <cstdint> #include <cstdint>
#define TBB_PREVIEW_GLOBAL_CONTROL 1
#include <tbb/flow_graph.h>
#include <tbb/global_control.h>
#include <tbb/task_scheduler_init.h>
namespace mkldnn namespace mkldnn
{ {
class primitive; class primitive;
...@@ -50,6 +55,9 @@ namespace ngraph ...@@ -50,6 +55,9 @@ namespace ngraph
mkldnn::primitive* const* mkldnn_primitives; mkldnn::primitive* const* mkldnn_primitives;
std::vector<AlignedBuffer*> memory_buffers; std::vector<AlignedBuffer*> memory_buffers;
char* const* mkldnn_workspaces; char* const* mkldnn_workspaces;
tbb::flow::graph* G;
tbb::global_control* c;
tbb::task_scheduler_init* init;
}; };
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment