Commit b1ce57de authored by Robert Kimball's avatar Robert Kimball

write overlaps with execute

parent f51253ac
...@@ -48,10 +48,16 @@ static size_t current_iteration = 0; ...@@ -48,10 +48,16 @@ static size_t current_iteration = 0;
static size_t s_iterations; static size_t s_iterations;
static size_t s_warmup_iterations; static size_t s_warmup_iterations;
static void do_iteration(runtime::Executable* exec, const TensorCollection& tensors) static void
thread_entry(runtime::Executable* exec, const TensorCollection& tensors, size_t pipeline_stage)
{ {
bool data_written = false;
const vector<shared_ptr<runtime::Tensor>>& args = tensors.input_tensors; const vector<shared_ptr<runtime::Tensor>>& args = tensors.input_tensors;
const vector<shared_ptr<runtime::Tensor>>& results = tensors.output_tensors; const vector<shared_ptr<runtime::Tensor>>& results = tensors.output_tensors;
while (current_iteration < s_iterations + s_warmup_iterations)
{
if (!data_written)
{
for (size_t arg_index = 0; arg_index < args.size(); arg_index++) for (size_t arg_index = 0; arg_index < args.size(); arg_index++)
{ {
const shared_ptr<runtime::Tensor>& arg = args[arg_index]; const shared_ptr<runtime::Tensor>& arg = args[arg_index];
...@@ -62,22 +68,8 @@ static void do_iteration(runtime::Executable* exec, const TensorCollection& tens ...@@ -62,22 +68,8 @@ static void do_iteration(runtime::Executable* exec, const TensorCollection& tens
data->get_element_count() * data->get_element_type().size()); data->get_element_count() * data->get_element_type().size());
} }
} }
exec->call(results, args); data_written = true;
for (size_t result_index = 0; result_index < results.size(); result_index++)
{
const shared_ptr<runtime::HostTensor>& data = tensors.result_data[result_index];
const shared_ptr<runtime::Tensor>& result = results[result_index];
result->read(data->get_data_ptr(),
data->get_element_count() * data->get_element_type().size());
} }
}
static void
thread_entry(runtime::Executable* exec, const TensorCollection& tensors, size_t pipeline_stage)
// static void thread_entry(size_t pipeline_stage)
{
while (current_iteration < s_iterations + s_warmup_iterations)
{
unique_lock<mutex> lock(s_mutex); unique_lock<mutex> lock(s_mutex);
if ((current_iteration & 1) != pipeline_stage) if ((current_iteration & 1) != pipeline_stage)
{ {
...@@ -87,8 +79,16 @@ static void ...@@ -87,8 +79,16 @@ static void
{ {
// our turn to run // our turn to run
NGRAPH_INFO << "stage " << pipeline_stage << " for iteration " << current_iteration; NGRAPH_INFO << "stage " << pipeline_stage << " for iteration " << current_iteration;
do_iteration(exec, tensors); exec->call(results, args);
for (size_t result_index = 0; result_index < results.size(); result_index++)
{
const shared_ptr<runtime::HostTensor>& data = tensors.result_data[result_index];
const shared_ptr<runtime::Tensor>& result = results[result_index];
result->read(data->get_data_ptr(),
data->get_element_count() * data->get_element_type().size());
}
current_iteration++; current_iteration++;
data_written = false;
s_condition.notify_all(); s_condition.notify_all();
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment