Commit 273f9aed authored by pruthvi's avatar pruthvi

- add support for the backend api to accept framework allocators and deallocators

- changes in the call frame to assign framework allocators to cpu_allocator
parent afaa43fd
......@@ -52,9 +52,12 @@ namespace
shared_ptr<runtime::cpu::CPU_CallFrame> runtime::cpu::CPU_Backend::make_call_frame(
const shared_ptr<runtime::cpu::CPU_ExternalFunction>& external_function,
ngraph::pass::PassConfig& pass_config)
ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator,
DestroyFunc& framework_deallocator)
{
return external_function->make_call_frame(pass_config);
return external_function->make_call_frame(
pass_config, framework_allocator, framework_deallocator);
}
shared_ptr<runtime::Tensor>
......@@ -79,7 +82,9 @@ shared_ptr<runtime::Executable>
shared_ptr<runtime::Executable>
runtime::cpu::CPU_Backend::compile(shared_ptr<Function> func,
ngraph::pass::PassConfig& pass_config,
bool performance_counters_enabled)
bool performance_counters_enabled,
AllocateFunc framework_allocator,
DestroyFunc framework_deallocator)
{
shared_ptr<runtime::Executable> rc;
auto it = m_exec_map.find(func);
......@@ -89,7 +94,11 @@ shared_ptr<runtime::Executable>
}
else
{
rc = make_shared<CPU_Executable>(func, pass_config, performance_counters_enabled);
rc = make_shared<CPU_Executable>(func,
pass_config,
performance_counters_enabled,
framework_allocator,
framework_deallocator);
m_exec_map.insert({func, rc});
}
return rc;
......@@ -97,14 +106,17 @@ shared_ptr<runtime::Executable>
runtime::cpu::CPU_Executable::CPU_Executable(shared_ptr<Function> func,
ngraph::pass::PassConfig& pass_config,
bool performance_counters_enabled)
bool performance_counters_enabled,
AllocateFunc& framework_allocator,
DestroyFunc& framework_deallocator)
{
FunctionInstance& instance = m_function_instance;
if (instance.m_external_function == nullptr)
{
instance.m_external_function = make_shared<CPU_ExternalFunction>(func);
instance.m_external_function->m_emit_timing = performance_counters_enabled;
auto cf = instance.m_external_function->make_call_frame(pass_config);
auto cf = instance.m_external_function->make_call_frame(
pass_config, framework_allocator, framework_deallocator);
instance.m_call_frame = dynamic_pointer_cast<CPU_CallFrame>(cf);
}
set_parameters_and_results(*func);
......
......@@ -22,6 +22,7 @@
#include "cpu_backend_visibility.h"
#include "ngraph/pass/pass_config.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/cpu/cpu_mkl_allocator.hpp"
namespace ngraph
{
......@@ -37,7 +38,9 @@ namespace ngraph
public:
std::shared_ptr<CPU_CallFrame>
make_call_frame(const std::shared_ptr<CPU_ExternalFunction>& external_function,
ngraph::pass::PassConfig& pass_config);
ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator,
DestroyFunc& framework_deallocator);
std::shared_ptr<ngraph::runtime::Tensor>
create_tensor(const ngraph::element::Type& element_type,
......@@ -55,7 +58,9 @@ namespace ngraph
std::shared_ptr<ngraph::runtime::Executable>
compile(std::shared_ptr<Function> func,
ngraph::pass::PassConfig& pass_config,
bool enable_performance_counters = false);
bool enable_performance_counters = false,
AllocateFunc framework_allocator = nullptr,
DestroyFunc framework_deallocator = nullptr);
void remove_compiled_function(std::shared_ptr<Executable> exec) override;
......@@ -72,7 +77,9 @@ namespace ngraph
public:
CPU_Executable(std::shared_ptr<Function> func,
ngraph::pass::PassConfig& pass_config,
bool performance_counters_enabled);
bool performance_counters_enabled,
AllocateFunc& framework_allocator,
DestroyFunc& framework_deallocator);
bool call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) override;
......
......@@ -33,8 +33,12 @@ using namespace ngraph;
runtime::cpu::CPU_CallFrame::CPU_CallFrame(std::shared_ptr<CPU_ExternalFunction> external_function,
InitContextFuncCG compiled_init_ctx_func,
DestroyContextFuncCG compiled_destroy_ctx_func,
EntryPoint compiled_function)
EntryPoint compiled_function,
AllocateFunc& framework_allocator,
DestroyFunc& framework_deallocator)
: m_external_function(external_function)
, m_framework_allocator(framework_allocator)
, m_framework_deallocator(framework_deallocator)
, m_compiled_init_ctx_func(compiled_init_ctx_func)
, m_compiled_destroy_ctx_func(compiled_destroy_ctx_func)
, m_compiled_function(compiled_function)
......@@ -140,8 +144,10 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context()
// Create temporary buffer pools
size_t alignment = runtime::cpu::CPU_ExternalFunction::s_memory_pool_alignment;
ngraph::runtime::cpu::CPUAllocator::framework_allocator = nullptr;
ngraph::runtime::cpu::CPUAllocator::framework_deallocator = nullptr;
// assign the passed memory allocators
ngraph::runtime::cpu::CPUAllocator::framework_allocator = m_framework_allocator;
ngraph::runtime::cpu::CPUAllocator::framework_deallocator = m_framework_deallocator;
ngraph::runtime::cpu::CPUAllocator::alignment = alignment;
for (auto buffer_size : m_external_function->get_memory_buffer_sizes())
......
......@@ -23,6 +23,7 @@
#include "ngraph/function.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_mkl_allocator.hpp"
#include "ngraph/runtime/cpu/cpu_runtime_context.hpp"
#include "ngraph/runtime/tensor.hpp"
......@@ -57,7 +58,10 @@ namespace ngraph
CPU_CallFrame(std::shared_ptr<CPU_ExternalFunction> external_function,
InitContextFuncCG compiled_init_ctx_func,
DestroyContextFuncCG compiled_destroy_ctx_func,
EntryPoint compiled_function);
EntryPoint compiled_function,
AllocateFunc& framework_allocator,
DestroyFunc& framework_deallocator);
~CPU_CallFrame();
/// \brief Invoke the function with values matching the signature of the function.
......@@ -85,6 +89,10 @@ namespace ngraph
CPURuntimeContext* ctx = nullptr;
// memeber function pointers to hold the framework allocators
AllocateFunc m_framework_allocator;
DestroyFunc m_framework_deallocator;
/* Codegen specific */
/// Function that initializes the context used in codegen mode.
......
......@@ -141,6 +141,7 @@
#include "ngraph/runtime/cpu/cpu_emitter.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/cpu/cpu_mkl_allocator.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view.hpp"
#include "ngraph/runtime/cpu/cpu_tracing.hpp"
......@@ -1701,7 +1702,9 @@ void*& runtime::cpu::CPU_ExternalFunction::get_tensor_data(const std::string& na
}
shared_ptr<ngraph::runtime::cpu::CPU_CallFrame>
runtime::cpu::CPU_ExternalFunction::make_call_frame(ngraph::pass::PassConfig& pass_config)
runtime::cpu::CPU_ExternalFunction::make_call_frame(ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator,
DestroyFunc& framework_deallocator)
{
#if !defined(NGRAPH_DEX_ONLY)
if (!m_is_compiled && !m_direct_execution)
......@@ -1718,7 +1721,9 @@ shared_ptr<ngraph::runtime::cpu::CPU_CallFrame>
return make_shared<ngraph::runtime::cpu::CPU_CallFrame>(shared_from_this(),
m_compiled_init_ctx_func,
m_compiled_destroy_ctx_func,
m_compiled_function);
m_compiled_function,
framework_allocator,
framework_deallocator);
}
const runtime::cpu::LayoutDescriptorPtrs&
......
......@@ -45,6 +45,7 @@
#include "ngraph/pass/pass_config.hpp"
#include "ngraph/runtime/cpu/cpu_call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_mkl_allocator.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#include "ngraph/runtime/cpu/mkldnn_emitter.hpp"
#include "ngraph/runtime/performance_counter.hpp"
......@@ -100,7 +101,9 @@ namespace ngraph
bool release_function = true);
~CPU_ExternalFunction();
std::shared_ptr<ngraph::runtime::cpu::CPU_CallFrame>
make_call_frame(ngraph::pass::PassConfig& pass_config);
make_call_frame(ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator,
DestroyFunc& framework_deallocator);
const LayoutDescriptorPtrs& get_parameter_layout_descriptors();
const LayoutDescriptorPtrs& get_result_layout_descriptors();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment