Commit d676fb77 authored by pruthvi's avatar pruthvi

reseat mkl funnction pointers to ngraph functions

parent ec652d77
...@@ -35,11 +35,12 @@ runtime::cpu::CPUAlignedBuffer::CPUAlignedBuffer(size_t byte_size, ...@@ -35,11 +35,12 @@ runtime::cpu::CPUAlignedBuffer::CPUAlignedBuffer(size_t byte_size,
ngraph::runtime::cpu::CPUAllocator& cpu_allocator) ngraph::runtime::cpu::CPUAllocator& cpu_allocator)
{ {
m_byte_size = byte_size; m_byte_size = byte_size;
m_cpu_allocator = cpu_allocator; AllocateFunc allocator = ngraph::runtime::cpu::CPUAllocator::m_framework_allocator;
if (m_byte_size > 0) if (m_byte_size > 0)
{ {
size_t allocation_size = m_byte_size + alignment; size_t allocation_size = m_byte_size + alignment;
m_allocated_buffer = static_cast<char*>(m_cpu_allocator.cpu_malloc(allocation_size)); m_allocated_buffer = static_cast<char*>(
ngraph::runtime::cpu::cpu_malloc(allocation_size, alignment, allocator));
m_aligned_buffer = m_allocated_buffer; m_aligned_buffer = m_allocated_buffer;
size_t mod = size_t(m_aligned_buffer) % alignment; size_t mod = size_t(m_aligned_buffer) % alignment;
...@@ -57,8 +58,9 @@ runtime::cpu::CPUAlignedBuffer::CPUAlignedBuffer(size_t byte_size, ...@@ -57,8 +58,9 @@ runtime::cpu::CPUAlignedBuffer::CPUAlignedBuffer(size_t byte_size,
runtime::cpu::CPUAlignedBuffer::~CPUAlignedBuffer() runtime::cpu::CPUAlignedBuffer::~CPUAlignedBuffer()
{ {
DestroyFunc deallocator = ngraph::runtime::cpu::CPUAllocator::m_framework_deallocator;
if (m_allocated_buffer != nullptr) if (m_allocated_buffer != nullptr)
{ {
m_cpu_allocator.cpu_free(m_allocated_buffer); ngraph::runtime::cpu::cpu_free(m_allocated_buffer, deallocator);
} }
} }
...@@ -19,29 +19,29 @@ ...@@ -19,29 +19,29 @@
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
ngraph::runtime::cpu::CPUAllocator::CPUAllocator() ngraph::runtime::cpu::CPUAllocator::CPUAllocator()
: m_framework_allocator(nullptr)
, m_framework_deallocator(nullptr)
, m_alignment(0)
{ {
} }
AllocateFunc ngraph::runtime::cpu::CPUAllocator::m_framework_allocator = nullptr;
DestroyFunc ngraph::runtime::cpu::CPUAllocator::m_framework_deallocator = nullptr;
size_t ngraph::runtime::cpu::CPUAllocator::m_alignment = 4096;
ngraph::runtime::cpu::CPUAllocator::CPUAllocator(AllocateFunc allocator, ngraph::runtime::cpu::CPUAllocator::CPUAllocator(AllocateFunc allocator,
DestroyFunc deallocator, DestroyFunc deallocator,
size_t alignment) size_t alignment)
: m_framework_allocator(allocator)
, m_framework_deallocator(deallocator)
, m_alignment(alignment)
{ {
// mkl::i_malloc = MallocHook; mkl::i_malloc = MallocHook;
// mkl::i_free = FreeHook; mkl::i_free = FreeHook;
} }
void* ngraph::runtime::cpu::CPUAllocator::cpu_malloc(size_t size) void* ngraph::runtime::cpu::cpu_malloc(size_t size,
size_t alignment,
AllocateFunc framework_allocator)
{ {
void* ptr; void* ptr;
if (m_framework_allocator != nullptr) if (framework_allocator != nullptr)
{ {
ptr = m_framework_allocator(nullptr, m_alignment, size); ptr = framework_allocator(nullptr, alignment, size);
} }
else else
{ {
...@@ -57,11 +57,11 @@ void* ngraph::runtime::cpu::CPUAllocator::cpu_malloc(size_t size) ...@@ -57,11 +57,11 @@ void* ngraph::runtime::cpu::CPUAllocator::cpu_malloc(size_t size)
return ptr; return ptr;
} }
void ngraph::runtime::cpu::CPUAllocator::cpu_free(void* ptr) void ngraph::runtime::cpu::cpu_free(void* ptr, DestroyFunc framework_deallocator)
{ {
if (m_framework_deallocator && ptr) if (framework_deallocator && ptr)
{ {
m_framework_deallocator(nullptr, ptr); framework_deallocator(nullptr, ptr);
} }
else if (ptr) else if (ptr)
{ {
......
...@@ -47,6 +47,8 @@ namespace ngraph ...@@ -47,6 +47,8 @@ namespace ngraph
namespace cpu namespace cpu
{ {
class CPUAllocator; class CPUAllocator;
void* cpu_malloc(size_t size, size_t alignment, AllocateFunc framework_allocator);
void cpu_free(void* ptr, DestroyFunc framework_deallocator);
} }
} }
} }
...@@ -58,21 +60,18 @@ public: ...@@ -58,21 +60,18 @@ public:
CPUAllocator(); CPUAllocator();
~CPUAllocator(); ~CPUAllocator();
void* cpu_malloc(size_t size); static AllocateFunc m_framework_allocator;
void cpu_free(void* ptr); static DestroyFunc m_framework_deallocator;
static size_t m_alignment;
private: private:
size_t m_alignment; static inline void* MallocHook(size_t size)
AllocateFunc m_framework_allocator;
DestroyFunc m_framework_deallocator;
/*static inline void* MallocHook(size_t size)
{ {
ngraph::runtime::cpu::GetCPUAllocator().cpu_malloc(size); ngraph::runtime::cpu::cpu_malloc(size, m_alignment, m_framework_allocator);
} }
static inline void FreeHook(void* ptr) static inline void FreeHook(void* ptr)
{ {
ngraph::runtime::cpu::GetCPUAllocator().cpu_free(ptr); ngraph::runtime::cpu::cpu_free(ptr, m_framework_deallocator);
}*/ }
}; };
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment