Unverified Commit 7b9bf2a8 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Address deferred comments from PR 1676 (#1689)

* Address deferred comments from PR 1676

* use dynamic pointer cast for added error checking
parent 53c4b3ce
......@@ -119,17 +119,37 @@ bool runtime::gpu::GPU_Backend::compile(shared_ptr<Function> func)
instance.m_external_function->m_emit_timing = instance.m_performance_counters_enabled;
instance.m_external_function->compile();
instance.m_compiled_function = instance.m_external_function->m_compiled_function;
instance.m_inputs.resize(func->get_parameters().size());
instance.m_outputs.resize(func->get_output_size());
}
return true;
}
void runtime::gpu::GPU_Backend::initialize_io(void** target,
const vector<shared_ptr<runtime::TensorView>>& source)
{
for (size_t i = 0; i < source.size(); i++)
{
shared_ptr<runtime::gpu::GPU_TensorView> tv =
dynamic_pointer_cast<runtime::gpu::GPU_TensorView>(source[i]);
if (tv)
{
target[i] = tv->m_allocated_buffer_pool;
}
else
{
throw invalid_argument("Tensors passed to GPU backend must be GPU Tensors");
}
}
}
bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func,
const vector<shared_ptr<runtime::TensorView>>& output_tvs,
const vector<shared_ptr<runtime::TensorView>>& input_tvs)
const vector<shared_ptr<runtime::TensorView>>& outputs,
const vector<shared_ptr<runtime::TensorView>>& inputs)
{
bool rc = true;
validate_call(func, output_tvs, input_tvs);
validate_call(func, outputs, inputs);
FunctionInstance& instance = m_function_map[func];
if (instance.m_external_function == nullptr)
......@@ -141,24 +161,11 @@ bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func,
m_context->prepare_runtime_context();
// Device tensors
vector<void*> inputs;
vector<void*> outputs;
for (size_t i = 0; i < input_tvs.size(); i++)
{
shared_ptr<runtime::gpu::GPU_TensorView> tv =
static_pointer_cast<runtime::gpu::GPU_TensorView>(input_tvs[i]);
inputs.push_back(tv->m_allocated_buffer_pool);
}
for (size_t i = 0; i < output_tvs.size(); i++)
{
shared_ptr<runtime::gpu::GPU_TensorView> tv =
static_pointer_cast<runtime::gpu::GPU_TensorView>(output_tvs[i]);
outputs.push_back(tv->m_allocated_buffer_pool);
}
initialize_io(instance.m_inputs.data(), inputs);
initialize_io(instance.m_outputs.data(), outputs);
auto ctx = m_context->m_runtime_context.get();
instance.m_compiled_function(inputs.data(), outputs.data(), ctx);
instance.m_compiled_function(instance.m_inputs.data(), instance.m_outputs.data(), ctx);
return rc;
}
......
......@@ -83,8 +83,19 @@ namespace ngraph
std::shared_ptr<GPU_ExternalFunction> m_external_function;
bool m_performance_counters_enabled = false;
EntryPoint m_compiled_function;
std::vector<void*> m_inputs;
std::vector<void*> m_outputs;
};
/// \brief Convert a vector of TensorView into a vector of void* where each void*
/// points to a TensorView's data buffer.
/// \param target Pointer to a pre-allocated array of void* with
/// size >= source.size()
/// \param source Source vector of TensorViews
static void
initialize_io(void** target,
const std::vector<std::shared_ptr<runtime::TensorView>>& source);
std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
std::shared_ptr<BackendContext> m_context;
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment