Commit 3d41e5c6 authored by pruthvi's avatar pruthvi

Addressed PR comments

parent b171fe57
...@@ -34,20 +34,11 @@ runtime::AlignedBuffer::AlignedBuffer(size_t byte_size, ...@@ -34,20 +34,11 @@ runtime::AlignedBuffer::AlignedBuffer(size_t byte_size,
size_t alignment, size_t alignment,
std::shared_ptr<ngraph::runtime::Allocator> allocator) std::shared_ptr<ngraph::runtime::Allocator> allocator)
{ {
m_allocator = allocator;
m_byte_size = byte_size; m_byte_size = byte_size;
if (m_byte_size > 0) if (m_byte_size > 0)
{ {
size_t allocation_size = m_byte_size + alignment; size_t allocation_size = m_byte_size + alignment;
if (m_allocator) m_allocated_buffer = static_cast<char*>(m_allocator->Malloc(allocation_size, alignment));
{
m_allocated_buffer =
static_cast<char*>(m_allocator->Malloc(nullptr, allocation_size, alignment));
}
else
{
m_allocated_buffer = static_cast<char*>(ngraph_malloc(allocation_size));
}
m_aligned_buffer = m_allocated_buffer; m_aligned_buffer = m_allocated_buffer;
size_t mod = size_t(m_aligned_buffer) % alignment; size_t mod = size_t(m_aligned_buffer) % alignment;
...@@ -60,7 +51,6 @@ runtime::AlignedBuffer::AlignedBuffer(size_t byte_size, ...@@ -60,7 +51,6 @@ runtime::AlignedBuffer::AlignedBuffer(size_t byte_size,
{ {
m_allocated_buffer = nullptr; m_allocated_buffer = nullptr;
m_aligned_buffer = nullptr; m_aligned_buffer = nullptr;
m_allocator = nullptr;
} }
} }
...@@ -68,13 +58,6 @@ runtime::AlignedBuffer::~AlignedBuffer() ...@@ -68,13 +58,6 @@ runtime::AlignedBuffer::~AlignedBuffer()
{ {
if (m_allocated_buffer != nullptr) if (m_allocated_buffer != nullptr)
{ {
if (m_allocator) m_allocator->Free(m_allocated_buffer);
{
m_allocator->Free(nullptr, m_allocated_buffer);
}
else
{
ngraph_free(m_allocated_buffer);
}
} }
} }
...@@ -36,7 +36,8 @@ class ngraph::runtime::AlignedBuffer ...@@ -36,7 +36,8 @@ class ngraph::runtime::AlignedBuffer
public: public:
AlignedBuffer(size_t byte_size, AlignedBuffer(size_t byte_size,
size_t alignment, size_t alignment,
std::shared_ptr<ngraph::runtime::Allocator> allocator = nullptr); std::shared_ptr<ngraph::runtime::Allocator> allocator =
std::make_shared<runtime::Allocator>());
AlignedBuffer(); AlignedBuffer();
~AlignedBuffer(); ~AlignedBuffer();
......
...@@ -16,19 +16,19 @@ ...@@ -16,19 +16,19 @@
#include "ngraph/runtime/allocator.hpp" #include "ngraph/runtime/allocator.hpp"
void* ngraph::runtime::Allocator::Malloc(void* handle, size_t size, size_t alignment) void* ngraph::runtime::Allocator::Malloc(size_t size, size_t alignment)
{ {
void* ptr = ngraph::aligned_alloc(alignment, size); void* ptr = ngraph::aligned_alloc(alignment, size);
// check for exception // check for exception
if (size != 0 && !ptr) if (!ptr)
{ {
throw ngraph_error("malloc failed to allocate memory of size " + std::to_string(size)); throw ngraph_error("malloc failed to allocate memory of size " + std::to_string(size));
} }
return ptr; return ptr;
} }
void ngraph::runtime::Allocator::Free(void* handle, void* ptr) void ngraph::runtime::Allocator::Free(void* ptr)
{ {
if (ptr) if (ptr)
{ {
......
...@@ -29,11 +29,18 @@ namespace ngraph ...@@ -29,11 +29,18 @@ namespace ngraph
class Allocator; class Allocator;
} }
} }
// Abstract class for the allocator, for allocating and deallocating device memory
/// \brief Abstract class for the allocator, for allocating and deallocating device memory
class ngraph::runtime::Allocator class ngraph::runtime::Allocator
{ {
public: public:
virtual ~Allocator() = default; virtual ~Allocator() = default;
virtual void* Malloc(void* handle, size_t size, size_t alignment); /// \brief allocates the memory on the device with the given size and alignment requirement
virtual void Free(void* handle, void* ptr); /// \param exact size of bytes to allocate
/// \param alignment specifies the alignment. Must be a valid alignment supported by the implementation.
virtual void* Malloc(size_t size, size_t alignment);
/// \brief deallocates the memory pointed by ptr
/// \param ptr pointer to the aligned memory to be released
virtual void Free(void* ptr);
}; };
...@@ -71,19 +71,27 @@ std::shared_ptr<ngraph::runtime::Allocator> runtime::Backend::get_framework_memo ...@@ -71,19 +71,27 @@ std::shared_ptr<ngraph::runtime::Allocator> runtime::Backend::get_framework_memo
void runtime::Backend::set_framework_memory_allocator( void runtime::Backend::set_framework_memory_allocator(
const std::shared_ptr<ngraph::runtime::Allocator>& allocator) const std::shared_ptr<ngraph::runtime::Allocator>& allocator)
{ {
// override this method from all supported backends to set its memory allocator to
// framework passed memory allocator
} }
ngraph::runtime::AllocateFunc runtime::Backend::get_device_memory_alloc() ngraph::runtime::AllocateFunc runtime::Backend::get_device_memory_alloc()
{ {
// override this method from all supported backends to return memory allocator
// which allocates device pinned memory
return nullptr; return nullptr;
} }
ngraph::runtime::DestroyFunc runtime::Backend::get_device_memory_dealloc() ngraph::runtime::DestroyFunc runtime::Backend::get_device_memory_dealloc()
{ {
// override this method from all supported backends to return memory de-allocator
// which de-allocates device pinned memory
return nullptr; return nullptr;
} }
bool runtime::Backend::is_device_memory() bool runtime::Backend::is_device_memory(void* ptr)
{ {
// override this method for each supported backend to determine if the passed pointer is in
// device pinned memory or not
return false; return false;
} }
...@@ -117,9 +117,18 @@ public: ...@@ -117,9 +117,18 @@ public:
virtual void remove_compiled_function(std::shared_ptr<Executable> exec); virtual void remove_compiled_function(std::shared_ptr<Executable> exec);
virtual std::shared_ptr<ngraph::runtime::Allocator> get_framework_memory_allocator(); virtual std::shared_ptr<ngraph::runtime::Allocator> get_framework_memory_allocator();
/// \brief method for the framework to pass its memory allocator object to the backend.
virtual void set_framework_memory_allocator( virtual void set_framework_memory_allocator(
const std::shared_ptr<ngraph::runtime::Allocator>& allocator); const std::shared_ptr<ngraph::runtime::Allocator>& allocator);
/// \brief method to return memory de-allocator which de-allocates device pinned memory
virtual ngraph::runtime::AllocateFunc get_device_memory_alloc(); virtual ngraph::runtime::AllocateFunc get_device_memory_alloc();
/// \brief method to return memory allocator which allocates device pinned memory
virtual ngraph::runtime::DestroyFunc get_device_memory_dealloc(); virtual ngraph::runtime::DestroyFunc get_device_memory_dealloc();
virtual bool is_device_memory();
/// \brief method for each supported backend to determine if the passed pointer is in device pinned memory or not
/// \param ptr pointer to the memory to determine if its in device memory or not
virtual bool is_device_memory(void* ptr);
}; };
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment