Commit 83a9d252 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon Committed by Scott Cyphers

Fix first_iteration (#1294)

* Fix the first_iteration flag so it works when more than one call-frame exists

Static variables defined in lambda expressions are not private to a lambda so
move this to the runtime context

* Shave off a few microseconds by initializing intermediates exactly once

* Make all execution paths use first_iteration in the runtime context
parent 870ab827
......@@ -110,6 +110,9 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context()
ctx->op_durations = new int64_t[m_external_function->get_op_attrs().size()];
}
ctx->p_en = new bool[m_external_function->get_parameter_layout_descriptors().size()];
ctx->first_iteration = true;
// Create temporary buffer pools
size_t alignment = runtime::cpu::CPU_ExternalFunction::s_memory_pool_alignment;
for (auto buffer_size : m_external_function->get_memory_buffer_sizes())
......
......@@ -597,7 +597,6 @@ using namespace ngraph::runtime;
}
writer << "bool " << current_function->get_name() << "_t_en[" << tensor_index << "];\n";
writer << "bool " << current_function->get_name() << "_init = true;\n";
writer << "extern \"C\" void " << current_function->get_name();
writer << "(void** inputs, void** outputs, cpu::CPURuntimeContext* ctx)\n";
......@@ -635,7 +634,7 @@ using namespace ngraph::runtime;
if (m_use_tbb)
{
writer << "\n";
writer << "if (" << current_function->get_name() << "_init) {\n";
writer << "if (ctx->first_iteration) {\n";
writer.indent++;
writer << "tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight>* "
"flowgraph_node_start"
......@@ -759,7 +758,7 @@ using namespace ngraph::runtime;
// Op Control
if (!node->is_parameter() && !node->is_constant())
{
writer << "if (" << current_function->get_name() << "_init ";
writer << "if (ctx->first_iteration ";
for (const descriptor::Input& input : node->get_inputs())
{
const descriptor::Output& output = input.get_output();
......@@ -896,7 +895,7 @@ using namespace ngraph::runtime;
<< "->try_put(tbb::flow::continue_msg());\n";
writer << "try { ctx->G->wait_for_all(); } catch(...) { throw; }\n";
}
writer << current_function->get_name() << "_init = false;\n";
writer << "ctx->first_iteration = false;\n";
writer.indent--;
// End generated function
......@@ -1192,15 +1191,17 @@ void runtime::cpu::CPU_ExternalFunction::build()
}
executor = [&](CPURuntimeContext* ctx, vector<void*>& inputs, vector<void*>& outputs) {
static bool first_iteration = true;
cpu::Timestamp start_ts;
int profiler_count = 0;
if (ctx->first_iteration)
{
for (auto& p : intermediates_offsets)
{
tensor_data[p.first] =
static_cast<uint8_t*>(ctx->memory_buffers[0]->get_ptr()) + p.second;
}
}
for (const auto& p : function_input_index)
{
......@@ -1217,7 +1218,7 @@ void runtime::cpu::CPU_ExternalFunction::build()
if (m_use_tbb)
{
// Build the flow graph
if (first_iteration)
if (ctx->first_iteration)
{
std::unordered_map<
std::string,
......@@ -1234,7 +1235,7 @@ void runtime::cpu::CPU_ExternalFunction::build()
flowgraph_node = new tbb::flow::continue_node<tbb::flow::continue_msg,
tbb::flow::lightweight>(
*(ctx->G), [&](const tbb::flow::continue_msg& msg) {
if (p.first(ctx) || first_iteration)
if (p.first(ctx) || ctx->first_iteration)
{
for (size_t j = 0; j < p.second; j++)
{
......@@ -1315,7 +1316,7 @@ void runtime::cpu::CPU_ExternalFunction::build()
{
for (const auto& p : enables)
{
if (p.first(ctx) || first_iteration)
if (p.first(ctx) || ctx->first_iteration)
{
for (size_t j = 0; j < p.second; j++)
{
......@@ -1348,7 +1349,7 @@ void runtime::cpu::CPU_ExternalFunction::build()
}
}
}
first_iteration = false;
ctx->first_iteration = false;
if (runtime::cpu::IsTracingEnabled())
{
......
......@@ -52,6 +52,7 @@ namespace ngraph
{
int64_t* op_durations;
bool* p_en;
bool first_iteration;
mkldnn::primitive* const* mkldnn_primitives;
std::vector<AlignedBuffer*> memory_buffers;
char* const* mkldnn_workspaces;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment