Commit 3454cc45 authored by pruthvi's avatar pruthvi

- fkix clang errors

- override compile method in the cpu_backend
parent 7240b3ae
......@@ -42,6 +42,8 @@ vector<string> runtime::Backend::get_registered_devices()
std::shared_ptr<runtime::Executable>
runtime::Backend::compile(std::shared_ptr<Function> func,
ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator,
DestroyFunc& framework_deallocator,
bool enable_performance_data)
{
return compile(func, enable_performance_data);
......
......@@ -24,6 +24,7 @@
#include "ngraph/runtime/performance_counter.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
......@@ -90,6 +91,8 @@ public:
/// \returns compiled function or nullptr on failure
virtual std::shared_ptr<Executable> compile(std::shared_ptr<Function> func,
ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator,
DestroyFunc& framework_deallocator,
bool enable_performance_data = false);
/// \brief Test if a backend is capable of supporting an op
......
......@@ -60,7 +60,7 @@ namespace ngraph
ngraph::pass::PassConfig& pass_config,
AllocateFunc& framework_allocator,
DestroyFunc& framework_deallocator,
bool enable_performance_counters = false);
bool enable_performance_counters = false) override;
void remove_compiled_function(std::shared_ptr<Executable> exec) override;
......
......@@ -38,6 +38,10 @@ namespace ngraph
class NodeMap;
class stopwatch;
// aliases for framework provided function pointers as defined in onnx runtime
using AllocateFunc = void* (*)(void*, size_t, size_t);
using DestroyFunc = void (*)(void*, void*);
namespace runtime
{
class Backend;
......
......@@ -16,6 +16,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/util.hpp"
#include "util/all_close.hpp"
#include "util/ndarray.hpp"
......@@ -43,7 +44,9 @@ TEST(cpu_codegen, abc)
copy_data(c, test::NDArray<float, 2>({{9, 10}, {11, 12}}).get_vector());
ngraph::pass::PassConfig pass_config{ngraph::pass::CompilationMode::CODEGEN};
auto handle = backend->compile(f, pass_config);
AllocateFunc framework_allocator = nullptr;
DestroyFunc framework_deallocator = nullptr;
auto handle = backend->compile(f, pass_config, framework_allocator, framework_deallocator);
handle->call_with_validate({result}, {a, b, c});
EXPECT_EQ(read_vector<float>(result),
(test::NDArray<float, 2>({{54, 80}, {110, 144}})).get_vector());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment