Unverified Commit ea4a89ec authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Create backend as shared_ptr (#2793)

* Interpreter working

* cleanup

* cleanup

* add CPU backend

* Update the rest of the backends

* update python API

* update test_case to use new shared_ptr
parent 9244e45b
......@@ -32,7 +32,7 @@ static std::shared_ptr<ngraph::runtime::Executable> compile(ngraph::runtime::Bac
void regclass_pyngraph_runtime_Backend(py::module m)
{
py::class_<ngraph::runtime::Backend, std::unique_ptr<ngraph::runtime::Backend>> backend(
py::class_<ngraph::runtime::Backend, std::shared_ptr<ngraph::runtime::Backend>> backend(
m, "Backend");
backend.doc() = "ngraph.impl.runtime.Backend wraps ngraph::runtime::Backend";
backend.def_static("create", &ngraph::runtime::Backend::create);
......
......@@ -35,7 +35,7 @@ std::shared_ptr<ngraph::Node> runtime::Backend::get_backend_op(const std::string
return dummy_node;
}
unique_ptr<runtime::Backend> runtime::Backend::create(const string& type)
shared_ptr<runtime::Backend> runtime::Backend::create(const string& type)
{
return BackendManager::create_backend(type);
}
......
......@@ -44,9 +44,9 @@ public:
/// \brief Create a new Backend object
/// \param type The name of a registered backend, such as "CPU" or "GPU".
/// To select a subdevice use "GPU:N" where s`N` is the subdevice number.
/// \returns unique_ptr to a new Backend or nullptr if the named backend
/// \returns shared_ptr to a new Backend or nullptr if the named backend
/// does not exist.
static std::unique_ptr<Backend> create(const std::string& type);
static std::shared_ptr<Backend> create(const std::string& type);
/// \brief Query the list of registered devices
/// \returns A vector of all registered devices.
......
......@@ -38,13 +38,13 @@ using namespace ngraph;
#define DLSYM(a, b) dlsym(a, b)
#endif
unordered_map<string, runtime::new_backend_t>& runtime::BackendManager::get_registry()
unordered_map<string, runtime::BackendConstructor*>& runtime::BackendManager::get_registry()
{
static unordered_map<string, new_backend_t> s_registered_backend;
static unordered_map<string, BackendConstructor*> s_registered_backend;
return s_registered_backend;
}
void runtime::BackendManager::register_backend(const string& name, new_backend_t new_backend)
void runtime::BackendManager::register_backend(const string& name, BackendConstructor* new_backend)
{
get_registry()[name] = new_backend;
}
......@@ -66,9 +66,9 @@ vector<string> runtime::BackendManager::get_registered_backends()
return rc;
}
unique_ptr<runtime::Backend> runtime::BackendManager::create_backend(const std::string& config)
shared_ptr<runtime::Backend> runtime::BackendManager::create_backend(const std::string& config)
{
runtime::Backend* backend = nullptr;
shared_ptr<runtime::Backend> backend;
string type = config;
// strip off attributes, IE:CPU becomes IE
......@@ -82,8 +82,8 @@ unique_ptr<runtime::Backend> runtime::BackendManager::create_backend(const std::
auto it = registry.find(type);
if (it != registry.end())
{
new_backend_t new_backend = it->second;
backend = new_backend(config.c_str());
BackendConstructor* new_backend = it->second;
backend = new_backend->create(config);
}
else
{
......@@ -97,26 +97,22 @@ unique_ptr<runtime::Backend> runtime::BackendManager::create_backend(const std::
#endif
throw runtime_error(ss.str());
}
function<const char*()> get_ngraph_version_string =
reinterpret_cast<const char* (*)()>(DLSYM(handle, "get_ngraph_version_string"));
if (!get_ngraph_version_string)
function<runtime::BackendConstructor*()> get_backend_constructor_pointer =
reinterpret_cast<runtime::BackendConstructor* (*)()>(
DLSYM(handle, "get_backend_constructor_pointer"));
if (get_backend_constructor_pointer)
{
CLOSE_LIBRARY(handle);
throw runtime_error("Backend '" + type +
"' does not implement get_ngraph_version_string");
backend = get_backend_constructor_pointer()->create(config);
}
function<runtime::Backend*(const char*)> new_backend =
reinterpret_cast<runtime::Backend* (*)(const char*)>(DLSYM(handle, "new_backend"));
if (!new_backend)
else
{
CLOSE_LIBRARY(handle);
throw runtime_error("Backend '" + type + "' does not implement new_backend");
throw runtime_error("Backend '" + type +
"' does not implement get_backend_constructor_pointer");
}
backend = new_backend(config.c_str());
}
return unique_ptr<runtime::Backend>(backend);
return backend;
}
// This doodad finds the full path of the containing shared library
......
......@@ -36,11 +36,17 @@ namespace ngraph
{
class Backend;
class BackendManager;
using new_backend_t = std::function<Backend*(const char* config)>;
class BackendConstructor;
}
}
class ngraph::runtime::BackendConstructor
{
public:
virtual ~BackendConstructor() {}
virtual std::shared_ptr<Backend> create(const std::string& config) = 0;
};
class ngraph::runtime::BackendManager
{
friend class Backend;
......@@ -49,19 +55,19 @@ public:
/// \brief Used by build-in backends to register their name and constructor.
/// This function is not used if the backend is build as a shared library.
/// \param name The name of the registering backend in UPPER CASE.
/// \param backend_constructor A function of type new_backend_t which will be called to
/// \param backend_constructor A BackendConstructor which will be called to
//// construct an instance of the registered backend.
static void register_backend(const std::string& name, new_backend_t backend_constructor);
static void register_backend(const std::string& name, BackendConstructor* backend_constructor);
/// \brief Query the list of registered devices
/// \returns A vector of all registered devices.
static std::vector<std::string> get_registered_backends();
private:
static std::unique_ptr<runtime::Backend> create_backend(const std::string& type);
static std::unordered_map<std::string, new_backend_t>& get_registry();
static std::shared_ptr<runtime::Backend> create_backend(const std::string& type);
static std::unordered_map<std::string, BackendConstructor*>& get_registry();
static std::unordered_map<std::string, new_backend_t> s_registered_backend;
static std::unordered_map<std::string, BackendConstructor*> s_registered_backend;
static DL_HANDLE open_shared_library(std::string type);
static std::map<std::string, std::string> get_registered_device_map();
......
......@@ -28,16 +28,22 @@
using namespace ngraph;
using namespace std;
extern "C" CPU_BACKEND_API runtime::Backend* new_backend(const char* configuration_string)
extern "C" runtime::BackendConstructor* get_backend_constructor_pointer()
{
class CPU_BackendConstructor : public runtime::BackendConstructor
{
public:
std::shared_ptr<runtime::Backend> create(const std::string& config) override
{
// Force TBB to link to the backend
tbb::TBB_runtime_interface_version();
return new runtime::cpu::CPU_Backend();
}
return make_shared<runtime::cpu::CPU_Backend>();
}
};
extern "C" CPU_BACKEND_API void delete_backend(runtime::Backend* backend)
{
delete backend;
static unique_ptr<runtime::BackendConstructor> s_backend_constructor(
new CPU_BackendConstructor());
return s_backend_constructor.get();
}
namespace
......@@ -45,7 +51,10 @@ namespace
static class CPUStaticInit
{
public:
CPUStaticInit() { runtime::BackendManager::register_backend("CPU", new_backend); }
CPUStaticInit()
{
runtime::BackendManager::register_backend("CPU", get_backend_constructor_pointer());
}
~CPUStaticInit() {}
} s_cpu_static_init;
}
......
......@@ -24,14 +24,20 @@
using namespace std;
using namespace ngraph;
extern "C" const char* get_ngraph_version_string()
extern "C" runtime::BackendConstructor* get_backend_constructor_pointer()
{
return NGRAPH_VERSION;
}
class LocalBackendConstructor : public runtime::BackendConstructor
{
public:
std::shared_ptr<runtime::Backend> create(const std::string& config) override
{
return std::make_shared<runtime::gcpu::GCPUBackend>();
}
};
extern "C" runtime::Backend* new_backend(const char* configuration_string)
{
return new runtime::gcpu::GCPUBackend();
static unique_ptr<runtime::BackendConstructor> s_backend_constructor(
new LocalBackendConstructor());
return s_backend_constructor.get();
}
runtime::gcpu::GCPUBackend::GCPUBackend()
......
......@@ -21,6 +21,7 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp"
#include "ngraph/runtime/gpu/gpu_internal_function.hpp"
......@@ -33,19 +34,20 @@
using namespace ngraph;
using namespace std;
extern "C" const char* get_ngraph_version_string()
extern "C" runtime::BackendConstructor* get_backend_constructor_pointer()
{
return NGRAPH_VERSION;
}
extern "C" runtime::Backend* new_backend(const char* configuration_string)
{
return new runtime::gpu::GPU_Backend();
}
class LocalBackendConstructor : public runtime::BackendConstructor
{
public:
std::shared_ptr<runtime::Backend> create(const std::string& config) override
{
return std::make_shared<runtime::gpu::GPU_Backend>();
}
};
extern "C" void delete_backend(runtime::Backend* backend)
{
delete backend;
static unique_ptr<runtime::BackendConstructor> s_backend_constructor(
new LocalBackendConstructor());
return s_backend_constructor.get();
}
runtime::gpu::GPU_Backend::GPU_Backend()
......
......@@ -17,6 +17,7 @@
#include "ngraph/runtime/gpuh/gpuh_backend.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/interpreter/int_backend.hpp"
#include "ngraph/runtime/tensor.hpp"
......@@ -24,19 +25,20 @@
using namespace ngraph;
using namespace std;
extern "C" const char* get_ngraph_version_string()
extern "C" runtime::BackendConstructor* get_backend_constructor_pointer()
{
return NGRAPH_VERSION;
}
class LocalBackendConstructor : public runtime::BackendConstructor
{
public:
std::shared_ptr<runtime::Backend> create(const std::string& config) override
{
return std::make_shared<runtime::gpuh::GPUHBackend>();
}
};
extern "C" runtime::Backend* new_backend(const char* configuration_string)
{
return new runtime::gpuh::GPUHBackend();
}
vector<string> get_excludes()
{
return vector<string>{{"Not"}};
static unique_ptr<runtime::BackendConstructor> s_backend_constructor(
new LocalBackendConstructor());
return s_backend_constructor.get();
}
runtime::gpuh::GPUHBackend::GPUHBackend()
......
......@@ -49,6 +49,7 @@
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/reshape_elimination.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_backend.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_executable.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_kernels.hpp"
......@@ -274,19 +275,20 @@ static void do_equal_propagation(cldnn::topology& topology,
topology.add(op_concat);
}
extern "C" const char* get_ngraph_version_string()
extern "C" runtime::BackendConstructor* get_backend_constructor_pointer()
{
return NGRAPH_VERSION;
}
extern "C" runtime::Backend* new_backend(const char* configuration_string)
{
return new runtime::intelgpu::IntelGPUBackend();
}
class IntelGPUBackendConstructor : public runtime::BackendConstructor
{
public:
std::shared_ptr<runtime::Backend> create(const std::string& config) override
{
return std::make_shared<runtime::intelgpu::IntelGPUBackend>();
}
};
extern "C" void delete_backend(runtime::Backend* backend)
{
delete backend;
static unique_ptr<runtime::BackendConstructor> s_backend_constructor(
new IntelGPUBackendConstructor());
return s_backend_constructor.get();
}
runtime::intelgpu::IntelGPUBackend::IntelGPUBackend()
......
......@@ -24,14 +24,20 @@
using namespace std;
using namespace ngraph;
extern "C" const char* get_ngraph_version_string()
extern "C" runtime::BackendConstructor* get_backend_constructor_pointer()
{
return NGRAPH_VERSION;
}
class INTBackendConstructor : public runtime::BackendConstructor
{
public:
std::shared_ptr<runtime::Backend> create(const std::string& config) override
{
return std::make_shared<runtime::interpreter::INTBackend>();
}
};
extern "C" runtime::Backend* new_backend(const char* configuration_string)
{
return new runtime::interpreter::INTBackend();
static unique_ptr<runtime::BackendConstructor> s_backend_constructor(
new INTBackendConstructor());
return s_backend_constructor.get();
}
runtime::interpreter::INTBackend::INTBackend()
......
......@@ -22,6 +22,7 @@
#include <string>
#include <vector>
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/tensor.hpp"
namespace ngraph
......@@ -32,6 +33,7 @@ namespace ngraph
{
class INTBackend;
class INTExecutable;
class INTBackendConstructor;
}
}
}
......
......@@ -24,6 +24,7 @@
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/util.hpp"
using namespace std;
......@@ -31,14 +32,20 @@ using namespace ngraph;
using descriptor::layout::DenseTensorLayout;
extern "C" const char* get_ngraph_version_string()
extern "C" runtime::BackendConstructor* get_backend_constructor_pointer()
{
return NGRAPH_VERSION;
}
class LocalBackendConstructor : public runtime::BackendConstructor
{
public:
std::shared_ptr<runtime::Backend> create(const std::string& config) override
{
return std::make_shared<runtime::nop::NOPBackend>();
}
};
extern "C" runtime::Backend* new_backend(const char* configuration_string)
{
return new runtime::nop::NOPBackend();
static unique_ptr<runtime::BackendConstructor> s_backend_constructor(
new LocalBackendConstructor());
return s_backend_constructor.get();
}
shared_ptr<runtime::Tensor> runtime::nop::NOPBackend::create_tensor(const element::Type& type,
......
......@@ -299,7 +299,7 @@ template <typename T>
class BatchNormInferenceTester
{
public:
BatchNormInferenceTester(const std::unique_ptr<ngraph::runtime::Backend>& backend,
BatchNormInferenceTester(const std::shared_ptr<ngraph::runtime::Backend>& backend,
const Shape& input_shape,
element::Type etype,
double epsilon)
......@@ -343,7 +343,7 @@ public:
}
protected:
const std::unique_ptr<ngraph::runtime::Backend>& m_backend;
const std::shared_ptr<ngraph::runtime::Backend>& m_backend;
std::shared_ptr<Function> m_function;
std::shared_ptr<ngraph::runtime::Tensor> m_input;
std::shared_ptr<ngraph::runtime::Tensor> m_gamma;
......@@ -365,7 +365,7 @@ public:
using Variance = test::NDArray<T, 1>;
using NormedInput = test::NDArray<T, 2>;
BatchNormInferenceTesterZeroEpsilon(const std::unique_ptr<ngraph::runtime::Backend>& backend,
BatchNormInferenceTesterZeroEpsilon(const std::shared_ptr<ngraph::runtime::Backend>& backend,
element::Type etype)
: BatchNormInferenceTester<T>(backend, Shape{2, 3}, etype, 0.0)
{
......@@ -459,7 +459,7 @@ public:
using Variance = test::NDArray<T, 1>;
using NormedInput = test::NDArray<T, 2>;
BatchNormInferenceTesterNonZeroEpsilon(const std::unique_ptr<ngraph::runtime::Backend>& backend,
BatchNormInferenceTesterNonZeroEpsilon(const std::shared_ptr<ngraph::runtime::Backend>& backend,
element::Type etype)
: BatchNormInferenceTester<T>(backend, Shape{2, 3}, etype, 0.25)
{
......@@ -545,7 +545,7 @@ template <typename T>
class BatchNormTrainingTester
{
public:
BatchNormTrainingTester(const std::unique_ptr<ngraph::runtime::Backend>& backend,
BatchNormTrainingTester(const std::shared_ptr<ngraph::runtime::Backend>& backend,
const Shape& input_shape,
element::Type etype,
double epsilon)
......@@ -597,7 +597,7 @@ public:
}
protected:
const std::unique_ptr<ngraph::runtime::Backend>& m_backend;
const std::shared_ptr<ngraph::runtime::Backend>& m_backend;
std::shared_ptr<Function> m_function;
std::shared_ptr<ngraph::runtime::Tensor> m_input;
std::shared_ptr<ngraph::runtime::Tensor> m_gamma;
......@@ -619,7 +619,7 @@ public:
using Mean = test::NDArray<T, 1>;
using Variance = test::NDArray<T, 1>;
BatchNormTrainingTesterZeroEpsilon(const std::unique_ptr<ngraph::runtime::Backend>& backend,
BatchNormTrainingTesterZeroEpsilon(const std::shared_ptr<ngraph::runtime::Backend>& backend,
element::Type etype)
: BatchNormTrainingTester<T>(backend, Shape{10, 3}, etype, 0.0)
{
......
......@@ -926,7 +926,7 @@ TEST(builder, scaled_Q_unsigned)
TEST(builder, dynamic_scaled_Q)
{
auto call_SQ = [](unique_ptr<runtime::Backend>& backend,
auto call_SQ = [](shared_ptr<runtime::Backend>& backend,
element::Type type,
op::Quantize::RoundMode mode,
Shape in_shape,
......@@ -1036,7 +1036,7 @@ TEST(builder, scaled_DQ_signed)
}
template <typename T>
shared_ptr<runtime::Tensor> call_SDQ(unique_ptr<runtime::Backend>& backend,
shared_ptr<runtime::Tensor> call_SDQ(shared_ptr<runtime::Backend>& backend,
element::Type type,
Shape in_shape,
vector<T> in,
......
......@@ -39,15 +39,26 @@
using namespace std;
using namespace ngraph;
static runtime::Backend* hybrid_creator(const char* config)
static runtime::BackendConstructor* hybrid_creator()
{
class HybridBackendConstructor : public runtime::BackendConstructor
{
public:
std::shared_ptr<runtime::Backend> create(const std::string& config) override
{
vector<string> unsupported_0 = {"Add", "Max"};
vector<string> unsupported_1 = {"Multiply"};
vector<shared_ptr<runtime::Backend>> backend_list = {
make_shared<runtime::interpreter::INTBackend>(unsupported_0),
make_shared<runtime::cpu::CPU_Backend>()};
return new runtime::hybrid::HybridBackend(backend_list);
return make_shared<runtime::hybrid::HybridBackend>(backend_list);
}
};
static unique_ptr<runtime::BackendConstructor> s_backend_constructor(
new HybridBackendConstructor());
return s_backend_constructor.get();
}
TEST(HYBRID, function_call)
......@@ -100,7 +111,7 @@ TEST(HYBRID, function_call)
TEST(HYBRID, abc)
{
const string backend_name = "H1";
runtime::BackendManager::register_backend(backend_name, hybrid_creator);
runtime::BackendManager::register_backend(backend_name, hybrid_creator());
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
......@@ -139,7 +150,7 @@ TEST(HYBRID, abc)
TEST(HYBRID, simple)
{
const string backend_name = "H1";
runtime::BackendManager::register_backend(backend_name, hybrid_creator);
runtime::BackendManager::register_backend(backend_name, hybrid_creator());
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::i8, shape);
......
......@@ -150,7 +150,7 @@ namespace ngraph
const std::shared_ptr<ngraph::runtime::Tensor>&)>;
std::shared_ptr<Function> m_function;
std::unique_ptr<runtime::Backend> m_backend;
std::shared_ptr<runtime::Backend> m_backend;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_result_tensors;
std::vector<std::shared_ptr<ngraph::op::Constant>> m_expected_outputs;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment