Commit 83a9d252 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon Committed by Scott Cyphers

Fix first_iteration (#1294)

* Fix the first_iteration flag so it works when more than one call-frame exists

Static variables defined in lambda expressions are not private to a lambda so
move this to the runtime context

* Shave off a few microseconds by initializing intermediates exactly once

* Make all execution paths use first_iteration in the runtime context
parent 870ab827
...@@ -110,6 +110,9 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context() ...@@ -110,6 +110,9 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context()
ctx->op_durations = new int64_t[m_external_function->get_op_attrs().size()]; ctx->op_durations = new int64_t[m_external_function->get_op_attrs().size()];
} }
ctx->p_en = new bool[m_external_function->get_parameter_layout_descriptors().size()]; ctx->p_en = new bool[m_external_function->get_parameter_layout_descriptors().size()];
ctx->first_iteration = true;
// Create temporary buffer pools // Create temporary buffer pools
size_t alignment = runtime::cpu::CPU_ExternalFunction::s_memory_pool_alignment; size_t alignment = runtime::cpu::CPU_ExternalFunction::s_memory_pool_alignment;
for (auto buffer_size : m_external_function->get_memory_buffer_sizes()) for (auto buffer_size : m_external_function->get_memory_buffer_sizes())
......
...@@ -597,7 +597,6 @@ using namespace ngraph::runtime; ...@@ -597,7 +597,6 @@ using namespace ngraph::runtime;
} }
writer << "bool " << current_function->get_name() << "_t_en[" << tensor_index << "];\n"; writer << "bool " << current_function->get_name() << "_t_en[" << tensor_index << "];\n";
writer << "bool " << current_function->get_name() << "_init = true;\n";
writer << "extern \"C\" void " << current_function->get_name(); writer << "extern \"C\" void " << current_function->get_name();
writer << "(void** inputs, void** outputs, cpu::CPURuntimeContext* ctx)\n"; writer << "(void** inputs, void** outputs, cpu::CPURuntimeContext* ctx)\n";
...@@ -635,7 +634,7 @@ using namespace ngraph::runtime; ...@@ -635,7 +634,7 @@ using namespace ngraph::runtime;
if (m_use_tbb) if (m_use_tbb)
{ {
writer << "\n"; writer << "\n";
writer << "if (" << current_function->get_name() << "_init) {\n"; writer << "if (ctx->first_iteration) {\n";
writer.indent++; writer.indent++;
writer << "tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight>* " writer << "tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight>* "
"flowgraph_node_start" "flowgraph_node_start"
...@@ -759,7 +758,7 @@ using namespace ngraph::runtime; ...@@ -759,7 +758,7 @@ using namespace ngraph::runtime;
// Op Control // Op Control
if (!node->is_parameter() && !node->is_constant()) if (!node->is_parameter() && !node->is_constant())
{ {
writer << "if (" << current_function->get_name() << "_init "; writer << "if (ctx->first_iteration ";
for (const descriptor::Input& input : node->get_inputs()) for (const descriptor::Input& input : node->get_inputs())
{ {
const descriptor::Output& output = input.get_output(); const descriptor::Output& output = input.get_output();
...@@ -896,7 +895,7 @@ using namespace ngraph::runtime; ...@@ -896,7 +895,7 @@ using namespace ngraph::runtime;
<< "->try_put(tbb::flow::continue_msg());\n"; << "->try_put(tbb::flow::continue_msg());\n";
writer << "try { ctx->G->wait_for_all(); } catch(...) { throw; }\n"; writer << "try { ctx->G->wait_for_all(); } catch(...) { throw; }\n";
} }
writer << current_function->get_name() << "_init = false;\n"; writer << "ctx->first_iteration = false;\n";
writer.indent--; writer.indent--;
// End generated function // End generated function
...@@ -1192,14 +1191,16 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1192,14 +1191,16 @@ void runtime::cpu::CPU_ExternalFunction::build()
} }
executor = [&](CPURuntimeContext* ctx, vector<void*>& inputs, vector<void*>& outputs) { executor = [&](CPURuntimeContext* ctx, vector<void*>& inputs, vector<void*>& outputs) {
static bool first_iteration = true;
cpu::Timestamp start_ts; cpu::Timestamp start_ts;
int profiler_count = 0; int profiler_count = 0;
for (auto& p : intermediates_offsets) if (ctx->first_iteration)
{ {
tensor_data[p.first] = for (auto& p : intermediates_offsets)
static_cast<uint8_t*>(ctx->memory_buffers[0]->get_ptr()) + p.second; {
tensor_data[p.first] =
static_cast<uint8_t*>(ctx->memory_buffers[0]->get_ptr()) + p.second;
}
} }
for (const auto& p : function_input_index) for (const auto& p : function_input_index)
...@@ -1217,7 +1218,7 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1217,7 +1218,7 @@ void runtime::cpu::CPU_ExternalFunction::build()
if (m_use_tbb) if (m_use_tbb)
{ {
// Build the flow graph // Build the flow graph
if (first_iteration) if (ctx->first_iteration)
{ {
std::unordered_map< std::unordered_map<
std::string, std::string,
...@@ -1234,7 +1235,7 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1234,7 +1235,7 @@ void runtime::cpu::CPU_ExternalFunction::build()
flowgraph_node = new tbb::flow::continue_node<tbb::flow::continue_msg, flowgraph_node = new tbb::flow::continue_node<tbb::flow::continue_msg,
tbb::flow::lightweight>( tbb::flow::lightweight>(
*(ctx->G), [&](const tbb::flow::continue_msg& msg) { *(ctx->G), [&](const tbb::flow::continue_msg& msg) {
if (p.first(ctx) || first_iteration) if (p.first(ctx) || ctx->first_iteration)
{ {
for (size_t j = 0; j < p.second; j++) for (size_t j = 0; j < p.second; j++)
{ {
...@@ -1315,7 +1316,7 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1315,7 +1316,7 @@ void runtime::cpu::CPU_ExternalFunction::build()
{ {
for (const auto& p : enables) for (const auto& p : enables)
{ {
if (p.first(ctx) || first_iteration) if (p.first(ctx) || ctx->first_iteration)
{ {
for (size_t j = 0; j < p.second; j++) for (size_t j = 0; j < p.second; j++)
{ {
...@@ -1348,7 +1349,7 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1348,7 +1349,7 @@ void runtime::cpu::CPU_ExternalFunction::build()
} }
} }
} }
first_iteration = false; ctx->first_iteration = false;
if (runtime::cpu::IsTracingEnabled()) if (runtime::cpu::IsTracingEnabled())
{ {
......
...@@ -52,6 +52,7 @@ namespace ngraph ...@@ -52,6 +52,7 @@ namespace ngraph
{ {
int64_t* op_durations; int64_t* op_durations;
bool* p_en; bool* p_en;
bool first_iteration;
mkldnn::primitive* const* mkldnn_primitives; mkldnn::primitive* const* mkldnn_primitives;
std::vector<AlignedBuffer*> memory_buffers; std::vector<AlignedBuffer*> memory_buffers;
char* const* mkldnn_workspaces; char* const* mkldnn_workspaces;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment