Commit 5190ee90 authored by Amy Zhuang's avatar Amy Zhuang Committed by Robert Kimball

Modify TBB graph nodes creation and deletion (#1226)

* Modify TBB graph nodes creation and deletion
* Add a graph* member to CPURuntimeContext.
* Create nodes the first time a function is called, all the following calls only exectue the computation.
* Delete nodes when cleanup_runtime_context is called.

* Add TBB global_control* and task_scheduler_init* members to CPURuntimeContext.

* Remove one comment.
Do not write two TBB header files and one #define to generated c++ source code.

* Move TBB header file and #define before other header files in generated c++ source code.

* Move one comment to the top in generated c++ source code.
parent 36473a8a
......@@ -120,6 +120,15 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context()
const auto& mkldnn_emitter = m_external_function->get_mkldnn_emitter();
ctx->mkldnn_primitives = mkldnn_emitter->get_mkldnn_primitives().data();
ctx->mkldnn_workspaces = mkldnn_emitter->get_mkldnn_workspaces().data();
if (std::getenv("NGRAPH_CPU_USE_TBB") != nullptr)
{
ctx->G = new tbb::flow::graph;
const auto envParallelism = std::getenv("NGRAPH_INTER_OP_PARALLELISM");
const auto parallelism = envParallelism == nullptr ? 1 : std::atoi(envParallelism);
ctx->c = new tbb::global_control(tbb::global_control::max_allowed_parallelism, parallelism);
ctx->init = new tbb::task_scheduler_init(parallelism);
}
}
void runtime::cpu::CPU_CallFrame::cleanup_runtime_context()
......@@ -130,5 +139,22 @@ void runtime::cpu::CPU_CallFrame::cleanup_runtime_context()
{
delete buffer;
}
if (std::getenv("NGRAPH_CPU_USE_TBB") != nullptr)
{
// delete graph G and nodes in G
ctx->G->wait_for_all();
std::vector<tbb::flow::graph_node*> to_be_deleted;
for (auto it = ctx->G->begin(); it != ctx->G->end(); it++)
{
to_be_deleted.push_back(&(*it));
}
delete ctx->G;
for (auto node : to_be_deleted)
{
delete node;
}
delete ctx->c;
delete ctx->init;
}
delete ctx;
}
......@@ -385,8 +385,15 @@ void runtime::cpu::CPU_ExternalFunction::compile()
codegen::CodeWriter writer;
writer << "//Generated by the nGraph CPU backend\n";
if (m_use_tbb)
{
writer << "#define __TBB_PREVIEW_LIGHTWEIGHT_POLICY 1\n";
writer << "#include <tbb/flow_graph.h>";
}
writer +=
R"(// Generated by the nGraph CPU backend
R"(
#include <cmath>
#include "ngraph/except.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
......@@ -433,11 +440,6 @@ using namespace ngraph::runtime;
writer << "#include <mpi.h>\n\n";
#endif
if (m_use_tbb)
{
writer << "#include <tbb/flow_graph.h>\n";
}
string pch_header_source = writer.get_code();
// The "dso_handle" symbol is required by __cxa_atexit()
......@@ -591,12 +593,6 @@ using namespace ngraph::runtime;
writer << "{\n";
writer.indent++;
if (m_use_tbb)
{
// TODO: This should be static but we don't codegen statics correctly yet
writer << "tbb::flow::graph G;\n\n";
}
// Execution tracing support
if (runtime::cpu::IsTracingEnabled() && current_function->get_name() == m_function_name)
{
......@@ -625,6 +621,18 @@ using namespace ngraph::runtime;
writer << "bool* t_en = (bool*)" << current_function->get_name() << "_t_en;\n";
if (m_use_tbb)
{
writer << "\n";
writer << "if (" << current_function->get_name() << "_init) {\n";
writer.indent++;
writer << "tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight>* "
"flowgraph_node_start"
<< " = new tbb::flow::continue_node<tbb::flow::continue_msg, "
"tbb::flow::lightweight>"
"(*(ctx->G), [&](const tbb::flow::continue_msg &msg)\n{});\n";
}
// Add inputs to the variable name map
size_t arg_index = 0;
for (shared_ptr<ngraph::op::Parameter> param : current_function->get_parameters())
......@@ -705,10 +713,13 @@ using namespace ngraph::runtime;
}
if (m_use_tbb)
{
writer << "tbb::flow::continue_node<tbb::flow::continue_msg> "
writer << "tbb::flow::continue_node<tbb::flow::continue_msg, "
"tbb::flow::lightweight>* "
"flowgraph_node_"
<< node->get_name()
<< "(G, [&](const tbb::flow::continue_msg &msg)\n{\n";
<< " = new tbb::flow::continue_node<tbb::flow::continue_msg, "
"tbb::flow::lightweight>"
"(*(ctx->G), [&](const tbb::flow::continue_msg &msg)\n{\n";
writer.indent++;
}
if (runtime::cpu::IsTracingEnabled() &&
......@@ -855,8 +866,8 @@ using namespace ngraph::runtime;
if (!arg->is_parameter() && !arg->is_constant())
{
is_head = false;
writer << "tbb::flow::make_edge(flowgraph_node_" << arg->get_name()
<< ", flowgraph_node_" << n->get_name() << ");\n";
writer << "tbb::flow::make_edge(*flowgraph_node_" << arg->get_name()
<< ", *flowgraph_node_" << n->get_name() << ");\n";
}
}
if (is_head)
......@@ -867,17 +878,24 @@ using namespace ngraph::runtime;
});
writer << "\n";
// Execute the flow graph
if (!dependence_graph_heads.empty())
{
for (Node* n : dependence_graph_heads)
{
writer << "flowgraph_node_" << n->get_name()
<< ".try_put(tbb::flow::continue_msg());\n";
writer << "tbb::flow::make_edge(*flowgraph_node_start"
<< ", *flowgraph_node_" << n->get_name() << ");\n";
}
writer << "try { G.wait_for_all(); } catch(...) { throw; }\n";
}
writer.indent--;
writer << "}\n";
// Execute the flow graph
writer << "auto start = &(*(ctx->G->begin()));\n";
writer << "((tbb::flow::continue_node<tbb::flow::continue_msg, "
"tbb::flow::lightweight>*)(&(*(ctx->G->begin()))))"
<< "->try_put(tbb::flow::continue_msg());\n";
writer << "try { ctx->G->wait_for_all(); } catch(...) { throw; }\n";
}
writer << current_function->get_name() << "_init = false;\n";
......
......@@ -19,6 +19,11 @@
#include <chrono>
#include <cstdint>
#define TBB_PREVIEW_GLOBAL_CONTROL 1
#include <tbb/flow_graph.h>
#include <tbb/global_control.h>
#include <tbb/task_scheduler_init.h>
namespace mkldnn
{
class primitive;
......@@ -50,6 +55,9 @@ namespace ngraph
mkldnn::primitive* const* mkldnn_primitives;
std::vector<AlignedBuffer*> memory_buffers;
char* const* mkldnn_workspaces;
tbb::flow::graph* G;
tbb::global_control* c;
tbb::task_scheduler_init* init;
};
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment