Commit 9e54e66f authored by pruthvi's avatar pruthvi

- Addressed PR comments

- Receive framework memory_allocators as pointers instead of reference to function pointers
parent e762262e
...@@ -29,7 +29,7 @@ namespace ngraph ...@@ -29,7 +29,7 @@ namespace ngraph
class Allocator; class Allocator;
} }
} }
// Abstarct class for the allocator, for allocating and deallocating device memory // Abstract class for the allocator, for allocating and deallocating device memory
class ngraph::runtime::Allocator class ngraph::runtime::Allocator
{ {
public: public:
......
...@@ -42,8 +42,8 @@ vector<string> runtime::Backend::get_registered_devices() ...@@ -42,8 +42,8 @@ vector<string> runtime::Backend::get_registered_devices()
std::shared_ptr<runtime::Executable> std::shared_ptr<runtime::Executable>
runtime::Backend::compile(std::shared_ptr<Function> func, runtime::Backend::compile(std::shared_ptr<Function> func,
ngraph::pass::PassConfig& pass_config, ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator, AllocateFunc memory_allocator,
DestroyFunc& framework_deallocator, DestroyFunc memory_deallocator,
bool enable_performance_data) bool enable_performance_data)
{ {
return compile(func, enable_performance_data); return compile(func, enable_performance_data);
......
...@@ -95,8 +95,8 @@ public: ...@@ -95,8 +95,8 @@ public:
/// \returns compiled function or nullptr on failure /// \returns compiled function or nullptr on failure
virtual std::shared_ptr<Executable> compile(std::shared_ptr<Function> func, virtual std::shared_ptr<Executable> compile(std::shared_ptr<Function> func,
ngraph::pass::PassConfig& pass_config, ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator, AllocateFunc memory_allocator = nullptr,
DestroyFunc& framework_deallocator, DestroyFunc memory_deallocator = nullptr,
bool enable_performance_data = false); bool enable_performance_data = false);
/// \brief Test if a backend is capable of supporting an op /// \brief Test if a backend is capable of supporting an op
......
...@@ -53,11 +53,10 @@ namespace ...@@ -53,11 +53,10 @@ namespace
shared_ptr<runtime::cpu::CPU_CallFrame> runtime::cpu::CPU_Backend::make_call_frame( shared_ptr<runtime::cpu::CPU_CallFrame> runtime::cpu::CPU_Backend::make_call_frame(
const shared_ptr<runtime::cpu::CPU_ExternalFunction>& external_function, const shared_ptr<runtime::cpu::CPU_ExternalFunction>& external_function,
ngraph::pass::PassConfig& pass_config, ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator, AllocateFunc memory_allocator,
DestroyFunc& framework_deallocator) DestroyFunc memory_deallocator)
{ {
return external_function->make_call_frame( return external_function->make_call_frame(pass_config, memory_allocator, memory_deallocator);
pass_config, framework_allocator, framework_deallocator);
} }
shared_ptr<runtime::Tensor> shared_ptr<runtime::Tensor>
...@@ -76,19 +75,14 @@ shared_ptr<runtime::Executable> ...@@ -76,19 +75,14 @@ shared_ptr<runtime::Executable>
runtime::cpu::CPU_Backend::compile(shared_ptr<Function> func, bool performance_counters_enabled) runtime::cpu::CPU_Backend::compile(shared_ptr<Function> func, bool performance_counters_enabled)
{ {
ngraph::pass::PassConfig pass_config; ngraph::pass::PassConfig pass_config;
return compile(func, pass_config, nullptr, nullptr, performance_counters_enabled);
// TODO(pruthvi): reseat this with framework provided allocator and deallocator
AllocateFunc framework_alloc = nullptr;
DestroyFunc framework_dealloc = nullptr;
return compile(
func, pass_config, framework_alloc, framework_dealloc, performance_counters_enabled);
} }
shared_ptr<runtime::Executable> shared_ptr<runtime::Executable>
runtime::cpu::CPU_Backend::compile(shared_ptr<Function> func, runtime::cpu::CPU_Backend::compile(shared_ptr<Function> func,
ngraph::pass::PassConfig& pass_config, ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator, AllocateFunc memory_allocator,
DestroyFunc& framework_deallocator, DestroyFunc memory_deallocator,
bool performance_counters_enabled) bool performance_counters_enabled)
{ {
shared_ptr<runtime::Executable> rc; shared_ptr<runtime::Executable> rc;
...@@ -99,11 +93,8 @@ shared_ptr<runtime::Executable> ...@@ -99,11 +93,8 @@ shared_ptr<runtime::Executable>
} }
else else
{ {
rc = make_shared<CPU_Executable>(func, rc = make_shared<CPU_Executable>(
pass_config, func, pass_config, memory_allocator, memory_deallocator, performance_counters_enabled);
framework_allocator,
framework_deallocator,
performance_counters_enabled);
m_exec_map.insert({func, rc}); m_exec_map.insert({func, rc});
} }
return rc; return rc;
...@@ -111,8 +102,8 @@ shared_ptr<runtime::Executable> ...@@ -111,8 +102,8 @@ shared_ptr<runtime::Executable>
runtime::cpu::CPU_Executable::CPU_Executable(shared_ptr<Function> func, runtime::cpu::CPU_Executable::CPU_Executable(shared_ptr<Function> func,
ngraph::pass::PassConfig& pass_config, ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator, AllocateFunc memory_allocator,
DestroyFunc& framework_deallocator, DestroyFunc memory_deallocator,
bool performance_counters_enabled) bool performance_counters_enabled)
{ {
FunctionInstance& instance = m_function_instance; FunctionInstance& instance = m_function_instance;
...@@ -121,7 +112,7 @@ runtime::cpu::CPU_Executable::CPU_Executable(shared_ptr<Function> func, ...@@ -121,7 +112,7 @@ runtime::cpu::CPU_Executable::CPU_Executable(shared_ptr<Function> func,
instance.m_external_function = make_shared<CPU_ExternalFunction>(func); instance.m_external_function = make_shared<CPU_ExternalFunction>(func);
instance.m_external_function->m_emit_timing = performance_counters_enabled; instance.m_external_function->m_emit_timing = performance_counters_enabled;
auto cf = instance.m_external_function->make_call_frame( auto cf = instance.m_external_function->make_call_frame(
pass_config, framework_allocator, framework_deallocator); pass_config, memory_allocator, memory_deallocator);
instance.m_call_frame = dynamic_pointer_cast<CPU_CallFrame>(cf); instance.m_call_frame = dynamic_pointer_cast<CPU_CallFrame>(cf);
} }
set_parameters_and_results(*func); set_parameters_and_results(*func);
...@@ -179,6 +170,7 @@ bool runtime::cpu::CPU_Backend::is_supported(const Node& op) const ...@@ -179,6 +170,7 @@ bool runtime::cpu::CPU_Backend::is_supported(const Node& op) const
{ {
return true; return true;
} }
bool runtime::cpu::CPU_Backend::is_supported_property(const Property prop) const bool runtime::cpu::CPU_Backend::is_supported_property(const Property prop) const
{ {
if (prop == Property::memory_attach) if (prop == Property::memory_attach)
......
...@@ -39,8 +39,8 @@ namespace ngraph ...@@ -39,8 +39,8 @@ namespace ngraph
std::shared_ptr<CPU_CallFrame> std::shared_ptr<CPU_CallFrame>
make_call_frame(const std::shared_ptr<CPU_ExternalFunction>& external_function, make_call_frame(const std::shared_ptr<CPU_ExternalFunction>& external_function,
ngraph::pass::PassConfig& pass_config, ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator, AllocateFunc memory_allocator,
DestroyFunc& framework_deallocator); DestroyFunc memory_deallocator);
std::shared_ptr<ngraph::runtime::Tensor> std::shared_ptr<ngraph::runtime::Tensor>
create_tensor(const ngraph::element::Type& element_type, create_tensor(const ngraph::element::Type& element_type,
...@@ -58,8 +58,8 @@ namespace ngraph ...@@ -58,8 +58,8 @@ namespace ngraph
std::shared_ptr<ngraph::runtime::Executable> std::shared_ptr<ngraph::runtime::Executable>
compile(std::shared_ptr<Function> func, compile(std::shared_ptr<Function> func,
ngraph::pass::PassConfig& pass_config, ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator, AllocateFunc memory_allocator = nullptr,
DestroyFunc& framework_deallocator, DestroyFunc memory_deallocator = nullptr,
bool enable_performance_counters = false) override; bool enable_performance_counters = false) override;
void remove_compiled_function(std::shared_ptr<Executable> exec) override; void remove_compiled_function(std::shared_ptr<Executable> exec) override;
...@@ -77,8 +77,8 @@ namespace ngraph ...@@ -77,8 +77,8 @@ namespace ngraph
public: public:
CPU_Executable(std::shared_ptr<Function> func, CPU_Executable(std::shared_ptr<Function> func,
ngraph::pass::PassConfig& pass_config, ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator, AllocateFunc memory_allocator,
DestroyFunc& framework_deallocator, DestroyFunc memory_deallocator,
bool performance_counters_enabled); bool performance_counters_enabled);
bool call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs, bool call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) override; const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) override;
......
...@@ -33,11 +33,11 @@ runtime::cpu::CPU_CallFrame::CPU_CallFrame(std::shared_ptr<CPU_ExternalFunction> ...@@ -33,11 +33,11 @@ runtime::cpu::CPU_CallFrame::CPU_CallFrame(std::shared_ptr<CPU_ExternalFunction>
InitContextFuncCG compiled_init_ctx_func, InitContextFuncCG compiled_init_ctx_func,
DestroyContextFuncCG compiled_destroy_ctx_func, DestroyContextFuncCG compiled_destroy_ctx_func,
EntryPoint compiled_function, EntryPoint compiled_function,
AllocateFunc& framework_allocator, AllocateFunc memory_allocator,
DestroyFunc& framework_deallocator) DestroyFunc memory_deallocator)
: m_external_function(external_function) : m_external_function(external_function)
, m_framework_allocator(framework_allocator) , m_memory_allocator(memory_allocator)
, m_framework_deallocator(framework_deallocator) , m_memory_deallocator(memory_deallocator)
, m_compiled_init_ctx_func(compiled_init_ctx_func) , m_compiled_init_ctx_func(compiled_init_ctx_func)
, m_compiled_destroy_ctx_func(compiled_destroy_ctx_func) , m_compiled_destroy_ctx_func(compiled_destroy_ctx_func)
, m_compiled_function(compiled_function) , m_compiled_function(compiled_function)
...@@ -144,18 +144,7 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context() ...@@ -144,18 +144,7 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context()
// Create temporary buffer pools // Create temporary buffer pools
size_t alignment = runtime::cpu::CPU_ExternalFunction::s_memory_pool_alignment; size_t alignment = runtime::cpu::CPU_ExternalFunction::s_memory_pool_alignment;
ngraph::runtime::Allocator* allocator = ngraph::runtime::Allocator* allocator =
new ngraph::runtime::cpu::CPUAllocator(nullptr, nullptr); new ngraph::runtime::cpu::CPUAllocator(m_memory_allocator, m_memory_deallocator);
/*if (m_framework_allocator && m_framework_deallocator)
{
auto fw_allocator =
new ngraph::runtime::FrameworkAllocator(m_framework_allocator, m_framework_deallocator);
allocator = new ngraph::runtime::cpu::CPUAllocator(fw_allocator, alignment);
}
else
{
auto sys_allocator = new ngraph::runtime::SystemAllocator();
allocator = new ngraph::runtime::cpu::CPUAllocator(sys_allocator, alignment);
}*/
for (auto buffer_size : m_external_function->get_memory_buffer_sizes()) for (auto buffer_size : m_external_function->get_memory_buffer_sizes())
{ {
......
...@@ -59,8 +59,8 @@ namespace ngraph ...@@ -59,8 +59,8 @@ namespace ngraph
InitContextFuncCG compiled_init_ctx_func, InitContextFuncCG compiled_init_ctx_func,
DestroyContextFuncCG compiled_destroy_ctx_func, DestroyContextFuncCG compiled_destroy_ctx_func,
EntryPoint compiled_function, EntryPoint compiled_function,
AllocateFunc& framework_allocator, AllocateFunc memory_allocator,
DestroyFunc& framework_deallocator); DestroyFunc memory_deallocator);
~CPU_CallFrame(); ~CPU_CallFrame();
...@@ -90,8 +90,8 @@ namespace ngraph ...@@ -90,8 +90,8 @@ namespace ngraph
CPURuntimeContext* ctx = nullptr; CPURuntimeContext* ctx = nullptr;
// memeber function pointers to hold the framework allocators // memeber function pointers to hold the framework allocators
AllocateFunc m_framework_allocator; AllocateFunc m_memory_allocator;
DestroyFunc m_framework_deallocator; DestroyFunc m_memory_deallocator;
/* Codegen specific */ /* Codegen specific */
......
...@@ -507,6 +507,7 @@ void runtime::cpu::CPU_ExternalFunction::compile(ngraph::pass::PassConfig& pass_ ...@@ -507,6 +507,7 @@ void runtime::cpu::CPU_ExternalFunction::compile(ngraph::pass::PassConfig& pass_
R"( R"(
#include <cmath> #include <cmath>
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/cpu/cpu_eigen_utils.hpp" #include "ngraph/runtime/cpu/cpu_eigen_utils.hpp"
#include "ngraph/runtime/cpu/cpu_kernels.hpp" #include "ngraph/runtime/cpu/cpu_kernels.hpp"
#include "ngraph/runtime/cpu/cpu_runtime_context.hpp" #include "ngraph/runtime/cpu/cpu_runtime_context.hpp"
...@@ -1713,8 +1714,8 @@ void*& runtime::cpu::CPU_ExternalFunction::get_tensor_data(const std::string& na ...@@ -1713,8 +1714,8 @@ void*& runtime::cpu::CPU_ExternalFunction::get_tensor_data(const std::string& na
shared_ptr<ngraph::runtime::cpu::CPU_CallFrame> shared_ptr<ngraph::runtime::cpu::CPU_CallFrame>
runtime::cpu::CPU_ExternalFunction::make_call_frame(ngraph::pass::PassConfig& pass_config, runtime::cpu::CPU_ExternalFunction::make_call_frame(ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator, AllocateFunc memory_allocator,
DestroyFunc& framework_deallocator) DestroyFunc memory_deallocator)
{ {
#if defined(NGRAPH_DEX_ONLY) #if defined(NGRAPH_DEX_ONLY)
if (pass_config.get_compilation_mode() == ngraph::pass::CompilationMode::CODEGEN) if (pass_config.get_compilation_mode() == ngraph::pass::CompilationMode::CODEGEN)
...@@ -1743,8 +1744,8 @@ shared_ptr<ngraph::runtime::cpu::CPU_CallFrame> ...@@ -1743,8 +1744,8 @@ shared_ptr<ngraph::runtime::cpu::CPU_CallFrame>
m_compiled_init_ctx_func, m_compiled_init_ctx_func,
m_compiled_destroy_ctx_func, m_compiled_destroy_ctx_func,
m_compiled_function, m_compiled_function,
framework_allocator, memory_allocator,
framework_deallocator); memory_deallocator);
} }
const runtime::cpu::LayoutDescriptorPtrs& const runtime::cpu::LayoutDescriptorPtrs&
......
...@@ -102,8 +102,8 @@ namespace ngraph ...@@ -102,8 +102,8 @@ namespace ngraph
~CPU_ExternalFunction(); ~CPU_ExternalFunction();
std::shared_ptr<ngraph::runtime::cpu::CPU_CallFrame> std::shared_ptr<ngraph::runtime::cpu::CPU_CallFrame>
make_call_frame(ngraph::pass::PassConfig& pass_config, make_call_frame(ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator, AllocateFunc memory_allocator,
DestroyFunc& framework_deallocator); DestroyFunc memory_deallocator);
const LayoutDescriptorPtrs& get_parameter_layout_descriptors(); const LayoutDescriptorPtrs& get_parameter_layout_descriptors();
const LayoutDescriptorPtrs& get_result_layout_descriptors(); const LayoutDescriptorPtrs& get_result_layout_descriptors();
......
...@@ -43,9 +43,7 @@ TEST(cpu_codegen, abc) ...@@ -43,9 +43,7 @@ TEST(cpu_codegen, abc)
copy_data(c, test::NDArray<float, 2>({{9, 10}, {11, 12}}).get_vector()); copy_data(c, test::NDArray<float, 2>({{9, 10}, {11, 12}}).get_vector());
ngraph::pass::PassConfig pass_config{ngraph::pass::CompilationMode::CODEGEN}; ngraph::pass::PassConfig pass_config{ngraph::pass::CompilationMode::CODEGEN};
runtime::AllocateFunc framework_allocator = nullptr; auto handle = backend->compile(f, pass_config);
runtime::DestroyFunc framework_deallocator = nullptr;
auto handle = backend->compile(f, pass_config, framework_allocator, framework_deallocator);
handle->call_with_validate({result}, {a, b, c}); handle->call_with_validate({result}, {a, b, c});
EXPECT_TRUE(test::all_close_f(read_vector<float>(result), EXPECT_TRUE(test::all_close_f(read_vector<float>(result),
(test::NDArray<float, 2>({{54, 80}, {110, 144}})).get_vector(), (test::NDArray<float, 2>({{54, 80}, {110, 144}})).get_vector(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment