Commit b1ce57de authored by Robert Kimball's avatar Robert Kimball

write overlaps with execute

parent f51253ac
......@@ -48,36 +48,28 @@ static size_t current_iteration = 0;
static size_t s_iterations;
static size_t s_warmup_iterations;
static void do_iteration(runtime::Executable* exec, const TensorCollection& tensors)
static void
thread_entry(runtime::Executable* exec, const TensorCollection& tensors, size_t pipeline_stage)
{
bool data_written = false;
const vector<shared_ptr<runtime::Tensor>>& args = tensors.input_tensors;
const vector<shared_ptr<runtime::Tensor>>& results = tensors.output_tensors;
for (size_t arg_index = 0; arg_index < args.size(); arg_index++)
while (current_iteration < s_iterations + s_warmup_iterations)
{
const shared_ptr<runtime::Tensor>& arg = args[arg_index];
if (arg->get_stale())
if (!data_written)
{
const shared_ptr<runtime::HostTensor>& data = tensors.parameter_data[arg_index];
arg->write(data->get_data_ptr(),
data->get_element_count() * data->get_element_type().size());
for (size_t arg_index = 0; arg_index < args.size(); arg_index++)
{
const shared_ptr<runtime::Tensor>& arg = args[arg_index];
if (arg->get_stale())
{
const shared_ptr<runtime::HostTensor>& data = tensors.parameter_data[arg_index];
arg->write(data->get_data_ptr(),
data->get_element_count() * data->get_element_type().size());
}
}
data_written = true;
}
}
exec->call(results, args);
for (size_t result_index = 0; result_index < results.size(); result_index++)
{
const shared_ptr<runtime::HostTensor>& data = tensors.result_data[result_index];
const shared_ptr<runtime::Tensor>& result = results[result_index];
result->read(data->get_data_ptr(),
data->get_element_count() * data->get_element_type().size());
}
}
static void
thread_entry(runtime::Executable* exec, const TensorCollection& tensors, size_t pipeline_stage)
// static void thread_entry(size_t pipeline_stage)
{
while (current_iteration < s_iterations + s_warmup_iterations)
{
unique_lock<mutex> lock(s_mutex);
if ((current_iteration & 1) != pipeline_stage)
{
......@@ -87,8 +79,16 @@ static void
{
// our turn to run
NGRAPH_INFO << "stage " << pipeline_stage << " for iteration " << current_iteration;
do_iteration(exec, tensors);
exec->call(results, args);
for (size_t result_index = 0; result_index < results.size(); result_index++)
{
const shared_ptr<runtime::HostTensor>& data = tensors.result_data[result_index];
const shared_ptr<runtime::Tensor>& result = results[result_index];
result->read(data->get_data_ptr(),
data->get_element_count() * data->get_element_type().size());
}
current_iteration++;
data_written = false;
s_condition.notify_all();
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment