Commit 1c2e5b7a authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Add TBB flow graphs to DEX. (#1247)

* Add TBB flow graphs to DEX.

* Make edges from dummy start node to head nodes when traversing nodes.

* Use static_cast to cast TBB graph node.
Undefine __TBB_PREVIEW_LIGHTWEIGHT_POLICY.

* Code formatting.

* Remove clang wreserved-id-macro warning.
parent 8ad38f2e
...@@ -23,6 +23,16 @@ ...@@ -23,6 +23,16 @@
#include <typeinfo> #include <typeinfo>
#include <unordered_map> #include <unordered_map>
// Kill clang diagnostics bug
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-id-macro"
#undef __TBB_PREVIEW_LIGHTWEIGHT_POLICY
#define __TBB_PREVIEW_LIGHTWEIGHT_POLICY 1
#pragma clang diagnostic pop
#include <tbb/flow_graph.h>
#include "ngraph/codegen/code_writer.hpp" #include "ngraph/codegen/code_writer.hpp"
#include "ngraph/codegen/compiler.hpp" #include "ngraph/codegen/compiler.hpp"
#include "ngraph/codegen/execution_engine.hpp" #include "ngraph/codegen/execution_engine.hpp"
...@@ -388,6 +398,7 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -388,6 +398,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
writer << "//Generated by the nGraph CPU backend\n"; writer << "//Generated by the nGraph CPU backend\n";
if (m_use_tbb) if (m_use_tbb)
{ {
writer << "#undef __TBB_PREVIEW_LIGHTWEIGHT_POLICY \n";
writer << "#define __TBB_PREVIEW_LIGHTWEIGHT_POLICY 1\n"; writer << "#define __TBB_PREVIEW_LIGHTWEIGHT_POLICY 1\n";
writer << "#include <tbb/flow_graph.h>"; writer << "#include <tbb/flow_graph.h>";
} }
...@@ -854,10 +865,8 @@ using namespace ngraph::runtime; ...@@ -854,10 +865,8 @@ using namespace ngraph::runtime;
{ {
writer << "\n"; writer << "\n";
// Build the flow graph // Build the flow graph
vector<Node*> dependence_graph_heads;
traverse_nodes( traverse_nodes(current_function, [&writer](shared_ptr<Node> n) {
current_function, [&writer, &dependence_graph_heads](shared_ptr<Node> n) {
if (!n->is_parameter() && !n->is_constant()) if (!n->is_parameter() && !n->is_constant())
{ {
bool is_head = true; bool is_head = true;
...@@ -871,29 +880,19 @@ using namespace ngraph::runtime; ...@@ -871,29 +880,19 @@ using namespace ngraph::runtime;
} }
} }
if (is_head) if (is_head)
{
dependence_graph_heads.emplace_back(n.get());
}
}
});
writer << "\n";
if (!dependence_graph_heads.empty())
{
for (Node* n : dependence_graph_heads)
{ {
writer << "tbb::flow::make_edge(*flowgraph_node_start" writer << "tbb::flow::make_edge(*flowgraph_node_start"
<< ", *flowgraph_node_" << n->get_name() << ");\n"; << ", *flowgraph_node_" << n->get_name() << ");\n";
} }
} }
});
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
// Execute the flow graph // Execute the flow graph
writer << "auto start = &(*(ctx->G->begin()));\n"; writer << "(static_cast<tbb::flow::continue_node<tbb::flow::continue_msg, "
writer << "((tbb::flow::continue_node<tbb::flow::continue_msg, " "tbb::flow::lightweight>*>(&(*(ctx->G->begin()))))"
"tbb::flow::lightweight>*)(&(*(ctx->G->begin()))))"
<< "->try_put(tbb::flow::continue_msg());\n"; << "->try_put(tbb::flow::continue_msg());\n";
writer << "try { ctx->G->wait_for_all(); } catch(...) { throw; }\n"; writer << "try { ctx->G->wait_for_all(); } catch(...) { throw; }\n";
} }
...@@ -1189,6 +1188,7 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1189,6 +1188,7 @@ void runtime::cpu::CPU_ExternalFunction::build()
}; };
enables.emplace_back(make_pair(enable, functors.size() - functor_count)); enables.emplace_back(make_pair(enable, functors.size() - functor_count));
enable_nodename_list.emplace_back(make_pair(enable, node->get_name()));
} }
executor = [&](CPURuntimeContext* ctx, vector<void*>& inputs, vector<void*>& outputs) { executor = [&](CPURuntimeContext* ctx, vector<void*>& inputs, vector<void*>& outputs) {
...@@ -1214,6 +1214,105 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1214,6 +1214,105 @@ void runtime::cpu::CPU_ExternalFunction::build()
} }
auto functor = functors.begin(); auto functor = functors.begin();
if (m_use_tbb)
{
// Build the flow graph
if (first_iteration)
{
std::unordered_map<
std::string,
tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight>*>
nodename_tbbnode_map;
tbb::flow::continue_node<tbb::flow::continue_msg,
tbb::flow::lightweight>* flowgraph_node_start =
new tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight>(
*(ctx->G), [&](const tbb::flow::continue_msg& msg) {});
auto it = enable_nodename_list.begin();
for (const auto& p : enables)
{
tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight>*
flowgraph_node = new tbb::flow::continue_node<tbb::flow::continue_msg,
tbb::flow::lightweight>(
*(ctx->G), [&](const tbb::flow::continue_msg& msg) {
if (p.first(ctx) || first_iteration)
{
for (size_t j = 0; j < p.second; j++)
{
if (runtime::cpu::IsTracingEnabled())
{
start_ts = cpu::Clock::now();
}
(*functor)(ctx);
if (runtime::cpu::IsTracingEnabled())
{
ctx->op_durations[profiler_count++] =
(std::chrono::duration_cast<cpu::Timescale>(
cpu::Clock::now() - start_ts))
.count();
}
std::advance(functor, 1);
}
}
else
{
if (runtime::cpu::IsTracingEnabled())
{
for (size_t j = 0; j < p.second; j++)
{
ctx->op_durations[profiler_count++] = 0;
}
}
std::advance(functor, p.second);
}
});
nodename_tbbnode_map.insert({it->second, flowgraph_node});
it++;
}
traverse_nodes(
m_function, [&flowgraph_node_start, &nodename_tbbnode_map](shared_ptr<Node> n) {
if (!n->is_parameter() && !n->is_constant())
{
bool is_head = true;
for (auto arg : n->get_arguments())
{
if (!arg->is_parameter() && !arg->is_constant())
{
is_head = false;
tbb::flow::make_edge(*(nodename_tbbnode_map[arg->get_name()]),
*(nodename_tbbnode_map[n->get_name()]));
}
}
if (is_head)
{
tbb::flow::make_edge(*flowgraph_node_start,
*(nodename_tbbnode_map[n->get_name()]));
}
}
});
if (m_release_function)
{
release_function();
}
}
// Execute the flow graph
(static_cast<
tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight>*>(
&(*(ctx->G->begin()))))
->try_put(tbb::flow::continue_msg());
try
{
ctx->G->wait_for_all();
}
catch (...)
{
throw;
}
}
else
{
for (const auto& p : enables) for (const auto& p : enables)
{ {
if (p.first(ctx) || first_iteration) if (p.first(ctx) || first_iteration)
...@@ -1248,6 +1347,7 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1248,6 +1347,7 @@ void runtime::cpu::CPU_ExternalFunction::build()
std::advance(functor, p.second); std::advance(functor, p.second);
} }
} }
}
first_iteration = false; first_iteration = false;
if (runtime::cpu::IsTracingEnabled()) if (runtime::cpu::IsTracingEnabled())
...@@ -1258,7 +1358,7 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1258,7 +1358,7 @@ void runtime::cpu::CPU_ExternalFunction::build()
m_is_built = true; m_is_built = true;
if (m_release_function) if (m_release_function && !m_use_tbb)
{ {
release_function(); release_function();
} }
......
...@@ -166,6 +166,8 @@ namespace ngraph ...@@ -166,6 +166,8 @@ namespace ngraph
std::list<std::function<void(CPURuntimeContext*)>> functors; std::list<std::function<void(CPURuntimeContext*)>> functors;
std::list<std::pair<std::function<bool(CPURuntimeContext*)>, size_t>> enables; std::list<std::pair<std::function<bool(CPURuntimeContext*)>, size_t>> enables;
std::list<std::pair<std::function<bool(CPURuntimeContext*)>, std::string>>
enable_nodename_list;
std::function<void(CPURuntimeContext*, std::vector<void*>&, std::vector<void*>&)> std::function<void(CPURuntimeContext*, std::vector<void*>&, std::vector<void*>&)>
executor; executor;
std::unordered_map<std::string, void*> tensor_data; std::unordered_map<std::string, void*> tensor_data;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment