Commit 8214cd39 authored by Robert Kimball's avatar Robert Kimball

update GPU backend

parent 222e0811
...@@ -49,7 +49,6 @@ extern "C" void delete_backend(runtime::Backend* backend) ...@@ -49,7 +49,6 @@ extern "C" void delete_backend(runtime::Backend* backend)
runtime::gpu::GPU_Backend::GPU_Backend() runtime::gpu::GPU_Backend::GPU_Backend()
: runtime::Backend() : runtime::Backend()
, m_context(new BackendContext())
{ {
} }
...@@ -118,24 +117,43 @@ shared_ptr<runtime::Tensor> runtime::gpu::GPU_Backend::create_tensor( ...@@ -118,24 +117,43 @@ shared_ptr<runtime::Tensor> runtime::gpu::GPU_Backend::create_tensor(
return make_shared<runtime::gpu::GPUTensor>(element_type, shape, memory_pointer, this); return make_shared<runtime::gpu::GPUTensor>(element_type, shape, memory_pointer, this);
} }
runtime::Handle runtime::gpu::GPU_Backend::compile(shared_ptr<Function> func) shared_ptr<runtime::Executable> runtime::gpu::GPU_Backend::compile(shared_ptr<Function> func,
bool timing_enable)
{ {
FunctionInstance& instance = m_function_map[func]; shared_ptr<runtime::Executable> rc;
auto it = m_exec_map.find(func);
if (it != m_exec_map.end())
{
rc = it->second;
}
else
{
rc = make_shared<GPU_Executable>(func, timing_enable);
m_exec_map.insert({func, rc});
}
return rc;
}
runtime::gpu::GPU_Executable::GPU_Executable(shared_ptr<Function> func, bool enable_timing)
: m_context(new GPU_Backend::BackendContext())
{
FunctionInstance& instance = m_function_instance;
if (instance.m_compiled_function == nullptr) if (instance.m_compiled_function == nullptr)
{ {
m_context->bind_cuda_context_to_thread(); m_context->bind_cuda_context_to_thread();
instance.m_compiled_function = runtime::gpu::GPUCompiledFunction::make(func, m_context); instance.m_compiled_function = runtime::gpu::GPUCompiledFunction::make(func, m_context);
instance.m_compiled_function->m_emit_timing = instance.m_performance_counters_enabled; instance.m_compiled_function->m_emit_timing = enable_timing;
instance.m_compiled_function->compile(); instance.m_compiled_function->compile();
instance.m_runtime = instance.m_compiled_function->m_runtime; instance.m_runtime = instance.m_compiled_function->m_runtime;
instance.m_inputs.resize(func->get_parameters().size()); instance.m_inputs.resize(func->get_parameters().size());
instance.m_outputs.resize(func->get_output_size()); instance.m_outputs.resize(func->get_output_size());
} }
return func; set_parameters_and_results(*func);
} }
void runtime::gpu::GPU_Backend::initialize_io(void** target, void runtime::gpu::GPU_Executable::initialize_io(void** target,
const vector<shared_ptr<runtime::Tensor>>& source) const vector<shared_ptr<runtime::Tensor>>& source)
{ {
for (size_t i = 0; i < source.size(); i++) for (size_t i = 0; i < source.size(); i++)
{ {
...@@ -152,11 +170,10 @@ void runtime::gpu::GPU_Backend::initialize_io(void** target, ...@@ -152,11 +170,10 @@ void runtime::gpu::GPU_Backend::initialize_io(void** target,
} }
} }
bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func, bool runtime::gpu::GPU_Executable::call(const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs) const vector<shared_ptr<runtime::Tensor>>& inputs)
{ {
FunctionInstance& instance = m_function_map[func]; FunctionInstance& instance = m_function_instance;
if (instance.m_compiled_function == nullptr) if (instance.m_compiled_function == nullptr)
{ {
throw runtime_error("compile() must be called before call()."); throw runtime_error("compile() must be called before call().");
...@@ -175,33 +192,18 @@ bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func, ...@@ -175,33 +192,18 @@ bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func,
return true; return true;
} }
void runtime::gpu::GPU_Backend::remove_compiled_function(shared_ptr<Function> func) // void runtime::gpu::GPU_Backend::remove_compiled_function(shared_ptr<Function> func)
{ // {
m_function_map.erase(func); // m_function_map.erase(func);
} // }
void runtime::gpu::GPU_Backend::enable_performance_data(shared_ptr<Function> func, bool enable)
{
FunctionInstance& instance = m_function_map[func];
if (instance.m_compiled_function != nullptr)
{
throw runtime_error("Performance data collection must be enabled prior to compiling.");
}
instance.m_performance_counters_enabled = enable;
}
vector<runtime::PerformanceCounter> vector<runtime::PerformanceCounter> runtime::gpu::GPU_Executable::get_performance_data() const
runtime::gpu::GPU_Backend::get_performance_data(shared_ptr<Function> func) const
{ {
std::vector<runtime::PerformanceCounter> rc; std::vector<runtime::PerformanceCounter> rc;
auto it = m_function_map.find(func); const FunctionInstance& instance = m_function_instance;
if (it != m_function_map.end()) if (instance.m_compiled_function != nullptr)
{ {
const FunctionInstance& instance = it->second; instance.m_compiled_function->get_performance_data(rc);
if (instance.m_compiled_function != nullptr)
{
instance.m_compiled_function->get_performance_data(rc);
}
} }
return rc; return rc;
} }
......
...@@ -51,16 +51,8 @@ namespace ngraph ...@@ -51,16 +51,8 @@ namespace ngraph
create_tensor(const ngraph::element::Type& element_type, create_tensor(const ngraph::element::Type& element_type,
const Shape& shape) override; const Shape& shape) override;
Handle compile(std::shared_ptr<Function> func) override; std::shared_ptr<runtime::Executable> compile(std::shared_ptr<Function> func,
bool timing_enabled = false) override;
bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) override;
void remove_compiled_function(std::shared_ptr<Function> func) override;
void enable_performance_data(std::shared_ptr<Function> func, bool enable) override;
std::vector<PerformanceCounter>
get_performance_data(std::shared_ptr<Function> func) const override;
bool is_supported(const Node& node) const override; bool is_supported(const Node& node) const override;
...@@ -79,6 +71,21 @@ namespace ngraph ...@@ -79,6 +71,21 @@ namespace ngraph
std::unique_ptr<CudaContextManager> m_cuda_manager; std::unique_ptr<CudaContextManager> m_cuda_manager;
}; };
private:
std::map<std::shared_ptr<Function>, std::shared_ptr<Executable>> m_exec_map;
};
class GPU_Executable : public Executable
{
public:
GPU_Executable(std::shared_ptr<Function> func, bool enable_timing);
bool call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) override;
// void remove_compiled_function(std::shared_ptr<Function> func) override;
std::vector<PerformanceCounter> get_performance_data() const override;
private: private:
class FunctionInstance class FunctionInstance
{ {
...@@ -88,7 +95,7 @@ namespace ngraph ...@@ -88,7 +95,7 @@ namespace ngraph
EntryPoint m_runtime; EntryPoint m_runtime;
std::vector<void*> m_inputs; std::vector<void*> m_inputs;
std::vector<void*> m_outputs; std::vector<void*> m_outputs;
}; } m_function_instance;
/// \brief Convert a vector of Tensor into a vector of void* where each void* /// \brief Convert a vector of Tensor into a vector of void* where each void*
/// points to a Tensor's data buffer. /// points to a Tensor's data buffer.
...@@ -99,8 +106,7 @@ namespace ngraph ...@@ -99,8 +106,7 @@ namespace ngraph
initialize_io(void** target, initialize_io(void** target,
const std::vector<std::shared_ptr<runtime::Tensor>>& source); const std::vector<std::shared_ptr<runtime::Tensor>>& source);
std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map; std::shared_ptr<GPU_Backend::BackendContext> m_context;
std::shared_ptr<BackendContext> m_context;
}; };
} }
} }
......
...@@ -49,6 +49,7 @@ namespace ngraph ...@@ -49,6 +49,7 @@ namespace ngraph
class GPUCompiledFunction class GPUCompiledFunction
{ {
friend class GPU_Backend; friend class GPU_Backend;
friend class GPU_Executable;
public: public:
GPUCompiledFunction( GPUCompiledFunction(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment