Unverified Commit 7b9bf2a8 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Address deferred comments from PR 1676 (#1689)

* Address deferred comments from PR 1676

* use dynamic pointer cast for added error checking
parent 53c4b3ce
...@@ -119,17 +119,37 @@ bool runtime::gpu::GPU_Backend::compile(shared_ptr<Function> func) ...@@ -119,17 +119,37 @@ bool runtime::gpu::GPU_Backend::compile(shared_ptr<Function> func)
instance.m_external_function->m_emit_timing = instance.m_performance_counters_enabled; instance.m_external_function->m_emit_timing = instance.m_performance_counters_enabled;
instance.m_external_function->compile(); instance.m_external_function->compile();
instance.m_compiled_function = instance.m_external_function->m_compiled_function; instance.m_compiled_function = instance.m_external_function->m_compiled_function;
instance.m_inputs.resize(func->get_parameters().size());
instance.m_outputs.resize(func->get_output_size());
} }
return true; return true;
} }
void runtime::gpu::GPU_Backend::initialize_io(void** target,
const vector<shared_ptr<runtime::TensorView>>& source)
{
for (size_t i = 0; i < source.size(); i++)
{
shared_ptr<runtime::gpu::GPU_TensorView> tv =
dynamic_pointer_cast<runtime::gpu::GPU_TensorView>(source[i]);
if (tv)
{
target[i] = tv->m_allocated_buffer_pool;
}
else
{
throw invalid_argument("Tensors passed to GPU backend must be GPU Tensors");
}
}
}
bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func, bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func,
const vector<shared_ptr<runtime::TensorView>>& output_tvs, const vector<shared_ptr<runtime::TensorView>>& outputs,
const vector<shared_ptr<runtime::TensorView>>& input_tvs) const vector<shared_ptr<runtime::TensorView>>& inputs)
{ {
bool rc = true; bool rc = true;
validate_call(func, output_tvs, input_tvs); validate_call(func, outputs, inputs);
FunctionInstance& instance = m_function_map[func]; FunctionInstance& instance = m_function_map[func];
if (instance.m_external_function == nullptr) if (instance.m_external_function == nullptr)
...@@ -141,24 +161,11 @@ bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func, ...@@ -141,24 +161,11 @@ bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func,
m_context->prepare_runtime_context(); m_context->prepare_runtime_context();
// Device tensors // Device tensors
vector<void*> inputs; initialize_io(instance.m_inputs.data(), inputs);
vector<void*> outputs; initialize_io(instance.m_outputs.data(), outputs);
for (size_t i = 0; i < input_tvs.size(); i++)
{
shared_ptr<runtime::gpu::GPU_TensorView> tv =
static_pointer_cast<runtime::gpu::GPU_TensorView>(input_tvs[i]);
inputs.push_back(tv->m_allocated_buffer_pool);
}
for (size_t i = 0; i < output_tvs.size(); i++)
{
shared_ptr<runtime::gpu::GPU_TensorView> tv =
static_pointer_cast<runtime::gpu::GPU_TensorView>(output_tvs[i]);
outputs.push_back(tv->m_allocated_buffer_pool);
}
auto ctx = m_context->m_runtime_context.get(); auto ctx = m_context->m_runtime_context.get();
instance.m_compiled_function(inputs.data(), outputs.data(), ctx); instance.m_compiled_function(instance.m_inputs.data(), instance.m_outputs.data(), ctx);
return rc; return rc;
} }
......
...@@ -83,8 +83,19 @@ namespace ngraph ...@@ -83,8 +83,19 @@ namespace ngraph
std::shared_ptr<GPU_ExternalFunction> m_external_function; std::shared_ptr<GPU_ExternalFunction> m_external_function;
bool m_performance_counters_enabled = false; bool m_performance_counters_enabled = false;
EntryPoint m_compiled_function; EntryPoint m_compiled_function;
std::vector<void*> m_inputs;
std::vector<void*> m_outputs;
}; };
/// \brief Convert a vector of TensorView into a vector of void* where each void*
/// points to a TensorView's data buffer.
/// \param target Pointer to a pre-allocated array of void* with
/// size >= source.size()
/// \param source Source vector of TensorViews
static void
initialize_io(void** target,
const std::vector<std::shared_ptr<runtime::TensorView>>& source);
std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map; std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
std::shared_ptr<BackendContext> m_context; std::shared_ptr<BackendContext> m_context;
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment