Commit d676fb77 authored by pruthvi's avatar pruthvi

reseat mkl funnction pointers to ngraph functions

parent ec652d77
......@@ -35,11 +35,12 @@ runtime::cpu::CPUAlignedBuffer::CPUAlignedBuffer(size_t byte_size,
ngraph::runtime::cpu::CPUAllocator& cpu_allocator)
{
m_byte_size = byte_size;
m_cpu_allocator = cpu_allocator;
AllocateFunc allocator = ngraph::runtime::cpu::CPUAllocator::m_framework_allocator;
if (m_byte_size > 0)
{
size_t allocation_size = m_byte_size + alignment;
m_allocated_buffer = static_cast<char*>(m_cpu_allocator.cpu_malloc(allocation_size));
m_allocated_buffer = static_cast<char*>(
ngraph::runtime::cpu::cpu_malloc(allocation_size, alignment, allocator));
m_aligned_buffer = m_allocated_buffer;
size_t mod = size_t(m_aligned_buffer) % alignment;
......@@ -57,8 +58,9 @@ runtime::cpu::CPUAlignedBuffer::CPUAlignedBuffer(size_t byte_size,
runtime::cpu::CPUAlignedBuffer::~CPUAlignedBuffer()
{
DestroyFunc deallocator = ngraph::runtime::cpu::CPUAllocator::m_framework_deallocator;
if (m_allocated_buffer != nullptr)
{
m_cpu_allocator.cpu_free(m_allocated_buffer);
ngraph::runtime::cpu::cpu_free(m_allocated_buffer, deallocator);
}
}
......@@ -19,29 +19,29 @@
#include "ngraph/except.hpp"
ngraph::runtime::cpu::CPUAllocator::CPUAllocator()
: m_framework_allocator(nullptr)
, m_framework_deallocator(nullptr)
, m_alignment(0)
{
}
AllocateFunc ngraph::runtime::cpu::CPUAllocator::m_framework_allocator = nullptr;
DestroyFunc ngraph::runtime::cpu::CPUAllocator::m_framework_deallocator = nullptr;
size_t ngraph::runtime::cpu::CPUAllocator::m_alignment = 4096;
ngraph::runtime::cpu::CPUAllocator::CPUAllocator(AllocateFunc allocator,
DestroyFunc deallocator,
size_t alignment)
: m_framework_allocator(allocator)
, m_framework_deallocator(deallocator)
, m_alignment(alignment)
{
// mkl::i_malloc = MallocHook;
// mkl::i_free = FreeHook;
mkl::i_malloc = MallocHook;
mkl::i_free = FreeHook;
}
void* ngraph::runtime::cpu::CPUAllocator::cpu_malloc(size_t size)
void* ngraph::runtime::cpu::cpu_malloc(size_t size,
size_t alignment,
AllocateFunc framework_allocator)
{
void* ptr;
if (m_framework_allocator != nullptr)
if (framework_allocator != nullptr)
{
ptr = m_framework_allocator(nullptr, m_alignment, size);
ptr = framework_allocator(nullptr, alignment, size);
}
else
{
......@@ -57,11 +57,11 @@ void* ngraph::runtime::cpu::CPUAllocator::cpu_malloc(size_t size)
return ptr;
}
void ngraph::runtime::cpu::CPUAllocator::cpu_free(void* ptr)
void ngraph::runtime::cpu::cpu_free(void* ptr, DestroyFunc framework_deallocator)
{
if (m_framework_deallocator && ptr)
if (framework_deallocator && ptr)
{
m_framework_deallocator(nullptr, ptr);
framework_deallocator(nullptr, ptr);
}
else if (ptr)
{
......
......@@ -47,6 +47,8 @@ namespace ngraph
namespace cpu
{
class CPUAllocator;
void* cpu_malloc(size_t size, size_t alignment, AllocateFunc framework_allocator);
void cpu_free(void* ptr, DestroyFunc framework_deallocator);
}
}
}
......@@ -58,21 +60,18 @@ public:
CPUAllocator();
~CPUAllocator();
void* cpu_malloc(size_t size);
void cpu_free(void* ptr);
static AllocateFunc m_framework_allocator;
static DestroyFunc m_framework_deallocator;
static size_t m_alignment;
private:
size_t m_alignment;
AllocateFunc m_framework_allocator;
DestroyFunc m_framework_deallocator;
/*static inline void* MallocHook(size_t size)
static inline void* MallocHook(size_t size)
{
ngraph::runtime::cpu::GetCPUAllocator().cpu_malloc(size);
ngraph::runtime::cpu::cpu_malloc(size, m_alignment, m_framework_allocator);
}
static inline void FreeHook(void* ptr)
{
ngraph::runtime::cpu::GetCPUAllocator().cpu_free(ptr);
}*/
ngraph::runtime::cpu::cpu_free(ptr, m_framework_deallocator);
}
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment