Commit a8fb4fe0 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Backend support for static and dynamically linked backends (#1361)

* separate backend updates for shared and static libs into separate PR

* add search for shared libs to statically linked libs

* update per review comments

* update per review comments
parent 87eee4d3
......@@ -139,6 +139,7 @@ set (SRC
pattern/matcher.cpp
runtime/aligned_buffer.cpp
runtime/backend.cpp
runtime/backend_manager.cpp
runtime/host_tensor_view.cpp
runtime/tensor_view.cpp
serializer.cpp
......
......@@ -14,164 +14,29 @@
* limitations under the License.
*******************************************************************************/
#ifdef WIN32
#include <windows.h>
#else
#include <dlfcn.h>
#endif
#include <sstream>
#include "ngraph/file_util.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
#ifdef WIN32
#define OPEN_LIBRARY(a, b) LoadLibrary(a)
#define CLOSE_LIBRARY(a) FreeLibrary(a)
#define DLSYM(a, b) GetProcAddress(a, b)
#else
// #define OPEN_LIBRARY(a, b) dlopen(a, b)
#define CLOSE_LIBRARY(a) dlclose(a)
#define DLSYM(a, b) dlsym(a, b)
#endif
runtime::Backend::~Backend()
{
}
// This doodad finds the full path of the containing shared library
static string find_my_file()
{
#ifdef WIN32
return ".";
#else
Dl_info dl_info;
dladdr(reinterpret_cast<void*>(find_my_file), &dl_info);
return dl_info.dli_fname;
#endif
}
DL_HANDLE runtime::Backend::open_shared_library(string type)
{
string ext = SHARED_LIB_EXT;
DL_HANDLE handle;
// strip off attributes, IE:CPU becomes IE
auto colon = type.find(":");
if (colon != type.npos)
{
type = type.substr(0, colon);
}
string library_name = "lib" + to_lower(type) + "_backend" + string(SHARED_LIB_EXT);
string my_directory = file_util::get_directory(find_my_file());
string library_path = file_util::path_join(my_directory, library_name);
#ifdef WIN32
handle = LoadLibrary(library_path.c_str());
#else
handle = dlopen(library_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
#endif
return handle;
}
shared_ptr<runtime::Backend> runtime::Backend::create(const string& type)
{
shared_ptr<runtime::Backend> rc;
DL_HANDLE handle = open_shared_library(type);
if (!handle)
{
throw runtime_error("Backend '" + type + "' not found");
}
else
{
function<const char*()> get_ngraph_version_string =
reinterpret_cast<const char* (*)()>(DLSYM(handle, "get_ngraph_version_string"));
if (!get_ngraph_version_string)
{
CLOSE_LIBRARY(handle);
throw runtime_error("Backend '" + type +
"' does not implement get_ngraph_version_string");
}
function<runtime::Backend*(const char*)> new_backend =
reinterpret_cast<runtime::Backend* (*)(const char*)>(DLSYM(handle, "new_backend"));
if (!new_backend)
{
CLOSE_LIBRARY(handle);
throw runtime_error("Backend '" + type + "' does not implement new_backend");
}
function<void(runtime::Backend*)> delete_backend =
reinterpret_cast<void (*)(runtime::Backend*)>(DLSYM(handle, "delete_backend"));
if (!delete_backend)
{
CLOSE_LIBRARY(handle);
throw runtime_error("Backend '" + type + "' does not implement delete_backend");
}
runtime::Backend* backend = new_backend(type.c_str());
rc = shared_ptr<runtime::Backend>(backend, [=](runtime::Backend* b) {
delete_backend(b);
// CLOSE_LIBRARY(handle);
});
}
return rc;
}
map<string, string> runtime::Backend::get_registered_device_map()
{
map<string, string> rc;
string my_directory = file_util::get_directory(find_my_file());
vector<string> backend_list;
auto f = [&](const string& file, bool is_dir) {
string name = file_util::get_file_name(file);
string backend_name;
if (is_backend_name(name, backend_name))
{
DL_HANDLE handle;
#ifdef WIN32
handle = LoadLibrary(file.c_str());
#else
handle = dlopen(file.c_str(), RTLD_LAZY | RTLD_LOCAL);
#endif
if (handle)
{
if (DLSYM(handle, "new_backend") && DLSYM(handle, "delete_backend"))
{
function<const char*()> get_ngraph_version_string =
reinterpret_cast<const char* (*)()>(
DLSYM(handle, "get_ngraph_version_string"));
if (get_ngraph_version_string &&
get_ngraph_version_string() == string(NGRAPH_VERSION))
{
rc.insert({to_upper(backend_name), file});
}
}
CLOSE_LIBRARY(handle);
}
}
};
file_util::iterate_files(my_directory, f, false, true);
return rc;
return BackendManager::create_backend(type);
}
vector<string> runtime::Backend::get_registered_devices()
{
map<string, string> m = get_registered_device_map();
vector<string> rc;
for (const pair<string, string>& p : m)
{
rc.push_back(p.first);
}
return rc;
return BackendManager::get_registered_backends();
}
void runtime::Backend::remove_compiled_function(shared_ptr<Function> func)
......@@ -244,23 +109,3 @@ void runtime::Backend::validate_call(shared_ptr<const Function> function,
}
}
}
bool runtime::Backend::is_backend_name(const string& file, string& backend_name)
{
string name = file_util::get_file_name(file);
string ext = SHARED_LIB_EXT;
bool rc = false;
if (!name.compare(0, 3, "lib"))
{
if (!name.compare(name.size() - ext.size(), ext.size(), ext))
{
auto pos = name.find("_backend");
if (pos != name.npos)
{
backend_name = name.substr(3, pos - 3);
rc = true;
}
}
}
return rc;
}
......@@ -17,83 +17,97 @@
#pragma once
#include <memory>
#include <string>
#include "ngraph/function.hpp"
#include "ngraph/runtime/performance_counter.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#ifdef WIN32
#include <windows.h>
#define DL_HANDLE HMODULE
#else
#define DL_HANDLE void*
#endif
namespace ngraph
{
namespace runtime
{
class ExternalFunction;
class TensorView;
/// @brief Interface to a generic backend.
///
/// Backends are responsible for function execution and value allocation.
class Backend
{
public:
virtual ~Backend();
/// @brief Create a new Backend object
/// @param type The name of a registered backend, such as "CPU" or "GPU".
/// To select a subdevice use "GPU:N" where s`N` is the subdevice number.
/// @returns shared_ptr to a new Backend or nullptr if the named backend
/// does not exist.
static std::shared_ptr<Backend> create(const std::string& type);
/// @brief Query the list of registered devices
/// @returns A vector of all registered devices.
static std::vector<std::string> get_registered_devices();
virtual std::shared_ptr<ngraph::runtime::TensorView>
create_tensor(const ngraph::element::Type& element_type, const Shape& shape) = 0;
/// @brief Return a handle for a tensor for given mem on backend device
virtual std::shared_ptr<ngraph::runtime::TensorView>
create_tensor(const ngraph::element::Type& element_type,
const Shape& shape,
void* memory_pointer) = 0;
template <typename T>
std::shared_ptr<ngraph::runtime::TensorView> create_tensor(const Shape& shape)
{
return create_tensor(element::from<T>(), shape);
}
virtual bool compile(std::shared_ptr<Function> func) = 0;
virtual bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) = 0;
virtual void remove_compiled_function(std::shared_ptr<Function> func);
virtual void enable_performance_data(std::shared_ptr<Function> func, bool enable) {}
virtual std::vector<PerformanceCounter>
get_performance_data(std::shared_ptr<Function> func) const;
static bool register_backend(const std::string& name, std::shared_ptr<Backend>);
protected:
void validate_call(std::shared_ptr<const Function> func,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs);
private:
static DL_HANDLE open_shared_library(std::string type);
static std::map<std::string, std::string> get_registered_device_map();
static bool is_backend_name(const std::string& file, std::string& backend_name);
};
class Backend;
}
}
/// @brief Interface to a generic backend.
///
/// Backends are responsible for function execution and value allocation.
class ngraph::runtime::Backend
{
public:
virtual ~Backend();
/// @brief Create a new Backend object
/// @param type The name of a registered backend, such as "CPU" or "GPU".
/// To select a subdevice use "GPU:N" where s`N` is the subdevice number.
/// @returns shared_ptr to a new Backend or nullptr if the named backend
/// does not exist.
static std::shared_ptr<Backend> create(const std::string& type);
/// @brief Query the list of registered devices
/// @returns A vector of all registered devices.
static std::vector<std::string> get_registered_devices();
/// @brief Create a tensor specific to this backend
/// @param element_type The type of the tensor element
/// @param shape The shape of the tensor
/// @returns shared_ptr to a new backend specific tensor
virtual std::shared_ptr<ngraph::runtime::TensorView>
create_tensor(const ngraph::element::Type& element_type, const Shape& shape) = 0;
/// @brief Create a tensor specific to this backend
/// @param element_type The type of the tensor element
/// @param shape The shape of the tensor
/// @param memory_pointer A pointer to a buffer used for this tensor. The size of the buffer
/// must be sufficient to contain the tensor. The lifetime of the buffer is the
/// responsibility of the caller.
/// @returns shared_ptr to a new backend specific tensor
virtual std::shared_ptr<ngraph::runtime::TensorView> create_tensor(
const ngraph::element::Type& element_type, const Shape& shape, void* memory_pointer) = 0;
/// @brief Create a tensor of C type T specific to this backend
/// @param shape The shape of the tensor
/// @returns shared_ptr to a new backend specific tensor
template <typename T>
std::shared_ptr<ngraph::runtime::TensorView> create_tensor(const Shape& shape)
{
return create_tensor(element::from<T>(), shape);
}
/// @brief Compiles a Function.
/// @param func The function to compile
/// @returns true if compile is successful, false otherwise
virtual bool compile(std::shared_ptr<Function> func) = 0;
/// @brief Executes a single iteration of a Function. If func is not compiled the call will
/// compile it.
/// @param func The function to execute
/// @returns true if iteration is successful, false otherwise
virtual bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) = 0;
/// @brief Compiled functions may be cached. This function removes a compiled function
/// from the cache.
/// @param func The function to execute
virtual void remove_compiled_function(std::shared_ptr<Function> func);
/// @brief Enable the collection of per op performance information on a specified Function.
/// Data is collection via the `get_performance_data` method.
/// @param func The function to collect perfomance data on.
/// @param enable Set to true to enable or false to disable data collection
virtual void enable_performance_data(std::shared_ptr<Function> func, bool enable) {}
/// @brief Collect performance information gathered on a Function.
/// @param func The function to get collected data.
/// @returns Vector of PerformanceCounter information.
virtual std::vector<PerformanceCounter>
get_performance_data(std::shared_ptr<Function> func) const;
protected:
void validate_call(std::shared_ptr<const Function> func,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs);
};
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifdef WIN32
#include <windows.h>
#else
#include <dlfcn.h>
#endif
#include <sstream>
#include "ngraph/file_util.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
#ifdef WIN32
#define CLOSE_LIBRARY(a) FreeLibrary(a)
#define DLSYM(a, b) GetProcAddress(a, b)
#else
#define CLOSE_LIBRARY(a) dlclose(a)
#define DLSYM(a, b) dlsym(a, b)
#endif
unordered_map<string, runtime::new_backend_t>& runtime::BackendManager::get_registry()
{
static unordered_map<string, new_backend_t> s_registered_backend;
return s_registered_backend;
}
void runtime::BackendManager::register_backend(const string& name, new_backend_t new_backend)
{
get_registry()[name] = new_backend;
}
vector<string> runtime::BackendManager::get_registered_backends()
{
vector<string> rc;
for (const auto& p : get_registry())
{
rc.push_back(p.first);
}
for (const auto& p : get_registered_device_map())
{
if (find(rc.begin(), rc.end(), p.first) == rc.end())
{
rc.push_back(p.first);
}
}
return rc;
}
shared_ptr<runtime::Backend> runtime::BackendManager::create_backend(const std::string& config)
{
shared_ptr<runtime::Backend> rc;
string type = config;
// strip off attributes, IE:CPU becomes IE
auto colon = type.find(":");
if (colon != type.npos)
{
type = type.substr(0, colon);
}
auto registry = get_registry();
auto it = registry.find(type);
if (it != registry.end())
{
new_backend_t new_backend = it->second;
rc = shared_ptr<runtime::Backend>(new_backend(config.c_str()));
}
else
{
DL_HANDLE handle = open_shared_library(type);
if (!handle)
{
stringstream ss;
ss << "Backend '" << type << "' not registered";
throw runtime_error(ss.str());
}
function<const char*()> get_ngraph_version_string =
reinterpret_cast<const char* (*)()>(DLSYM(handle, "get_ngraph_version_string"));
if (!get_ngraph_version_string)
{
CLOSE_LIBRARY(handle);
throw runtime_error("Backend '" + type +
"' does not implement get_ngraph_version_string");
}
function<runtime::Backend*(const char*)> new_backend =
reinterpret_cast<runtime::Backend* (*)(const char*)>(DLSYM(handle, "new_backend"));
if (!new_backend)
{
CLOSE_LIBRARY(handle);
throw runtime_error("Backend '" + type + "' does not implement new_backend");
}
function<void(runtime::Backend*)> delete_backend =
reinterpret_cast<void (*)(runtime::Backend*)>(DLSYM(handle, "delete_backend"));
if (!delete_backend)
{
CLOSE_LIBRARY(handle);
throw runtime_error("Backend '" + type + "' does not implement delete_backend");
}
runtime::Backend* backend = new_backend(config.c_str());
rc = shared_ptr<runtime::Backend>(backend, [=](runtime::Backend* b) {
delete_backend(b);
// CLOSE_LIBRARY(handle);
});
}
return rc;
}
// This doodad finds the full path of the containing shared library
static string find_my_file()
{
#ifdef WIN32
return ".";
#else
Dl_info dl_info;
dladdr(reinterpret_cast<void*>(find_my_file), &dl_info);
return dl_info.dli_fname;
#endif
}
DL_HANDLE runtime::BackendManager::open_shared_library(string type)
{
string ext = SHARED_LIB_EXT;
DL_HANDLE handle = nullptr;
// strip off attributes, IE:CPU becomes IE
auto colon = type.find(":");
if (colon != type.npos)
{
type = type.substr(0, colon);
}
string library_name = "lib" + to_lower(type) + "_backend" + string(SHARED_LIB_EXT);
string my_directory = file_util::get_directory(find_my_file());
string library_path = file_util::path_join(my_directory, library_name);
#ifdef WIN32
handle = LoadLibrary(library_path.c_str());
#else
handle = dlopen(library_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
#endif
return handle;
}
map<string, string> runtime::BackendManager::get_registered_device_map()
{
map<string, string> rc;
string my_directory = file_util::get_directory(find_my_file());
vector<string> backend_list;
auto f = [&](const string& file, bool is_dir) {
string name = file_util::get_file_name(file);
string backend_name;
if (is_backend_name(name, backend_name))
{
DL_HANDLE handle;
#ifdef WIN32
handle = LoadLibrary(file.c_str());
#else
handle = dlopen(file.c_str(), RTLD_LAZY | RTLD_LOCAL);
#endif
if (handle)
{
if (DLSYM(handle, "new_backend") && DLSYM(handle, "delete_backend"))
{
function<const char*()> get_ngraph_version_string =
reinterpret_cast<const char* (*)()>(
DLSYM(handle, "get_ngraph_version_string"));
if (get_ngraph_version_string &&
get_ngraph_version_string() == string(NGRAPH_VERSION))
{
rc.insert({to_upper(backend_name), file});
}
}
CLOSE_LIBRARY(handle);
}
}
};
file_util::iterate_files(my_directory, f, false, true);
return rc;
}
bool runtime::BackendManager::is_backend_name(const string& file, string& backend_name)
{
string name = file_util::get_file_name(file);
string ext = SHARED_LIB_EXT;
bool rc = false;
if (!name.compare(0, 3, "lib"))
{
if (!name.compare(name.size() - ext.size(), ext.size(), ext))
{
auto pos = name.find("_backend");
if (pos != name.npos)
{
backend_name = name.substr(3, pos - 3);
rc = true;
}
}
}
return rc;
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <functional>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#ifdef WIN32
#include <windows.h>
#define DL_HANDLE HMODULE
#else
#define DL_HANDLE void*
#endif
namespace ngraph
{
namespace runtime
{
class Backend;
class BackendManager;
using new_backend_t = std::function<Backend*(const char* config)>;
}
}
class ngraph::runtime::BackendManager
{
friend class Backend;
public:
/// @brief Used by build-in backends to register their name and constructor.
/// This function is not used if the backend is build as a shared library.
/// @param name The name of the registering backend in UPPER CASE.
/// @param backend_constructor A function of type new_backend_t which will be called to
//// construct an instance of the registered backend.
static void register_backend(const std::string& name, new_backend_t backend_constructor);
/// @brief Query the list of registered devices
/// @returns A vector of all registered devices.
static std::vector<std::string> get_registered_backends();
private:
static std::shared_ptr<runtime::Backend> create_backend(const std::string& type);
static std::unordered_map<std::string, new_backend_t>& get_registry();
static std::unordered_map<std::string, new_backend_t> s_registered_backend;
static DL_HANDLE open_shared_library(std::string type);
static std::map<std::string, std::string> get_registered_device_map();
static bool is_backend_name(const std::string& file, std::string& backend_name);
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment