Commit 17fb22a9 authored by Rob Earhart's avatar Rob Earhart Committed by Robert Kimball

Update PlaidML backend to current API (#2445)

parent 75964a23
......@@ -18,10 +18,10 @@ set(SRC
plaidml_backend.cpp
plaidml_builder.cpp
plaidml_compilation_cache.cpp
plaidml_compiled_function.cpp
plaidml_compiler.cpp
plaidml_config.cpp
plaidml_convpool_formatter.cpp
plaidml_executable.cpp
plaidml_impl.cpp
plaidml_logger.cpp
plaidml_ops_arithmetic.cpp
......
......@@ -16,7 +16,7 @@
#include "ngraph/runtime/plaidml/plaidml_backend.hpp"
#include "ngraph/node.hpp"
#include "ngraph/runtime/plaidml/plaidml_compiled_function.hpp"
#include "ngraph/runtime/plaidml/plaidml_executable.hpp"
#include "ngraph/runtime/plaidml/plaidml_tensor.hpp"
#include "ngraph/util.hpp"
......@@ -42,39 +42,31 @@ std::shared_ptr<ngraph::runtime::Tensor> ngraph::runtime::plaidml::PlaidML_Backe
this, &m_config, element_type, shape, "direct_data", memory_pointer);
}
std::shared_ptr<ngraph::Function>
ngraph::runtime::plaidml::PlaidML_Backend::compile(std::shared_ptr<Function> func)
std::shared_ptr<ngraph::runtime::Executable>
ngraph::runtime::plaidml::PlaidML_Backend::compile(std::shared_ptr<Function> func,
bool /* enable_performance_data */)
{
m_cache.compile(func, &m_compiler);
return func;
return m_cache.compile(std::move(func), &m_compiler);
}
bool ngraph::runtime::plaidml::PlaidML_Backend::call(
std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs)
bool ngraph::runtime::plaidml::PlaidML_Backend::is_supported(const Node& node) const
{
auto cfunc = m_cache.compile(func, &m_compiler);
cfunc->schedule_invocation(inputs, outputs);
return true;
return m_compiler.is_supported(node);
}
void ngraph::runtime::plaidml::PlaidML_Backend::remove_compiled_function(
std::shared_ptr<Function> func)
bool ngraph::runtime::plaidml::PlaidML_Backend::is_supported_property(const Property prop) const
{
m_cache.forget(func);
return false;
}
void ngraph::runtime::plaidml::PlaidML_Backend::save(std::shared_ptr<Function> func,
const std::string& filename,
plaidml_file_format format)
void ngraph::runtime::plaidml::PlaidML_Backend::remove_compiled_function(
std::shared_ptr<Executable> exec)
{
auto cfunc = m_cache.try_lookup(func);
if (!cfunc)
auto plaidml_exec = std::dynamic_pointer_cast<PlaidML_Executable>(std::move(exec));
if (plaidml_exec)
{
cfunc = m_compiler.compile(func);
m_cache.forget(std::move(plaidml_exec));
}
cfunc->save(filename, format);
}
extern "C" const char* get_ngraph_version_string()
......
......@@ -22,6 +22,7 @@
#include "ngraph/runtime/plaidml/plaidml_compilation_cache.hpp"
#include "ngraph/runtime/plaidml/plaidml_compiler.hpp"
#include "ngraph/runtime/plaidml/plaidml_config.hpp"
#include "ngraph/runtime/plaidml/plaidml_executable.hpp"
namespace ngraph
{
......@@ -46,17 +47,17 @@ public:
std::shared_ptr<ngraph::runtime::Tensor> create_tensor(
const ngraph::element::Type& element_type, const Shape& shape, void* memory_pointer) final;
std::shared_ptr<Function> compile(std::shared_ptr<Function> func) final;
// N.B. The returned Executable will always be an instance of
// PlaidML_Executable, and may be safely converted via static
// casting.
std::shared_ptr<Executable> compile(std::shared_ptr<Function> func,
bool enable_performance_data = false) final;
bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) final;
bool is_supported(const Node& node) const final;
void remove_compiled_function(std::shared_ptr<Function> func) final;
bool is_supported_property(const Property prop) const final;
void save(std::shared_ptr<Function> func,
const std::string& filename,
plaidml_file_format format);
void remove_compiled_function(std::shared_ptr<Executable> exec) final;
private:
Config m_config;
......
......@@ -16,24 +16,12 @@
#include "ngraph/runtime/plaidml/plaidml_compilation_cache.hpp"
std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction>
ngraph::runtime::plaidml::CompilationCache::try_lookup(std::shared_ptr<Function> func)
{
std::lock_guard<std::mutex> lock{m_mu};
auto it = m_cache.find(func);
if (it != m_cache.end())
{
return it->second;
}
return std::shared_ptr<CompiledFunction>{};
}
std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction>
std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
ngraph::runtime::plaidml::CompilationCache::compile(std::shared_ptr<Function> func,
Compiler* compiler)
{
std::lock_guard<std::mutex> lock{m_mu};
auto it_inserted = m_cache.insert(std::make_pair(func, std::shared_ptr<CompiledFunction>{}));
auto it_inserted = m_cache.insert(std::make_pair(func, std::shared_ptr<PlaidML_Executable>{}));
if (it_inserted.second)
{
try
......@@ -49,8 +37,8 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction>
return it_inserted.first->second;
}
void ngraph::runtime::plaidml::CompilationCache::forget(std::shared_ptr<Function> func)
void ngraph::runtime::plaidml::CompilationCache::forget(std::shared_ptr<PlaidML_Executable> exec)
{
std::lock_guard<std::mutex> lock{m_mu};
m_cache.erase(func);
m_cache.erase(exec->src_func());
}
......@@ -21,8 +21,8 @@
#include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/runtime/plaidml/plaidml_compiled_function.hpp"
#include "ngraph/runtime/plaidml/plaidml_compiler.hpp"
#include "ngraph/runtime/plaidml/plaidml_executable.hpp"
namespace ngraph
{
......@@ -39,19 +39,17 @@ namespace ngraph
class ngraph::runtime::plaidml::CompilationCache final
{
public:
// Looks up the supplied function in the compilation cache. If the function is not in the
// cache, returns an empty pointer.
std::shared_ptr<CompiledFunction> try_lookup(std::shared_ptr<Function> func);
// Looks up the supplied function in the compilation cache. If the function is not in the
// cache, compiles it using the specified compiler (which must not be nullptr), adds the
// compiled function to the cache, and returns the compiled function.
std::shared_ptr<CompiledFunction> compile(std::shared_ptr<Function> func, Compiler* compiler);
std::shared_ptr<PlaidML_Executable> compile(std::shared_ptr<Function> func, Compiler* compiler);
// Drops the supplied function's compiled function from the compilation cache.
void forget(std::shared_ptr<Function> func);
void forget(std::shared_ptr<PlaidML_Executable> func);
private:
std::mutex m_mu;
std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<CompiledFunction>> m_cache;
// N.B. The key here is the original source function, *not* the copy that's been processed by the compilation passes.
std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<PlaidML_Executable>> m_cache;
};
......@@ -73,7 +73,7 @@ ngraph::runtime::plaidml::Compiler::Compiler(Config* config)
{
}
std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction>
std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
ngraph::runtime::plaidml::Compiler::compile(std::shared_ptr<Function> func)
{
// N.B. ngraph::pass::Manager::run_passes() is *not* a const
......@@ -121,15 +121,20 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction>
// The caller may wish to perform operations (e.g. clone) on their
// supplied function that will cause validation to occur. So
// before we rewrite, we make our own copy of the function.
func = clone_function(*func);
auto rewrite_func = clone_function(*func);
// Apply passes.
pass_manager.run_passes(func);
pass_manager.run_passes(rewrite_func);
// Compile the resulting function.
Build b;
build(std::move(func), &b);
return std::make_shared<CompiledFunction>(std::move(b));
build(std::move(rewrite_func), &b);
return std::make_shared<PlaidML_Executable>(std::move(b), std::move(func));
}
bool ngraph::runtime::plaidml::Compiler::is_supported(const Node& node) const
{
return GlobalOpImplMap()->count(std::type_index(typeid(node))) != 0;
}
void ngraph::runtime::plaidml::Compiler::build(std::shared_ptr<Function> func, Build* b)
......
......@@ -21,8 +21,8 @@
#include <plaidml/plaidml++.h>
#include "ngraph/function.hpp"
#include "ngraph/runtime/plaidml/plaidml_compiled_function.hpp"
#include "ngraph/runtime/plaidml/plaidml_config.hpp"
#include "ngraph/runtime/plaidml/plaidml_executable.hpp"
namespace ngraph
{
......@@ -42,10 +42,12 @@ class ngraph::runtime::plaidml::Compiler final
public:
Compiler(Config* config);
std::shared_ptr<CompiledFunction> compile(std::shared_ptr<Function> func);
std::shared_ptr<PlaidML_Executable> compile(std::shared_ptr<Function> func);
void build(std::shared_ptr<Function> func, Build* build);
bool is_supported(const Node& node) const;
private:
void build(std::shared_ptr<Function> func, Build* build);
Config* m_config;
};
......@@ -18,25 +18,28 @@
#include "ngraph/log.hpp"
#include "ngraph/runtime/plaidml/plaidml_build.hpp"
#include "ngraph/runtime/plaidml/plaidml_compiled_function.hpp"
#include "ngraph/runtime/plaidml/plaidml_executable.hpp"
#include "ngraph/runtime/plaidml/plaidml_tensor.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp"
namespace vp = vertexai::plaidml;
ngraph::runtime::plaidml::CompiledFunction::CompiledFunction(Build build)
ngraph::runtime::plaidml::PlaidML_Executable::PlaidML_Executable(Build build,
std::shared_ptr<Function> func)
: m_config{build.config}
, m_func{std::move(build.func)}
, m_src_func{std::move(func)}
, m_input_names{std::move(build.input_names)}
, m_output_names{std::move(build.output_names)}
, m_invoker{build.config->ctx, std::move(build.composer)}
{
set_parameters_and_results(*m_func);
NGRAPH_DEBUG << "Compiled PlaidML function " << this;
}
bool ngraph::runtime::plaidml::CompiledFunction::schedule_invocation(
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs) const
bool ngraph::runtime::plaidml::PlaidML_Executable::call(
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs)
{
std::lock_guard<std::mutex> lock{m_mu};
......@@ -120,8 +123,14 @@ bool ngraph::runtime::plaidml::CompiledFunction::schedule_invocation(
return true;
}
void ngraph::runtime::plaidml::CompiledFunction::save(const std::string& filename,
plaidml_file_format format) const
std::vector<ngraph::runtime::PerformanceCounter>
ngraph::runtime::plaidml::PlaidML_Executable::get_performance_data() const
{
return std::vector<ngraph::runtime::PerformanceCounter>{};
}
void ngraph::runtime::plaidml::PlaidML_Executable::save(const std::string& filename,
plaidml_file_format format) const
{
std::lock_guard<std::mutex> lock{m_mu};
......
......@@ -24,7 +24,7 @@
#include <plaidml/plaidml++.h>
#include "ngraph/function.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/plaidml/plaidml_config.hpp"
#include "ngraph/runtime/tensor.hpp"
......@@ -35,27 +35,31 @@ namespace ngraph
namespace plaidml
{
struct Build;
class CompiledFunction;
class PlaidML_Executable;
}
}
}
// A PlaidML compiled function object produced by compiling an nGraph function.
class ngraph::runtime::plaidml::CompiledFunction final
// A PlaidML executable object produced by compiling an nGraph function.
class ngraph::runtime::plaidml::PlaidML_Executable final : public Executable
{
public:
CompiledFunction(Build build);
PlaidML_Executable(Build build, std::shared_ptr<Function> func);
virtual ~PlaidML_Executable() {}
bool call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) final;
bool schedule_invocation(const std::vector<std::shared_ptr<runtime::Tensor>>& inputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs) const;
std::vector<PerformanceCounter> get_performance_data() const final;
void save(const std::string& filename, plaidml_file_format format) const;
const std::shared_ptr<Function>& src_func() const { return m_src_func; }
private:
mutable std::mutex m_mu; // Locks the invoker while scheduling invocations.
mutable bool m_bound = false;
Config* m_config;
std::shared_ptr<Function> m_func;
std::shared_ptr<Function> m_src_func; // The original source function.
std::unordered_map<descriptor::Tensor*, std::string> m_input_names;
std::unordered_map<descriptor::Tensor*, std::string> m_output_names;
mutable std::vector<std::weak_ptr<runtime::Tensor>> m_bound_inputs;
......
......@@ -142,7 +142,9 @@ OPTIONS
return EXIT_FAILURE;
}
backend->save(f, output, format);
auto exec = backend->compile(f);
static_cast<ngraph::runtime::plaidml::PlaidML_Executable*>(exec.get())->save(output, format);
std::cerr << "Wrote output to " << output << "\n";
return EXIT_SUCCESS;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment