Commit 1c2e5b7a authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Add TBB flow graphs to DEX. (#1247)

* Add TBB flow graphs to DEX.

* Make edges from dummy start node to head nodes when traversing nodes.

* Use static_cast to cast TBB graph node.
Undefine __TBB_PREVIEW_LIGHTWEIGHT_POLICY.

* Code formatting.

* Remove clang wreserved-id-macro warning.
parent 8ad38f2e
......@@ -23,6 +23,16 @@
#include <typeinfo>
#include <unordered_map>
// Kill clang diagnostics bug
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-id-macro"
#undef __TBB_PREVIEW_LIGHTWEIGHT_POLICY
#define __TBB_PREVIEW_LIGHTWEIGHT_POLICY 1
#pragma clang diagnostic pop
#include <tbb/flow_graph.h>
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/codegen/compiler.hpp"
#include "ngraph/codegen/execution_engine.hpp"
......@@ -388,6 +398,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
writer << "//Generated by the nGraph CPU backend\n";
if (m_use_tbb)
{
writer << "#undef __TBB_PREVIEW_LIGHTWEIGHT_POLICY \n";
writer << "#define __TBB_PREVIEW_LIGHTWEIGHT_POLICY 1\n";
writer << "#include <tbb/flow_graph.h>";
}
......@@ -854,10 +865,8 @@ using namespace ngraph::runtime;
{
writer << "\n";
// Build the flow graph
vector<Node*> dependence_graph_heads;
traverse_nodes(
current_function, [&writer, &dependence_graph_heads](shared_ptr<Node> n) {
traverse_nodes(current_function, [&writer](shared_ptr<Node> n) {
if (!n->is_parameter() && !n->is_constant())
{
bool is_head = true;
......@@ -871,29 +880,19 @@ using namespace ngraph::runtime;
}
}
if (is_head)
{
dependence_graph_heads.emplace_back(n.get());
}
}
});
writer << "\n";
if (!dependence_graph_heads.empty())
{
for (Node* n : dependence_graph_heads)
{
writer << "tbb::flow::make_edge(*flowgraph_node_start"
<< ", *flowgraph_node_" << n->get_name() << ");\n";
}
}
});
writer.indent--;
writer << "}\n";
// Execute the flow graph
writer << "auto start = &(*(ctx->G->begin()));\n";
writer << "((tbb::flow::continue_node<tbb::flow::continue_msg, "
"tbb::flow::lightweight>*)(&(*(ctx->G->begin()))))"
writer << "(static_cast<tbb::flow::continue_node<tbb::flow::continue_msg, "
"tbb::flow::lightweight>*>(&(*(ctx->G->begin()))))"
<< "->try_put(tbb::flow::continue_msg());\n";
writer << "try { ctx->G->wait_for_all(); } catch(...) { throw; }\n";
}
......@@ -1189,6 +1188,7 @@ void runtime::cpu::CPU_ExternalFunction::build()
};
enables.emplace_back(make_pair(enable, functors.size() - functor_count));
enable_nodename_list.emplace_back(make_pair(enable, node->get_name()));
}
executor = [&](CPURuntimeContext* ctx, vector<void*>& inputs, vector<void*>& outputs) {
......@@ -1214,6 +1214,105 @@ void runtime::cpu::CPU_ExternalFunction::build()
}
auto functor = functors.begin();
if (m_use_tbb)
{
// Build the flow graph
if (first_iteration)
{
std::unordered_map<
std::string,
tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight>*>
nodename_tbbnode_map;
tbb::flow::continue_node<tbb::flow::continue_msg,
tbb::flow::lightweight>* flowgraph_node_start =
new tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight>(
*(ctx->G), [&](const tbb::flow::continue_msg& msg) {});
auto it = enable_nodename_list.begin();
for (const auto& p : enables)
{
tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight>*
flowgraph_node = new tbb::flow::continue_node<tbb::flow::continue_msg,
tbb::flow::lightweight>(
*(ctx->G), [&](const tbb::flow::continue_msg& msg) {
if (p.first(ctx) || first_iteration)
{
for (size_t j = 0; j < p.second; j++)
{
if (runtime::cpu::IsTracingEnabled())
{
start_ts = cpu::Clock::now();
}
(*functor)(ctx);
if (runtime::cpu::IsTracingEnabled())
{
ctx->op_durations[profiler_count++] =
(std::chrono::duration_cast<cpu::Timescale>(
cpu::Clock::now() - start_ts))
.count();
}
std::advance(functor, 1);
}
}
else
{
if (runtime::cpu::IsTracingEnabled())
{
for (size_t j = 0; j < p.second; j++)
{
ctx->op_durations[profiler_count++] = 0;
}
}
std::advance(functor, p.second);
}
});
nodename_tbbnode_map.insert({it->second, flowgraph_node});
it++;
}
traverse_nodes(
m_function, [&flowgraph_node_start, &nodename_tbbnode_map](shared_ptr<Node> n) {
if (!n->is_parameter() && !n->is_constant())
{
bool is_head = true;
for (auto arg : n->get_arguments())
{
if (!arg->is_parameter() && !arg->is_constant())
{
is_head = false;
tbb::flow::make_edge(*(nodename_tbbnode_map[arg->get_name()]),
*(nodename_tbbnode_map[n->get_name()]));
}
}
if (is_head)
{
tbb::flow::make_edge(*flowgraph_node_start,
*(nodename_tbbnode_map[n->get_name()]));
}
}
});
if (m_release_function)
{
release_function();
}
}
// Execute the flow graph
(static_cast<
tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight>*>(
&(*(ctx->G->begin()))))
->try_put(tbb::flow::continue_msg());
try
{
ctx->G->wait_for_all();
}
catch (...)
{
throw;
}
}
else
{
for (const auto& p : enables)
{
if (p.first(ctx) || first_iteration)
......@@ -1248,6 +1347,7 @@ void runtime::cpu::CPU_ExternalFunction::build()
std::advance(functor, p.second);
}
}
}
first_iteration = false;
if (runtime::cpu::IsTracingEnabled())
......@@ -1258,7 +1358,7 @@ void runtime::cpu::CPU_ExternalFunction::build()
m_is_built = true;
if (m_release_function)
if (m_release_function && !m_use_tbb)
{
release_function();
}
......
......@@ -166,6 +166,8 @@ namespace ngraph
std::list<std::function<void(CPURuntimeContext*)>> functors;
std::list<std::pair<std::function<bool(CPURuntimeContext*)>, size_t>> enables;
std::list<std::pair<std::function<bool(CPURuntimeContext*)>, std::string>>
enable_nodename_list;
std::function<void(CPURuntimeContext*, std::vector<void*>&, std::vector<void*>&)>
executor;
std::unordered_map<std::string, void*> tensor_data;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment