Unverified Commit ea4a89ec authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Create backend as shared_ptr (#2793)

* Interpreter working

* cleanup

* cleanup

* add CPU backend

* Update the rest of the backends

* update python API

* update test_case to use new shared_ptr
parent 9244e45b
...@@ -32,7 +32,7 @@ static std::shared_ptr<ngraph::runtime::Executable> compile(ngraph::runtime::Bac ...@@ -32,7 +32,7 @@ static std::shared_ptr<ngraph::runtime::Executable> compile(ngraph::runtime::Bac
void regclass_pyngraph_runtime_Backend(py::module m) void regclass_pyngraph_runtime_Backend(py::module m)
{ {
py::class_<ngraph::runtime::Backend, std::unique_ptr<ngraph::runtime::Backend>> backend( py::class_<ngraph::runtime::Backend, std::shared_ptr<ngraph::runtime::Backend>> backend(
m, "Backend"); m, "Backend");
backend.doc() = "ngraph.impl.runtime.Backend wraps ngraph::runtime::Backend"; backend.doc() = "ngraph.impl.runtime.Backend wraps ngraph::runtime::Backend";
backend.def_static("create", &ngraph::runtime::Backend::create); backend.def_static("create", &ngraph::runtime::Backend::create);
......
...@@ -35,7 +35,7 @@ std::shared_ptr<ngraph::Node> runtime::Backend::get_backend_op(const std::string ...@@ -35,7 +35,7 @@ std::shared_ptr<ngraph::Node> runtime::Backend::get_backend_op(const std::string
return dummy_node; return dummy_node;
} }
unique_ptr<runtime::Backend> runtime::Backend::create(const string& type) shared_ptr<runtime::Backend> runtime::Backend::create(const string& type)
{ {
return BackendManager::create_backend(type); return BackendManager::create_backend(type);
} }
......
...@@ -44,9 +44,9 @@ public: ...@@ -44,9 +44,9 @@ public:
/// \brief Create a new Backend object /// \brief Create a new Backend object
/// \param type The name of a registered backend, such as "CPU" or "GPU". /// \param type The name of a registered backend, such as "CPU" or "GPU".
/// To select a subdevice use "GPU:N" where s`N` is the subdevice number. /// To select a subdevice use "GPU:N" where s`N` is the subdevice number.
/// \returns unique_ptr to a new Backend or nullptr if the named backend /// \returns shared_ptr to a new Backend or nullptr if the named backend
/// does not exist. /// does not exist.
static std::unique_ptr<Backend> create(const std::string& type); static std::shared_ptr<Backend> create(const std::string& type);
/// \brief Query the list of registered devices /// \brief Query the list of registered devices
/// \returns A vector of all registered devices. /// \returns A vector of all registered devices.
......
...@@ -38,13 +38,13 @@ using namespace ngraph; ...@@ -38,13 +38,13 @@ using namespace ngraph;
#define DLSYM(a, b) dlsym(a, b) #define DLSYM(a, b) dlsym(a, b)
#endif #endif
unordered_map<string, runtime::new_backend_t>& runtime::BackendManager::get_registry() unordered_map<string, runtime::BackendConstructor*>& runtime::BackendManager::get_registry()
{ {
static unordered_map<string, new_backend_t> s_registered_backend; static unordered_map<string, BackendConstructor*> s_registered_backend;
return s_registered_backend; return s_registered_backend;
} }
void runtime::BackendManager::register_backend(const string& name, new_backend_t new_backend) void runtime::BackendManager::register_backend(const string& name, BackendConstructor* new_backend)
{ {
get_registry()[name] = new_backend; get_registry()[name] = new_backend;
} }
...@@ -66,9 +66,9 @@ vector<string> runtime::BackendManager::get_registered_backends() ...@@ -66,9 +66,9 @@ vector<string> runtime::BackendManager::get_registered_backends()
return rc; return rc;
} }
unique_ptr<runtime::Backend> runtime::BackendManager::create_backend(const std::string& config) shared_ptr<runtime::Backend> runtime::BackendManager::create_backend(const std::string& config)
{ {
runtime::Backend* backend = nullptr; shared_ptr<runtime::Backend> backend;
string type = config; string type = config;
// strip off attributes, IE:CPU becomes IE // strip off attributes, IE:CPU becomes IE
...@@ -82,8 +82,8 @@ unique_ptr<runtime::Backend> runtime::BackendManager::create_backend(const std:: ...@@ -82,8 +82,8 @@ unique_ptr<runtime::Backend> runtime::BackendManager::create_backend(const std::
auto it = registry.find(type); auto it = registry.find(type);
if (it != registry.end()) if (it != registry.end())
{ {
new_backend_t new_backend = it->second; BackendConstructor* new_backend = it->second;
backend = new_backend(config.c_str()); backend = new_backend->create(config);
} }
else else
{ {
...@@ -97,26 +97,22 @@ unique_ptr<runtime::Backend> runtime::BackendManager::create_backend(const std:: ...@@ -97,26 +97,22 @@ unique_ptr<runtime::Backend> runtime::BackendManager::create_backend(const std::
#endif #endif
throw runtime_error(ss.str()); throw runtime_error(ss.str());
} }
function<const char*()> get_ngraph_version_string =
reinterpret_cast<const char* (*)()>(DLSYM(handle, "get_ngraph_version_string")); function<runtime::BackendConstructor*()> get_backend_constructor_pointer =
if (!get_ngraph_version_string) reinterpret_cast<runtime::BackendConstructor* (*)()>(
DLSYM(handle, "get_backend_constructor_pointer"));
if (get_backend_constructor_pointer)
{ {
CLOSE_LIBRARY(handle); backend = get_backend_constructor_pointer()->create(config);
throw runtime_error("Backend '" + type +
"' does not implement get_ngraph_version_string");
} }
else
function<runtime::Backend*(const char*)> new_backend =
reinterpret_cast<runtime::Backend* (*)(const char*)>(DLSYM(handle, "new_backend"));
if (!new_backend)
{ {
CLOSE_LIBRARY(handle); CLOSE_LIBRARY(handle);
throw runtime_error("Backend '" + type + "' does not implement new_backend"); throw runtime_error("Backend '" + type +
"' does not implement get_backend_constructor_pointer");
} }
backend = new_backend(config.c_str());
} }
return unique_ptr<runtime::Backend>(backend); return backend;
} }
// This doodad finds the full path of the containing shared library // This doodad finds the full path of the containing shared library
......
...@@ -36,11 +36,17 @@ namespace ngraph ...@@ -36,11 +36,17 @@ namespace ngraph
{ {
class Backend; class Backend;
class BackendManager; class BackendManager;
class BackendConstructor;
using new_backend_t = std::function<Backend*(const char* config)>;
} }
} }
class ngraph::runtime::BackendConstructor
{
public:
virtual ~BackendConstructor() {}
virtual std::shared_ptr<Backend> create(const std::string& config) = 0;
};
class ngraph::runtime::BackendManager class ngraph::runtime::BackendManager
{ {
friend class Backend; friend class Backend;
...@@ -49,19 +55,19 @@ public: ...@@ -49,19 +55,19 @@ public:
/// \brief Used by build-in backends to register their name and constructor. /// \brief Used by build-in backends to register their name and constructor.
/// This function is not used if the backend is build as a shared library. /// This function is not used if the backend is build as a shared library.
/// \param name The name of the registering backend in UPPER CASE. /// \param name The name of the registering backend in UPPER CASE.
/// \param backend_constructor A function of type new_backend_t which will be called to /// \param backend_constructor A BackendConstructor which will be called to
//// construct an instance of the registered backend. //// construct an instance of the registered backend.
static void register_backend(const std::string& name, new_backend_t backend_constructor); static void register_backend(const std::string& name, BackendConstructor* backend_constructor);
/// \brief Query the list of registered devices /// \brief Query the list of registered devices
/// \returns A vector of all registered devices. /// \returns A vector of all registered devices.
static std::vector<std::string> get_registered_backends(); static std::vector<std::string> get_registered_backends();
private: private:
static std::unique_ptr<runtime::Backend> create_backend(const std::string& type); static std::shared_ptr<runtime::Backend> create_backend(const std::string& type);
static std::unordered_map<std::string, new_backend_t>& get_registry(); static std::unordered_map<std::string, BackendConstructor*>& get_registry();
static std::unordered_map<std::string, new_backend_t> s_registered_backend; static std::unordered_map<std::string, BackendConstructor*> s_registered_backend;
static DL_HANDLE open_shared_library(std::string type); static DL_HANDLE open_shared_library(std::string type);
static std::map<std::string, std::string> get_registered_device_map(); static std::map<std::string, std::string> get_registered_device_map();
......
...@@ -28,16 +28,22 @@ ...@@ -28,16 +28,22 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
extern "C" CPU_BACKEND_API runtime::Backend* new_backend(const char* configuration_string) extern "C" runtime::BackendConstructor* get_backend_constructor_pointer()
{ {
// Force TBB to link to the backend class CPU_BackendConstructor : public runtime::BackendConstructor
tbb::TBB_runtime_interface_version(); {
return new runtime::cpu::CPU_Backend(); public:
} std::shared_ptr<runtime::Backend> create(const std::string& config) override
{
// Force TBB to link to the backend
tbb::TBB_runtime_interface_version();
return make_shared<runtime::cpu::CPU_Backend>();
}
};
extern "C" CPU_BACKEND_API void delete_backend(runtime::Backend* backend) static unique_ptr<runtime::BackendConstructor> s_backend_constructor(
{ new CPU_BackendConstructor());
delete backend; return s_backend_constructor.get();
} }
namespace namespace
...@@ -45,7 +51,10 @@ namespace ...@@ -45,7 +51,10 @@ namespace
static class CPUStaticInit static class CPUStaticInit
{ {
public: public:
CPUStaticInit() { runtime::BackendManager::register_backend("CPU", new_backend); } CPUStaticInit()
{
runtime::BackendManager::register_backend("CPU", get_backend_constructor_pointer());
}
~CPUStaticInit() {} ~CPUStaticInit() {}
} s_cpu_static_init; } s_cpu_static_init;
} }
......
...@@ -24,14 +24,20 @@ ...@@ -24,14 +24,20 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
extern "C" const char* get_ngraph_version_string() extern "C" runtime::BackendConstructor* get_backend_constructor_pointer()
{ {
return NGRAPH_VERSION; class LocalBackendConstructor : public runtime::BackendConstructor
} {
public:
std::shared_ptr<runtime::Backend> create(const std::string& config) override
{
return std::make_shared<runtime::gcpu::GCPUBackend>();
}
};
extern "C" runtime::Backend* new_backend(const char* configuration_string) static unique_ptr<runtime::BackendConstructor> s_backend_constructor(
{ new LocalBackendConstructor());
return new runtime::gcpu::GCPUBackend(); return s_backend_constructor.get();
} }
runtime::gcpu::GCPUBackend::GCPUBackend() runtime::gcpu::GCPUBackend::GCPUBackend()
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/batch_norm.hpp" #include "ngraph/op/batch_norm.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp" #include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp" #include "ngraph/runtime/gpu/gpu_external_function.hpp"
#include "ngraph/runtime/gpu/gpu_internal_function.hpp" #include "ngraph/runtime/gpu/gpu_internal_function.hpp"
...@@ -33,19 +34,20 @@ ...@@ -33,19 +34,20 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
extern "C" const char* get_ngraph_version_string() extern "C" runtime::BackendConstructor* get_backend_constructor_pointer()
{ {
return NGRAPH_VERSION; class LocalBackendConstructor : public runtime::BackendConstructor
} {
public:
extern "C" runtime::Backend* new_backend(const char* configuration_string) std::shared_ptr<runtime::Backend> create(const std::string& config) override
{ {
return new runtime::gpu::GPU_Backend(); return std::make_shared<runtime::gpu::GPU_Backend>();
} }
};
extern "C" void delete_backend(runtime::Backend* backend) static unique_ptr<runtime::BackendConstructor> s_backend_constructor(
{ new LocalBackendConstructor());
delete backend; return s_backend_constructor.get();
} }
runtime::gpu::GPU_Backend::GPU_Backend() runtime::gpu::GPU_Backend::GPU_Backend()
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "ngraph/runtime/gpuh/gpuh_backend.hpp" #include "ngraph/runtime/gpuh/gpuh_backend.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp" #include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/interpreter/int_backend.hpp" #include "ngraph/runtime/interpreter/int_backend.hpp"
#include "ngraph/runtime/tensor.hpp" #include "ngraph/runtime/tensor.hpp"
...@@ -24,19 +25,20 @@ ...@@ -24,19 +25,20 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
extern "C" const char* get_ngraph_version_string() extern "C" runtime::BackendConstructor* get_backend_constructor_pointer()
{ {
return NGRAPH_VERSION; class LocalBackendConstructor : public runtime::BackendConstructor
} {
public:
std::shared_ptr<runtime::Backend> create(const std::string& config) override
{
return std::make_shared<runtime::gpuh::GPUHBackend>();
}
};
extern "C" runtime::Backend* new_backend(const char* configuration_string) static unique_ptr<runtime::BackendConstructor> s_backend_constructor(
{ new LocalBackendConstructor());
return new runtime::gpuh::GPUHBackend(); return s_backend_constructor.get();
}
vector<string> get_excludes()
{
return vector<string>{{"Not"}};
} }
runtime::gpuh::GPUHBackend::GPUHBackend() runtime::gpuh::GPUHBackend::GPUHBackend()
......
...@@ -49,6 +49,7 @@ ...@@ -49,6 +49,7 @@
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp" #include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/reshape_elimination.hpp" #include "ngraph/pass/reshape_elimination.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_backend.hpp" #include "ngraph/runtime/intelgpu/intelgpu_backend.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_executable.hpp" #include "ngraph/runtime/intelgpu/intelgpu_executable.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_kernels.hpp" #include "ngraph/runtime/intelgpu/intelgpu_kernels.hpp"
...@@ -274,19 +275,20 @@ static void do_equal_propagation(cldnn::topology& topology, ...@@ -274,19 +275,20 @@ static void do_equal_propagation(cldnn::topology& topology,
topology.add(op_concat); topology.add(op_concat);
} }
extern "C" const char* get_ngraph_version_string() extern "C" runtime::BackendConstructor* get_backend_constructor_pointer()
{ {
return NGRAPH_VERSION; class IntelGPUBackendConstructor : public runtime::BackendConstructor
} {
public:
extern "C" runtime::Backend* new_backend(const char* configuration_string) std::shared_ptr<runtime::Backend> create(const std::string& config) override
{ {
return new runtime::intelgpu::IntelGPUBackend(); return std::make_shared<runtime::intelgpu::IntelGPUBackend>();
} }
};
extern "C" void delete_backend(runtime::Backend* backend) static unique_ptr<runtime::BackendConstructor> s_backend_constructor(
{ new IntelGPUBackendConstructor());
delete backend; return s_backend_constructor.get();
} }
runtime::intelgpu::IntelGPUBackend::IntelGPUBackend() runtime::intelgpu::IntelGPUBackend::IntelGPUBackend()
......
...@@ -24,14 +24,20 @@ ...@@ -24,14 +24,20 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
extern "C" const char* get_ngraph_version_string() extern "C" runtime::BackendConstructor* get_backend_constructor_pointer()
{ {
return NGRAPH_VERSION; class INTBackendConstructor : public runtime::BackendConstructor
} {
public:
std::shared_ptr<runtime::Backend> create(const std::string& config) override
{
return std::make_shared<runtime::interpreter::INTBackend>();
}
};
extern "C" runtime::Backend* new_backend(const char* configuration_string) static unique_ptr<runtime::BackendConstructor> s_backend_constructor(
{ new INTBackendConstructor());
return new runtime::interpreter::INTBackend(); return s_backend_constructor.get();
} }
runtime::interpreter::INTBackend::INTBackend() runtime::interpreter::INTBackend::INTBackend()
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/tensor.hpp" #include "ngraph/runtime/tensor.hpp"
namespace ngraph namespace ngraph
...@@ -32,6 +33,7 @@ namespace ngraph ...@@ -32,6 +33,7 @@ namespace ngraph
{ {
class INTBackend; class INTBackend;
class INTExecutable; class INTExecutable;
class INTBackendConstructor;
} }
} }
} }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "ngraph/pass/like_replacement.hpp" #include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -31,14 +32,20 @@ using namespace ngraph; ...@@ -31,14 +32,20 @@ using namespace ngraph;
using descriptor::layout::DenseTensorLayout; using descriptor::layout::DenseTensorLayout;
extern "C" const char* get_ngraph_version_string() extern "C" runtime::BackendConstructor* get_backend_constructor_pointer()
{ {
return NGRAPH_VERSION; class LocalBackendConstructor : public runtime::BackendConstructor
} {
public:
std::shared_ptr<runtime::Backend> create(const std::string& config) override
{
return std::make_shared<runtime::nop::NOPBackend>();
}
};
extern "C" runtime::Backend* new_backend(const char* configuration_string) static unique_ptr<runtime::BackendConstructor> s_backend_constructor(
{ new LocalBackendConstructor());
return new runtime::nop::NOPBackend(); return s_backend_constructor.get();
} }
shared_ptr<runtime::Tensor> runtime::nop::NOPBackend::create_tensor(const element::Type& type, shared_ptr<runtime::Tensor> runtime::nop::NOPBackend::create_tensor(const element::Type& type,
......
...@@ -299,7 +299,7 @@ template <typename T> ...@@ -299,7 +299,7 @@ template <typename T>
class BatchNormInferenceTester class BatchNormInferenceTester
{ {
public: public:
BatchNormInferenceTester(const std::unique_ptr<ngraph::runtime::Backend>& backend, BatchNormInferenceTester(const std::shared_ptr<ngraph::runtime::Backend>& backend,
const Shape& input_shape, const Shape& input_shape,
element::Type etype, element::Type etype,
double epsilon) double epsilon)
...@@ -343,7 +343,7 @@ public: ...@@ -343,7 +343,7 @@ public:
} }
protected: protected:
const std::unique_ptr<ngraph::runtime::Backend>& m_backend; const std::shared_ptr<ngraph::runtime::Backend>& m_backend;
std::shared_ptr<Function> m_function; std::shared_ptr<Function> m_function;
std::shared_ptr<ngraph::runtime::Tensor> m_input; std::shared_ptr<ngraph::runtime::Tensor> m_input;
std::shared_ptr<ngraph::runtime::Tensor> m_gamma; std::shared_ptr<ngraph::runtime::Tensor> m_gamma;
...@@ -365,7 +365,7 @@ public: ...@@ -365,7 +365,7 @@ public:
using Variance = test::NDArray<T, 1>; using Variance = test::NDArray<T, 1>;
using NormedInput = test::NDArray<T, 2>; using NormedInput = test::NDArray<T, 2>;
BatchNormInferenceTesterZeroEpsilon(const std::unique_ptr<ngraph::runtime::Backend>& backend, BatchNormInferenceTesterZeroEpsilon(const std::shared_ptr<ngraph::runtime::Backend>& backend,
element::Type etype) element::Type etype)
: BatchNormInferenceTester<T>(backend, Shape{2, 3}, etype, 0.0) : BatchNormInferenceTester<T>(backend, Shape{2, 3}, etype, 0.0)
{ {
...@@ -459,7 +459,7 @@ public: ...@@ -459,7 +459,7 @@ public:
using Variance = test::NDArray<T, 1>; using Variance = test::NDArray<T, 1>;
using NormedInput = test::NDArray<T, 2>; using NormedInput = test::NDArray<T, 2>;
BatchNormInferenceTesterNonZeroEpsilon(const std::unique_ptr<ngraph::runtime::Backend>& backend, BatchNormInferenceTesterNonZeroEpsilon(const std::shared_ptr<ngraph::runtime::Backend>& backend,
element::Type etype) element::Type etype)
: BatchNormInferenceTester<T>(backend, Shape{2, 3}, etype, 0.25) : BatchNormInferenceTester<T>(backend, Shape{2, 3}, etype, 0.25)
{ {
...@@ -545,7 +545,7 @@ template <typename T> ...@@ -545,7 +545,7 @@ template <typename T>
class BatchNormTrainingTester class BatchNormTrainingTester
{ {
public: public:
BatchNormTrainingTester(const std::unique_ptr<ngraph::runtime::Backend>& backend, BatchNormTrainingTester(const std::shared_ptr<ngraph::runtime::Backend>& backend,
const Shape& input_shape, const Shape& input_shape,
element::Type etype, element::Type etype,
double epsilon) double epsilon)
...@@ -597,7 +597,7 @@ public: ...@@ -597,7 +597,7 @@ public:
} }
protected: protected:
const std::unique_ptr<ngraph::runtime::Backend>& m_backend; const std::shared_ptr<ngraph::runtime::Backend>& m_backend;
std::shared_ptr<Function> m_function; std::shared_ptr<Function> m_function;
std::shared_ptr<ngraph::runtime::Tensor> m_input; std::shared_ptr<ngraph::runtime::Tensor> m_input;
std::shared_ptr<ngraph::runtime::Tensor> m_gamma; std::shared_ptr<ngraph::runtime::Tensor> m_gamma;
...@@ -619,7 +619,7 @@ public: ...@@ -619,7 +619,7 @@ public:
using Mean = test::NDArray<T, 1>; using Mean = test::NDArray<T, 1>;
using Variance = test::NDArray<T, 1>; using Variance = test::NDArray<T, 1>;
BatchNormTrainingTesterZeroEpsilon(const std::unique_ptr<ngraph::runtime::Backend>& backend, BatchNormTrainingTesterZeroEpsilon(const std::shared_ptr<ngraph::runtime::Backend>& backend,
element::Type etype) element::Type etype)
: BatchNormTrainingTester<T>(backend, Shape{10, 3}, etype, 0.0) : BatchNormTrainingTester<T>(backend, Shape{10, 3}, etype, 0.0)
{ {
......
...@@ -926,7 +926,7 @@ TEST(builder, scaled_Q_unsigned) ...@@ -926,7 +926,7 @@ TEST(builder, scaled_Q_unsigned)
TEST(builder, dynamic_scaled_Q) TEST(builder, dynamic_scaled_Q)
{ {
auto call_SQ = [](unique_ptr<runtime::Backend>& backend, auto call_SQ = [](shared_ptr<runtime::Backend>& backend,
element::Type type, element::Type type,
op::Quantize::RoundMode mode, op::Quantize::RoundMode mode,
Shape in_shape, Shape in_shape,
...@@ -1036,7 +1036,7 @@ TEST(builder, scaled_DQ_signed) ...@@ -1036,7 +1036,7 @@ TEST(builder, scaled_DQ_signed)
} }
template <typename T> template <typename T>
shared_ptr<runtime::Tensor> call_SDQ(unique_ptr<runtime::Backend>& backend, shared_ptr<runtime::Tensor> call_SDQ(shared_ptr<runtime::Backend>& backend,
element::Type type, element::Type type,
Shape in_shape, Shape in_shape,
vector<T> in, vector<T> in,
......
...@@ -39,15 +39,26 @@ ...@@ -39,15 +39,26 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
static runtime::Backend* hybrid_creator(const char* config) static runtime::BackendConstructor* hybrid_creator()
{ {
vector<string> unsupported_0 = {"Add", "Max"}; class HybridBackendConstructor : public runtime::BackendConstructor
vector<string> unsupported_1 = {"Multiply"}; {
vector<shared_ptr<runtime::Backend>> backend_list = { public:
make_shared<runtime::interpreter::INTBackend>(unsupported_0), std::shared_ptr<runtime::Backend> create(const std::string& config) override
make_shared<runtime::cpu::CPU_Backend>()}; {
vector<string> unsupported_0 = {"Add", "Max"};
return new runtime::hybrid::HybridBackend(backend_list); vector<string> unsupported_1 = {"Multiply"};
vector<shared_ptr<runtime::Backend>> backend_list = {
make_shared<runtime::interpreter::INTBackend>(unsupported_0),
make_shared<runtime::cpu::CPU_Backend>()};
return make_shared<runtime::hybrid::HybridBackend>(backend_list);
}
};
static unique_ptr<runtime::BackendConstructor> s_backend_constructor(
new HybridBackendConstructor());
return s_backend_constructor.get();
} }
TEST(HYBRID, function_call) TEST(HYBRID, function_call)
...@@ -100,7 +111,7 @@ TEST(HYBRID, function_call) ...@@ -100,7 +111,7 @@ TEST(HYBRID, function_call)
TEST(HYBRID, abc) TEST(HYBRID, abc)
{ {
const string backend_name = "H1"; const string backend_name = "H1";
runtime::BackendManager::register_backend(backend_name, hybrid_creator); runtime::BackendManager::register_backend(backend_name, hybrid_creator());
Shape shape{2, 2}; Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape); auto A = make_shared<op::Parameter>(element::f32, shape);
...@@ -139,7 +150,7 @@ TEST(HYBRID, abc) ...@@ -139,7 +150,7 @@ TEST(HYBRID, abc)
TEST(HYBRID, simple) TEST(HYBRID, simple)
{ {
const string backend_name = "H1"; const string backend_name = "H1";
runtime::BackendManager::register_backend(backend_name, hybrid_creator); runtime::BackendManager::register_backend(backend_name, hybrid_creator());
Shape shape{2, 2}; Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::i8, shape); auto A = make_shared<op::Parameter>(element::i8, shape);
......
...@@ -150,7 +150,7 @@ namespace ngraph ...@@ -150,7 +150,7 @@ namespace ngraph
const std::shared_ptr<ngraph::runtime::Tensor>&)>; const std::shared_ptr<ngraph::runtime::Tensor>&)>;
std::shared_ptr<Function> m_function; std::shared_ptr<Function> m_function;
std::unique_ptr<runtime::Backend> m_backend; std::shared_ptr<runtime::Backend> m_backend;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors; std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_result_tensors; std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_result_tensors;
std::vector<std::shared_ptr<ngraph::op::Constant>> m_expected_outputs; std::vector<std::shared_ptr<ngraph::op::Constant>> m_expected_outputs;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment