Unverified Commit 8e1954d0 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Bob/backend list (#1220)

* open only the unversioned library but check that it is built against the correct version of ngraph

* review comments
parent 97b19515
...@@ -178,10 +178,10 @@ add_subdirectory(codegen) ...@@ -178,10 +178,10 @@ add_subdirectory(codegen)
add_subdirectory(runtime) add_subdirectory(runtime)
add_library(ngraph SHARED ${SRC}) add_library(ngraph SHARED ${SRC})
add_definitions("-DSHARED_LIB_EXT=\"${CMAKE_SHARED_LIBRARY_SUFFIX}\"") target_compile_definitions(ngraph PRIVATE SHARED_LIB_EXT="${CMAKE_SHARED_LIBRARY_SUFFIX}")
add_definitions("-DLIBRARY_VERSION=\"${NGRAPH_VERSION}\"")
set_target_properties(ngraph PROPERTIES VERSION ${NGRAPH_VERSION} SOVERSION ${NGRAPH_API_VERSION}) set_target_properties(ngraph PROPERTIES VERSION ${NGRAPH_VERSION} SOVERSION ${NGRAPH_API_VERSION})
target_link_libraries(ngraph PUBLIC libjson) target_link_libraries(ngraph PUBLIC libjson)
target_compile_definitions(ngraph PUBLIC NGRAPH_VERSION="${NGRAPH_VERSION}")
if (NGRAPH_ONNX_IMPORT_ENABLE) if (NGRAPH_ONNX_IMPORT_ENABLE)
add_dependencies(ngraph onnx_import) add_dependencies(ngraph onnx_import)
......
...@@ -213,7 +213,8 @@ string file_util::read_file_to_string(const string& path) ...@@ -213,7 +213,8 @@ string file_util::read_file_to_string(const string& path)
static void iterate_files_worker(const string& path, static void iterate_files_worker(const string& path,
function<void(const string& file, bool is_dir)> func, function<void(const string& file, bool is_dir)> func,
bool recurse) bool recurse,
bool include_links)
{ {
DIR* dir; DIR* dir;
struct dirent* ent; struct dirent* ent;
...@@ -224,26 +225,26 @@ static void iterate_files_worker(const string& path, ...@@ -224,26 +225,26 @@ static void iterate_files_worker(const string& path,
while ((ent = readdir(dir)) != nullptr) while ((ent = readdir(dir)) != nullptr)
{ {
string name = ent->d_name; string name = ent->d_name;
string path_name = file_util::path_join(path, name);
switch (ent->d_type) switch (ent->d_type)
{ {
case DT_DIR: case DT_DIR:
if (name != "." && name != "..") if (name != "." && name != "..")
{ {
string dir_path = file_util::path_join(path, name);
if (recurse) if (recurse)
{ {
file_util::iterate_files(dir_path, func, recurse); file_util::iterate_files(path_name, func, recurse);
} }
func(dir_path, true); func(path_name, true);
} }
break; break;
case DT_LNK: break; case DT_LNK:
case DT_REG: if (include_links)
{ {
string file_name = file_util::path_join(path, name); func(path_name, false);
func(file_name, false); }
break; break;
} case DT_REG: func(path_name, false); break;
default: break; default: break;
} }
} }
...@@ -264,7 +265,8 @@ static void iterate_files_worker(const string& path, ...@@ -264,7 +265,8 @@ static void iterate_files_worker(const string& path,
void file_util::iterate_files(const string& path, void file_util::iterate_files(const string& path,
function<void(const string& file, bool is_dir)> func, function<void(const string& file, bool is_dir)> func,
bool recurse) bool recurse,
bool include_links)
{ {
vector<string> files; vector<string> files;
vector<string> dirs; vector<string> dirs;
...@@ -279,7 +281,8 @@ void file_util::iterate_files(const string& path, ...@@ -279,7 +281,8 @@ void file_util::iterate_files(const string& path,
files.push_back(file); files.push_back(file);
} }
}, },
recurse); recurse,
include_links);
for (auto f : files) for (auto f : files)
{ {
......
...@@ -79,7 +79,8 @@ namespace ngraph ...@@ -79,7 +79,8 @@ namespace ngraph
// @param recurse Optional parameter to enable recursing through path // @param recurse Optional parameter to enable recursing through path
void iterate_files(const std::string& path, void iterate_files(const std::string& path,
std::function<void(const std::string& file, bool is_dir)> func, std::function<void(const std::string& file, bool is_dir)> func,
bool recurse = false); bool recurse = false,
bool include_links = false);
// @brief Create a temporary file // @brief Create a temporary file
// @param extension Optional extension for the temporary file // @param extension Optional extension for the temporary file
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
*******************************************************************************/ *******************************************************************************/
#include <dlfcn.h> #include <dlfcn.h>
#include <regex>
#include <sstream> #include <sstream>
#include "ngraph/file_util.hpp" #include "ngraph/file_util.hpp"
...@@ -37,22 +38,9 @@ static string find_my_file() ...@@ -37,22 +38,9 @@ static string find_my_file()
return dl_info.dli_fname; return dl_info.dli_fname;
} }
// This will be uncommented when we add support for listing all known backends
// static bool is_backend(const string& path)
// {
// bool rc = false;
// string name = file_util::get_file_name(path);
// if (name.find("_backend.") != string::npos)
// {
// NGRAPH_INFO << name;
// }
// return rc;
// }
void* runtime::Backend::open_shared_library(string type) void* runtime::Backend::open_shared_library(string type)
{ {
string ext = SHARED_LIB_EXT; string ext = SHARED_LIB_EXT;
string ver = LIBRARY_VERSION;
void* handle = nullptr; void* handle = nullptr;
...@@ -62,16 +50,12 @@ void* runtime::Backend::open_shared_library(string type) ...@@ -62,16 +50,12 @@ void* runtime::Backend::open_shared_library(string type)
{ {
type = type.substr(0, colon); type = type.substr(0, colon);
} }
string lib_name = "lib" + to_lower(type) + "_backend" + ext;
string library_name = "lib" + to_lower(type) + "_backend" + string(SHARED_LIB_EXT);
string my_directory = file_util::get_directory(find_my_file()); string my_directory = file_util::get_directory(find_my_file());
string full_path = file_util::path_join(my_directory, lib_name); string library_path = file_util::path_join(my_directory, library_name);
handle = dlopen(full_path.c_str(), RTLD_NOW | RTLD_GLOBAL); handle = dlopen(library_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
if (!handle)
{
string err = dlerror();
throw runtime_error("Library open for Backend '" + lib_name + "' failed with error:\n" +
err);
}
return handle; return handle;
} }
...@@ -85,6 +69,15 @@ shared_ptr<runtime::Backend> runtime::Backend::create(const string& type) ...@@ -85,6 +69,15 @@ shared_ptr<runtime::Backend> runtime::Backend::create(const string& type)
} }
else else
{ {
function<const char*()> get_ngraph_version_string =
reinterpret_cast<const char* (*)()>(dlsym(handle, "get_ngraph_version_string"));
if (!get_ngraph_version_string)
{
dlclose(handle);
throw runtime_error("Backend '" + type +
"' does not implement get_ngraph_version_string");
}
function<runtime::Backend*(const char*)> new_backend = function<runtime::Backend*(const char*)> new_backend =
reinterpret_cast<runtime::Backend* (*)(const char*)>(dlsym(handle, "new_backend")); reinterpret_cast<runtime::Backend* (*)(const char*)>(dlsym(handle, "new_backend"));
if (!new_backend) if (!new_backend)
...@@ -110,9 +103,50 @@ shared_ptr<runtime::Backend> runtime::Backend::create(const string& type) ...@@ -110,9 +103,50 @@ shared_ptr<runtime::Backend> runtime::Backend::create(const string& type)
return rc; return rc;
} }
map<string, string> runtime::Backend::get_registered_device_map()
{
map<string, string> rc;
string my_directory = file_util::get_directory(find_my_file());
vector<string> backend_list;
regex reg("^lib(.+)_backend" + string(SHARED_LIB_EXT));
smatch result;
auto f = [&](const string& file, bool is_dir) {
string name = file_util::get_file_name(file);
if (regex_match(name, result, reg))
{
auto handle = dlopen(file.c_str(), RTLD_LAZY | RTLD_LOCAL);
if (handle)
{
if (dlsym(handle, "new_backend") && dlsym(handle, "delete_backend"))
{
function<const char*()> get_ngraph_version_string =
reinterpret_cast<const char* (*)()>(
dlsym(handle, "get_ngraph_version_string"));
if (get_ngraph_version_string &&
get_ngraph_version_string() == string(NGRAPH_VERSION))
{
rc.insert({to_upper(result[1]), file});
}
}
dlclose(handle);
}
}
};
file_util::iterate_files(my_directory, f, false, true);
return rc;
}
vector<string> runtime::Backend::get_registered_devices() vector<string> runtime::Backend::get_registered_devices()
{ {
return vector<string>(); map<string, string> m = get_registered_device_map();
vector<string> rc;
for (const pair<string, string>& p : m)
{
rc.push_back(p.first);
}
return rc;
} }
void runtime::Backend::remove_compiled_function(shared_ptr<Function> func) void runtime::Backend::remove_compiled_function(shared_ptr<Function> func)
......
...@@ -84,6 +84,7 @@ namespace ngraph ...@@ -84,6 +84,7 @@ namespace ngraph
private: private:
static void* open_shared_library(std::string type); static void* open_shared_library(std::string type);
static std::map<std::string, std::string> get_registered_device_map();
}; };
} }
} }
...@@ -26,6 +26,11 @@ ...@@ -26,6 +26,11 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
extern "C" const char* get_ngraph_version_string()
{
return NGRAPH_VERSION;
}
extern "C" runtime::Backend* new_backend(const char* configuration_string) extern "C" runtime::Backend* new_backend(const char* configuration_string)
{ {
// Force TBB to link to the backend // Force TBB to link to the backend
......
...@@ -23,6 +23,11 @@ ...@@ -23,6 +23,11 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
extern "C" const char* get_ngraph_version_string()
{
return NGRAPH_VERSION;
}
extern "C" runtime::Backend* new_backend(const char* configuration_string) extern "C" runtime::Backend* new_backend(const char* configuration_string)
{ {
return new runtime::gpu::GPU_Backend(); return new runtime::gpu::GPU_Backend();
......
...@@ -20,6 +20,11 @@ ...@@ -20,6 +20,11 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
extern "C" const char* get_ngraph_version_string()
{
return NGRAPH_VERSION;
}
extern "C" runtime::Backend* new_backend(const char* configuration_string) extern "C" runtime::Backend* new_backend(const char* configuration_string)
{ {
return new runtime::intelgpu::IntelGPUBackend(); return new runtime::intelgpu::IntelGPUBackend();
......
...@@ -29,6 +29,11 @@ using namespace ngraph; ...@@ -29,6 +29,11 @@ using namespace ngraph;
using descriptor::layout::DenseTensorViewLayout; using descriptor::layout::DenseTensorViewLayout;
extern "C" const char* get_ngraph_version_string()
{
return NGRAPH_VERSION;
}
extern "C" runtime::Backend* new_backend(const char* configuration_string) extern "C" runtime::Backend* new_backend(const char* configuration_string)
{ {
return new runtime::interpreter::INTBackend(); return new runtime::interpreter::INTBackend();
......
...@@ -93,6 +93,13 @@ std::string ngraph::to_lower(const std::string& s) ...@@ -93,6 +93,13 @@ std::string ngraph::to_lower(const std::string& s)
return rc; return rc;
} }
std::string ngraph::to_upper(const std::string& s)
{
std::string rc = s;
std::transform(rc.begin(), rc.end(), rc.begin(), ::toupper);
return rc;
}
string ngraph::trim(const string& s) string ngraph::trim(const string& s)
{ {
string rc = s; string rc = s;
......
...@@ -110,6 +110,7 @@ namespace ngraph ...@@ -110,6 +110,7 @@ namespace ngraph
void dump(std::ostream& out, const void*, size_t); void dump(std::ostream& out, const void*, size_t);
std::string to_lower(const std::string& s); std::string to_lower(const std::string& s);
std::string to_upper(const std::string& s);
std::string trim(const std::string& s); std::string trim(const std::string& s);
std::vector<std::string> split(const std::string& s, char delimiter, bool trim = false); std::vector<std::string> split(const std::string& s, char delimiter, bool trim = false);
......
...@@ -27,7 +27,7 @@ TEST(backend_api, registered_devices) ...@@ -27,7 +27,7 @@ TEST(backend_api, registered_devices)
vector<string> devices = runtime::Backend::get_registered_devices(); vector<string> devices = runtime::Backend::get_registered_devices();
EXPECT_GE(devices.size(), 0); EXPECT_GE(devices.size(), 0);
// EXPECT_TRUE(contains(devices, "INTERPRETER")); EXPECT_TRUE(contains(devices, "INTERPRETER"));
} }
TEST(backend_api, invalid_name) TEST(backend_api, invalid_name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment