Commit b1ce57de authored by Robert Kimball's avatar Robert Kimball

write overlaps with execute

parent f51253ac
...@@ -48,36 +48,28 @@ static size_t current_iteration = 0; ...@@ -48,36 +48,28 @@ static size_t current_iteration = 0;
static size_t s_iterations; static size_t s_iterations;
static size_t s_warmup_iterations; static size_t s_warmup_iterations;
static void do_iteration(runtime::Executable* exec, const TensorCollection& tensors) static void
thread_entry(runtime::Executable* exec, const TensorCollection& tensors, size_t pipeline_stage)
{ {
bool data_written = false;
const vector<shared_ptr<runtime::Tensor>>& args = tensors.input_tensors; const vector<shared_ptr<runtime::Tensor>>& args = tensors.input_tensors;
const vector<shared_ptr<runtime::Tensor>>& results = tensors.output_tensors; const vector<shared_ptr<runtime::Tensor>>& results = tensors.output_tensors;
for (size_t arg_index = 0; arg_index < args.size(); arg_index++) while (current_iteration < s_iterations + s_warmup_iterations)
{ {
const shared_ptr<runtime::Tensor>& arg = args[arg_index]; if (!data_written)
if (arg->get_stale())
{ {
const shared_ptr<runtime::HostTensor>& data = tensors.parameter_data[arg_index]; for (size_t arg_index = 0; arg_index < args.size(); arg_index++)
arg->write(data->get_data_ptr(), {
data->get_element_count() * data->get_element_type().size()); const shared_ptr<runtime::Tensor>& arg = args[arg_index];
if (arg->get_stale())
{
const shared_ptr<runtime::HostTensor>& data = tensors.parameter_data[arg_index];
arg->write(data->get_data_ptr(),
data->get_element_count() * data->get_element_type().size());
}
}
data_written = true;
} }
}
exec->call(results, args);
for (size_t result_index = 0; result_index < results.size(); result_index++)
{
const shared_ptr<runtime::HostTensor>& data = tensors.result_data[result_index];
const shared_ptr<runtime::Tensor>& result = results[result_index];
result->read(data->get_data_ptr(),
data->get_element_count() * data->get_element_type().size());
}
}
static void
thread_entry(runtime::Executable* exec, const TensorCollection& tensors, size_t pipeline_stage)
// static void thread_entry(size_t pipeline_stage)
{
while (current_iteration < s_iterations + s_warmup_iterations)
{
unique_lock<mutex> lock(s_mutex); unique_lock<mutex> lock(s_mutex);
if ((current_iteration & 1) != pipeline_stage) if ((current_iteration & 1) != pipeline_stage)
{ {
...@@ -87,8 +79,16 @@ static void ...@@ -87,8 +79,16 @@ static void
{ {
// our turn to run // our turn to run
NGRAPH_INFO << "stage " << pipeline_stage << " for iteration " << current_iteration; NGRAPH_INFO << "stage " << pipeline_stage << " for iteration " << current_iteration;
do_iteration(exec, tensors); exec->call(results, args);
for (size_t result_index = 0; result_index < results.size(); result_index++)
{
const shared_ptr<runtime::HostTensor>& data = tensors.result_data[result_index];
const shared_ptr<runtime::Tensor>& result = results[result_index];
result->read(data->get_data_ptr(),
data->get_element_count() * data->get_element_type().size());
}
current_iteration++; current_iteration++;
data_written = false;
s_condition.notify_all(); s_condition.notify_all();
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment