Commit 53744aad authored by Jayaram Bobba's avatar Jayaram Bobba

Remove static default allocator and create one on demand in the cpu backend

parent 2cdb76cc
...@@ -32,29 +32,20 @@ runtime::AlignedBuffer::AlignedBuffer() ...@@ -32,29 +32,20 @@ runtime::AlignedBuffer::AlignedBuffer()
{ {
} }
runtime::AlignedBuffer::AlignedBuffer(size_t byte_size, size_t alignment)
: m_allocator(nullptr)
, m_byte_size(byte_size)
{
m_byte_size = std::max<size_t>(1, byte_size);
size_t allocation_size = m_byte_size + alignment;
m_allocated_buffer = static_cast<char*>(malloc(allocation_size));
m_aligned_buffer = m_allocated_buffer;
size_t mod = size_t(m_aligned_buffer) % alignment;
if (mod != 0)
{
m_aligned_buffer += (alignment - mod);
}
}
runtime::AlignedBuffer::AlignedBuffer(size_t byte_size, size_t alignment, Allocator* allocator) runtime::AlignedBuffer::AlignedBuffer(size_t byte_size, size_t alignment, Allocator* allocator)
: m_allocator(allocator) : m_allocator(allocator)
, m_byte_size(byte_size) , m_byte_size(byte_size)
{ {
m_byte_size = std::max<size_t>(1, byte_size); m_byte_size = std::max<size_t>(1, byte_size);
size_t allocation_size = m_byte_size + alignment; size_t allocation_size = m_byte_size + alignment;
m_allocated_buffer = static_cast<char*>(m_allocator->malloc(allocation_size, alignment)); if (allocator)
{
m_allocated_buffer = static_cast<char*>(m_allocator->malloc(allocation_size, alignment));
}
else
{
m_allocated_buffer = static_cast<char*>(malloc(allocation_size));
}
m_aligned_buffer = m_allocated_buffer; m_aligned_buffer = m_allocated_buffer;
size_t mod = size_t(m_aligned_buffer) % alignment; size_t mod = size_t(m_aligned_buffer) % alignment;
......
...@@ -34,11 +34,10 @@ namespace ngraph ...@@ -34,11 +34,10 @@ namespace ngraph
class ngraph::runtime::AlignedBuffer class ngraph::runtime::AlignedBuffer
{ {
public: public:
AlignedBuffer(size_t byte_size, size_t alignment);
// Allocator objects and the allocation interfaces are owned by the // Allocator objects and the allocation interfaces are owned by the
// creators of AlignedBuffers. They need to ensure that the lifetime of // creators of AlignedBuffers. They need to ensure that the lifetime of
// allocator exceeds the lifetime of this AlignedBuffer. // allocator exceeds the lifetime of this AlignedBuffer.
AlignedBuffer(size_t byte_size, size_t alignment, Allocator* allocator); AlignedBuffer(size_t byte_size, size_t alignment, Allocator* allocator = nullptr);
AlignedBuffer(); AlignedBuffer();
~AlignedBuffer(); ~AlignedBuffer();
......
...@@ -49,8 +49,7 @@ public: ...@@ -49,8 +49,7 @@ public:
} }
}; };
ngraph::runtime::Allocator* ngraph::runtime::get_default_allocator() std::unique_ptr<ngraph::runtime::Allocator> ngraph::runtime::create_default_allocator()
{ {
static std::unique_ptr<DefaultAllocator> allocator(new DefaultAllocator()); return std::unique_ptr<DefaultAllocator>(new DefaultAllocator());
return allocator.get();
} }
...@@ -28,9 +28,9 @@ namespace ngraph ...@@ -28,9 +28,9 @@ namespace ngraph
{ {
class Allocator; class Allocator;
class DefaultAllocator; class DefaultAllocator;
/// \brief Returns a pointer to a statically allocated singleton /// \brief Create a default allocator that calls into system
/// allocator that calls into system allocation libraries /// allocation libraries
Allocator* get_default_allocator(); std::unique_ptr<Allocator> create_default_allocator();
} }
} }
......
...@@ -143,11 +143,7 @@ public: ...@@ -143,11 +143,7 @@ public:
virtual std::shared_ptr<ngraph::Node> get_backend_op(const std::string& op_name, ...); virtual std::shared_ptr<ngraph::Node> get_backend_op(const std::string& op_name, ...);
/// \brief Returns memory allocator used by backend for host allocations /// \brief Returns memory allocator used by backend for host allocations
virtual Allocator* get_host_memory_allocator() virtual Allocator* get_host_memory_allocator() { return nullptr; }
{
return ngraph::runtime::get_default_allocator();
}
/// \brief Set the host memory allocator to be used by the backend /// \brief Set the host memory allocator to be used by the backend
/// \param allocator is pointer to host memory allocator object /// \param allocator is pointer to host memory allocator object
virtual void set_host_memory_allocator(std::unique_ptr<Allocator> allocator) {} virtual void set_host_memory_allocator(std::unique_ptr<Allocator> allocator) {}
......
...@@ -59,6 +59,11 @@ namespace ...@@ -59,6 +59,11 @@ namespace
} s_cpu_static_init; } s_cpu_static_init;
} }
runtime::cpu::CPU_Backend::~CPU_Backend()
{
m_exec_map.clear();
}
shared_ptr<runtime::cpu::CPU_CallFrame> runtime::cpu::CPU_Backend::make_call_frame( shared_ptr<runtime::cpu::CPU_CallFrame> runtime::cpu::CPU_Backend::make_call_frame(
const shared_ptr<runtime::cpu::CPU_ExternalFunction>& external_function, const shared_ptr<runtime::cpu::CPU_ExternalFunction>& external_function,
ngraph::pass::PassConfig& pass_config, ngraph::pass::PassConfig& pass_config,
...@@ -159,19 +164,23 @@ void runtime::cpu::CPU_Backend::remove_compiled_function(shared_ptr<Executable> ...@@ -159,19 +164,23 @@ void runtime::cpu::CPU_Backend::remove_compiled_function(shared_ptr<Executable>
runtime::Allocator* runtime::cpu::CPU_Backend::get_host_memory_allocator() runtime::Allocator* runtime::cpu::CPU_Backend::get_host_memory_allocator()
{ {
if (m_allocator) if (!m_allocator)
{
return m_allocator.get();
}
else
{ {
return runtime::get_default_allocator(); m_allocator = std::move(create_default_allocator());
} }
return m_allocator.get();
} }
void runtime::cpu::CPU_Backend::set_host_memory_allocator( void runtime::cpu::CPU_Backend::set_host_memory_allocator(
std::unique_ptr<runtime::Allocator> allocator) std::unique_ptr<runtime::Allocator> allocator)
{ {
if (m_allocator)
{
// Resources allocated with the existing allocator might still be around and expect it
// to be available for freeing. We cannot switch to the new allocator
throw ngraph_error(
"Allocator already exists. Changing allocators mid-execution is not permitted.");
}
m_allocator = std::move(allocator); m_allocator = std::move(allocator);
} }
......
...@@ -36,6 +36,8 @@ namespace ngraph ...@@ -36,6 +36,8 @@ namespace ngraph
class CPU_BACKEND_API CPU_Backend : public runtime::Backend class CPU_BACKEND_API CPU_Backend : public runtime::Backend
{ {
public: public:
~CPU_Backend() override;
std::shared_ptr<CPU_CallFrame> std::shared_ptr<CPU_CallFrame>
make_call_frame(const std::shared_ptr<CPU_ExternalFunction>& external_function, make_call_frame(const std::shared_ptr<CPU_ExternalFunction>& external_function,
ngraph::pass::PassConfig& pass_config, ngraph::pass::PassConfig& pass_config,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment