Unverified Commit 8e1954d0 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Bob/backend list (#1220)

* open only the unversioned library but check that it is built against the correct version of ngraph

* review comments
parent 97b19515
......@@ -178,10 +178,10 @@ add_subdirectory(codegen)
add_subdirectory(runtime)
add_library(ngraph SHARED ${SRC})
add_definitions("-DSHARED_LIB_EXT=\"${CMAKE_SHARED_LIBRARY_SUFFIX}\"")
add_definitions("-DLIBRARY_VERSION=\"${NGRAPH_VERSION}\"")
target_compile_definitions(ngraph PRIVATE SHARED_LIB_EXT="${CMAKE_SHARED_LIBRARY_SUFFIX}")
set_target_properties(ngraph PROPERTIES VERSION ${NGRAPH_VERSION} SOVERSION ${NGRAPH_API_VERSION})
target_link_libraries(ngraph PUBLIC libjson)
target_compile_definitions(ngraph PUBLIC NGRAPH_VERSION="${NGRAPH_VERSION}")
if (NGRAPH_ONNX_IMPORT_ENABLE)
add_dependencies(ngraph onnx_import)
......
......@@ -213,7 +213,8 @@ string file_util::read_file_to_string(const string& path)
static void iterate_files_worker(const string& path,
function<void(const string& file, bool is_dir)> func,
bool recurse)
bool recurse,
bool include_links)
{
DIR* dir;
struct dirent* ent;
......@@ -224,26 +225,26 @@ static void iterate_files_worker(const string& path,
while ((ent = readdir(dir)) != nullptr)
{
string name = ent->d_name;
string path_name = file_util::path_join(path, name);
switch (ent->d_type)
{
case DT_DIR:
if (name != "." && name != "..")
{
string dir_path = file_util::path_join(path, name);
if (recurse)
{
file_util::iterate_files(dir_path, func, recurse);
file_util::iterate_files(path_name, func, recurse);
}
func(dir_path, true);
func(path_name, true);
}
break;
case DT_LNK: break;
case DT_REG:
{
string file_name = file_util::path_join(path, name);
func(file_name, false);
case DT_LNK:
if (include_links)
{
func(path_name, false);
}
break;
}
case DT_REG: func(path_name, false); break;
default: break;
}
}
......@@ -264,7 +265,8 @@ static void iterate_files_worker(const string& path,
void file_util::iterate_files(const string& path,
function<void(const string& file, bool is_dir)> func,
bool recurse)
bool recurse,
bool include_links)
{
vector<string> files;
vector<string> dirs;
......@@ -279,7 +281,8 @@ void file_util::iterate_files(const string& path,
files.push_back(file);
}
},
recurse);
recurse,
include_links);
for (auto f : files)
{
......
......@@ -79,7 +79,8 @@ namespace ngraph
// @param recurse Optional parameter to enable recursing through path
void iterate_files(const std::string& path,
std::function<void(const std::string& file, bool is_dir)> func,
bool recurse = false);
bool recurse = false,
bool include_links = false);
// @brief Create a temporary file
// @param extension Optional extension for the temporary file
......
......@@ -15,6 +15,7 @@
*******************************************************************************/
#include <dlfcn.h>
#include <regex>
#include <sstream>
#include "ngraph/file_util.hpp"
......@@ -37,22 +38,9 @@ static string find_my_file()
return dl_info.dli_fname;
}
// This will be uncommented when we add support for listing all known backends
// static bool is_backend(const string& path)
// {
// bool rc = false;
// string name = file_util::get_file_name(path);
// if (name.find("_backend.") != string::npos)
// {
// NGRAPH_INFO << name;
// }
// return rc;
// }
void* runtime::Backend::open_shared_library(string type)
{
string ext = SHARED_LIB_EXT;
string ver = LIBRARY_VERSION;
void* handle = nullptr;
......@@ -62,16 +50,12 @@ void* runtime::Backend::open_shared_library(string type)
{
type = type.substr(0, colon);
}
string lib_name = "lib" + to_lower(type) + "_backend" + ext;
string library_name = "lib" + to_lower(type) + "_backend" + string(SHARED_LIB_EXT);
string my_directory = file_util::get_directory(find_my_file());
string full_path = file_util::path_join(my_directory, lib_name);
handle = dlopen(full_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
if (!handle)
{
string err = dlerror();
throw runtime_error("Library open for Backend '" + lib_name + "' failed with error:\n" +
err);
}
string library_path = file_util::path_join(my_directory, library_name);
handle = dlopen(library_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
return handle;
}
......@@ -85,6 +69,15 @@ shared_ptr<runtime::Backend> runtime::Backend::create(const string& type)
}
else
{
function<const char*()> get_ngraph_version_string =
reinterpret_cast<const char* (*)()>(dlsym(handle, "get_ngraph_version_string"));
if (!get_ngraph_version_string)
{
dlclose(handle);
throw runtime_error("Backend '" + type +
"' does not implement get_ngraph_version_string");
}
function<runtime::Backend*(const char*)> new_backend =
reinterpret_cast<runtime::Backend* (*)(const char*)>(dlsym(handle, "new_backend"));
if (!new_backend)
......@@ -110,9 +103,50 @@ shared_ptr<runtime::Backend> runtime::Backend::create(const string& type)
return rc;
}
map<string, string> runtime::Backend::get_registered_device_map()
{
map<string, string> rc;
string my_directory = file_util::get_directory(find_my_file());
vector<string> backend_list;
regex reg("^lib(.+)_backend" + string(SHARED_LIB_EXT));
smatch result;
auto f = [&](const string& file, bool is_dir) {
string name = file_util::get_file_name(file);
if (regex_match(name, result, reg))
{
auto handle = dlopen(file.c_str(), RTLD_LAZY | RTLD_LOCAL);
if (handle)
{
if (dlsym(handle, "new_backend") && dlsym(handle, "delete_backend"))
{
function<const char*()> get_ngraph_version_string =
reinterpret_cast<const char* (*)()>(
dlsym(handle, "get_ngraph_version_string"));
if (get_ngraph_version_string &&
get_ngraph_version_string() == string(NGRAPH_VERSION))
{
rc.insert({to_upper(result[1]), file});
}
}
dlclose(handle);
}
}
};
file_util::iterate_files(my_directory, f, false, true);
return rc;
}
vector<string> runtime::Backend::get_registered_devices()
{
return vector<string>();
map<string, string> m = get_registered_device_map();
vector<string> rc;
for (const pair<string, string>& p : m)
{
rc.push_back(p.first);
}
return rc;
}
void runtime::Backend::remove_compiled_function(shared_ptr<Function> func)
......
......@@ -84,6 +84,7 @@ namespace ngraph
private:
static void* open_shared_library(std::string type);
static std::map<std::string, std::string> get_registered_device_map();
};
}
}
......@@ -26,6 +26,11 @@
using namespace ngraph;
using namespace std;
extern "C" const char* get_ngraph_version_string()
{
return NGRAPH_VERSION;
}
extern "C" runtime::Backend* new_backend(const char* configuration_string)
{
// Force TBB to link to the backend
......
......@@ -23,6 +23,11 @@
using namespace ngraph;
using namespace std;
extern "C" const char* get_ngraph_version_string()
{
return NGRAPH_VERSION;
}
extern "C" runtime::Backend* new_backend(const char* configuration_string)
{
return new runtime::gpu::GPU_Backend();
......
......@@ -20,6 +20,11 @@
using namespace std;
using namespace ngraph;
extern "C" const char* get_ngraph_version_string()
{
return NGRAPH_VERSION;
}
extern "C" runtime::Backend* new_backend(const char* configuration_string)
{
return new runtime::intelgpu::IntelGPUBackend();
......
......@@ -29,6 +29,11 @@ using namespace ngraph;
using descriptor::layout::DenseTensorViewLayout;
extern "C" const char* get_ngraph_version_string()
{
return NGRAPH_VERSION;
}
extern "C" runtime::Backend* new_backend(const char* configuration_string)
{
return new runtime::interpreter::INTBackend();
......
......@@ -93,6 +93,13 @@ std::string ngraph::to_lower(const std::string& s)
return rc;
}
std::string ngraph::to_upper(const std::string& s)
{
std::string rc = s;
std::transform(rc.begin(), rc.end(), rc.begin(), ::toupper);
return rc;
}
string ngraph::trim(const string& s)
{
string rc = s;
......
......@@ -110,6 +110,7 @@ namespace ngraph
void dump(std::ostream& out, const void*, size_t);
std::string to_lower(const std::string& s);
std::string to_upper(const std::string& s);
std::string trim(const std::string& s);
std::vector<std::string> split(const std::string& s, char delimiter, bool trim = false);
......
......@@ -27,7 +27,7 @@ TEST(backend_api, registered_devices)
vector<string> devices = runtime::Backend::get_registered_devices();
EXPECT_GE(devices.size(), 0);
// EXPECT_TRUE(contains(devices, "INTERPRETER"));
EXPECT_TRUE(contains(devices, "INTERPRETER"));
}
TEST(backend_api, invalid_name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment