Commit 3d41e5c6 authored by pruthvi's avatar pruthvi

Addressed PR comments

parent b171fe57
......@@ -34,20 +34,11 @@ runtime::AlignedBuffer::AlignedBuffer(size_t byte_size,
size_t alignment,
std::shared_ptr<ngraph::runtime::Allocator> allocator)
{
m_allocator = allocator;
m_byte_size = byte_size;
if (m_byte_size > 0)
{
size_t allocation_size = m_byte_size + alignment;
if (m_allocator)
{
m_allocated_buffer =
static_cast<char*>(m_allocator->Malloc(nullptr, allocation_size, alignment));
}
else
{
m_allocated_buffer = static_cast<char*>(ngraph_malloc(allocation_size));
}
m_allocated_buffer = static_cast<char*>(m_allocator->Malloc(allocation_size, alignment));
m_aligned_buffer = m_allocated_buffer;
size_t mod = size_t(m_aligned_buffer) % alignment;
......@@ -60,7 +51,6 @@ runtime::AlignedBuffer::AlignedBuffer(size_t byte_size,
{
m_allocated_buffer = nullptr;
m_aligned_buffer = nullptr;
m_allocator = nullptr;
}
}
......@@ -68,13 +58,6 @@ runtime::AlignedBuffer::~AlignedBuffer()
{
if (m_allocated_buffer != nullptr)
{
if (m_allocator)
{
m_allocator->Free(nullptr, m_allocated_buffer);
}
else
{
ngraph_free(m_allocated_buffer);
}
m_allocator->Free(m_allocated_buffer);
}
}
......@@ -36,7 +36,8 @@ class ngraph::runtime::AlignedBuffer
public:
AlignedBuffer(size_t byte_size,
size_t alignment,
std::shared_ptr<ngraph::runtime::Allocator> allocator = nullptr);
std::shared_ptr<ngraph::runtime::Allocator> allocator =
std::make_shared<runtime::Allocator>());
AlignedBuffer();
~AlignedBuffer();
......
......@@ -16,19 +16,19 @@
#include "ngraph/runtime/allocator.hpp"
void* ngraph::runtime::Allocator::Malloc(void* handle, size_t size, size_t alignment)
void* ngraph::runtime::Allocator::Malloc(size_t size, size_t alignment)
{
void* ptr = ngraph::aligned_alloc(alignment, size);
// check for exception
if (size != 0 && !ptr)
if (!ptr)
{
throw ngraph_error("malloc failed to allocate memory of size " + std::to_string(size));
}
return ptr;
}
void ngraph::runtime::Allocator::Free(void* handle, void* ptr)
void ngraph::runtime::Allocator::Free(void* ptr)
{
if (ptr)
{
......
......@@ -29,11 +29,18 @@ namespace ngraph
class Allocator;
}
}
// Abstract class for the allocator, for allocating and deallocating device memory
/// \brief Abstract class for the allocator, for allocating and deallocating device memory
class ngraph::runtime::Allocator
{
public:
virtual ~Allocator() = default;
virtual void* Malloc(void* handle, size_t size, size_t alignment);
virtual void Free(void* handle, void* ptr);
/// \brief allocates the memory on the device with the given size and alignment requirement
/// \param exact size of bytes to allocate
/// \param alignment specifies the alignment. Must be a valid alignment supported by the implementation.
virtual void* Malloc(size_t size, size_t alignment);
/// \brief deallocates the memory pointed by ptr
/// \param ptr pointer to the aligned memory to be released
virtual void Free(void* ptr);
};
......@@ -71,19 +71,27 @@ std::shared_ptr<ngraph::runtime::Allocator> runtime::Backend::get_framework_memo
void runtime::Backend::set_framework_memory_allocator(
const std::shared_ptr<ngraph::runtime::Allocator>& allocator)
{
// override this method from all supported backends to set its memory allocator to
// framework passed memory allocator
}
ngraph::runtime::AllocateFunc runtime::Backend::get_device_memory_alloc()
{
// override this method from all supported backends to return memory allocator
// which allocates device pinned memory
return nullptr;
}
ngraph::runtime::DestroyFunc runtime::Backend::get_device_memory_dealloc()
{
// override this method from all supported backends to return memory de-allocator
// which de-allocates device pinned memory
return nullptr;
}
bool runtime::Backend::is_device_memory()
bool runtime::Backend::is_device_memory(void* ptr)
{
// override this method for each supported backend to determine if the passed pointer is in
// device pinned memory or not
return false;
}
......@@ -117,9 +117,18 @@ public:
virtual void remove_compiled_function(std::shared_ptr<Executable> exec);
virtual std::shared_ptr<ngraph::runtime::Allocator> get_framework_memory_allocator();
/// \brief method for the framework to pass its memory allocator object to the backend.
virtual void set_framework_memory_allocator(
const std::shared_ptr<ngraph::runtime::Allocator>& allocator);
/// \brief method to return memory de-allocator which de-allocates device pinned memory
virtual ngraph::runtime::AllocateFunc get_device_memory_alloc();
/// \brief method to return memory allocator which allocates device pinned memory
virtual ngraph::runtime::DestroyFunc get_device_memory_dealloc();
virtual bool is_device_memory();
/// \brief method for each supported backend to determine if the passed pointer is in device pinned memory or not
/// \param ptr pointer to the memory to determine if its in device memory or not
virtual bool is_device_memory(void* ptr);
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment