Unverified Commit 49c2059c authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Force backends to register by calling dlopen (#938)

* if a requested backend is not registered then try dlopen to force it to register

* call 'extern C create_backend()' in the opened shared object to register the backend

* use a single name to test for backend
parent 5648af4e
...@@ -333,6 +333,7 @@ if (NGRAPH_NNP_ENABLE) ...@@ -333,6 +333,7 @@ if (NGRAPH_NNP_ENABLE)
endif() endif()
target_include_directories(ngraph PUBLIC "${NGRAPH_INCLUDE_PATH}") target_include_directories(ngraph PUBLIC "${NGRAPH_INCLUDE_PATH}")
target_link_libraries(ngraph PRIVATE dl)
if((NGRAPH_CPU_ENABLE OR NGRAPH_GPU_ENABLE) AND LLVM_LINK_LIBS) if((NGRAPH_CPU_ENABLE OR NGRAPH_GPU_ENABLE) AND LLVM_LINK_LIBS)
target_link_libraries(ngraph PRIVATE ${LLVM_LINK_LIBS}) target_link_libraries(ngraph PRIVATE ${LLVM_LINK_LIBS})
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include <dlfcn.h>
#include <sstream> #include <sstream>
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
...@@ -39,12 +40,33 @@ runtime::Backend::~Backend() ...@@ -39,12 +40,33 @@ runtime::Backend::~Backend()
{ {
} }
void* runtime::Backend::open_shared_library(const string& type)
{
void* handle = nullptr;
string name = "lib" + type + "_backend.so";
handle = dlopen(name.c_str(), RTLD_NOW);
if (handle)
{
function<void()> create = reinterpret_cast<void (*)()>(dlsym(handle, "create_backend"));
if (create)
{
create();
}
}
return handle;
}
shared_ptr<runtime::Backend> runtime::Backend::create(const string& type) shared_ptr<runtime::Backend> runtime::Backend::create(const string& type)
{ {
auto it = get_backend_map().find(type); auto it = get_backend_map().find(type);
if (it == get_backend_map().end()) if (it == get_backend_map().end())
{ {
throw runtime_error("Backend '" + type + "' not found in registered backends."); open_shared_library(type);
it = get_backend_map().find(type);
if (it == get_backend_map().end())
{
throw runtime_error("Backend '" + type + "' not found in registered backends.");
}
} }
return it->second; return it->second;
} }
......
...@@ -83,6 +83,7 @@ namespace ngraph ...@@ -83,6 +83,7 @@ namespace ngraph
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs); const std::vector<std::shared_ptr<runtime::TensorView>>& inputs);
private: private:
static void* open_shared_library(const std::string& name);
static std::unordered_map<std::string, std::shared_ptr<Backend>>& get_backend_map(); static std::unordered_map<std::string, std::shared_ptr<Backend>>& get_backend_map();
}; };
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment