Commit 17fb22a9 authored by Rob Earhart's avatar Rob Earhart Committed by Robert Kimball

Update PlaidML backend to current API (#2445)

parent 75964a23
...@@ -18,10 +18,10 @@ set(SRC ...@@ -18,10 +18,10 @@ set(SRC
plaidml_backend.cpp plaidml_backend.cpp
plaidml_builder.cpp plaidml_builder.cpp
plaidml_compilation_cache.cpp plaidml_compilation_cache.cpp
plaidml_compiled_function.cpp
plaidml_compiler.cpp plaidml_compiler.cpp
plaidml_config.cpp plaidml_config.cpp
plaidml_convpool_formatter.cpp plaidml_convpool_formatter.cpp
plaidml_executable.cpp
plaidml_impl.cpp plaidml_impl.cpp
plaidml_logger.cpp plaidml_logger.cpp
plaidml_ops_arithmetic.cpp plaidml_ops_arithmetic.cpp
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "ngraph/runtime/plaidml/plaidml_backend.hpp" #include "ngraph/runtime/plaidml/plaidml_backend.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/runtime/plaidml/plaidml_compiled_function.hpp" #include "ngraph/runtime/plaidml/plaidml_executable.hpp"
#include "ngraph/runtime/plaidml/plaidml_tensor.hpp" #include "ngraph/runtime/plaidml/plaidml_tensor.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -42,39 +42,31 @@ std::shared_ptr<ngraph::runtime::Tensor> ngraph::runtime::plaidml::PlaidML_Backe ...@@ -42,39 +42,31 @@ std::shared_ptr<ngraph::runtime::Tensor> ngraph::runtime::plaidml::PlaidML_Backe
this, &m_config, element_type, shape, "direct_data", memory_pointer); this, &m_config, element_type, shape, "direct_data", memory_pointer);
} }
std::shared_ptr<ngraph::Function> std::shared_ptr<ngraph::runtime::Executable>
ngraph::runtime::plaidml::PlaidML_Backend::compile(std::shared_ptr<Function> func) ngraph::runtime::plaidml::PlaidML_Backend::compile(std::shared_ptr<Function> func,
bool /* enable_performance_data */)
{ {
m_cache.compile(func, &m_compiler); return m_cache.compile(std::move(func), &m_compiler);
return func;
} }
bool ngraph::runtime::plaidml::PlaidML_Backend::call( bool ngraph::runtime::plaidml::PlaidML_Backend::is_supported(const Node& node) const
std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs)
{ {
auto cfunc = m_cache.compile(func, &m_compiler); return m_compiler.is_supported(node);
cfunc->schedule_invocation(inputs, outputs);
return true;
} }
void ngraph::runtime::plaidml::PlaidML_Backend::remove_compiled_function( bool ngraph::runtime::plaidml::PlaidML_Backend::is_supported_property(const Property prop) const
std::shared_ptr<Function> func)
{ {
m_cache.forget(func); return false;
} }
void ngraph::runtime::plaidml::PlaidML_Backend::save(std::shared_ptr<Function> func, void ngraph::runtime::plaidml::PlaidML_Backend::remove_compiled_function(
const std::string& filename, std::shared_ptr<Executable> exec)
plaidml_file_format format)
{ {
auto cfunc = m_cache.try_lookup(func); auto plaidml_exec = std::dynamic_pointer_cast<PlaidML_Executable>(std::move(exec));
if (!cfunc) if (plaidml_exec)
{ {
cfunc = m_compiler.compile(func); m_cache.forget(std::move(plaidml_exec));
} }
cfunc->save(filename, format);
} }
extern "C" const char* get_ngraph_version_string() extern "C" const char* get_ngraph_version_string()
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ngraph/runtime/plaidml/plaidml_compilation_cache.hpp" #include "ngraph/runtime/plaidml/plaidml_compilation_cache.hpp"
#include "ngraph/runtime/plaidml/plaidml_compiler.hpp" #include "ngraph/runtime/plaidml/plaidml_compiler.hpp"
#include "ngraph/runtime/plaidml/plaidml_config.hpp" #include "ngraph/runtime/plaidml/plaidml_config.hpp"
#include "ngraph/runtime/plaidml/plaidml_executable.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -46,17 +47,17 @@ public: ...@@ -46,17 +47,17 @@ public:
std::shared_ptr<ngraph::runtime::Tensor> create_tensor( std::shared_ptr<ngraph::runtime::Tensor> create_tensor(
const ngraph::element::Type& element_type, const Shape& shape, void* memory_pointer) final; const ngraph::element::Type& element_type, const Shape& shape, void* memory_pointer) final;
std::shared_ptr<Function> compile(std::shared_ptr<Function> func) final; // N.B. The returned Executable will always be an instance of
// PlaidML_Executable, and may be safely converted via static
// casting.
std::shared_ptr<Executable> compile(std::shared_ptr<Function> func,
bool enable_performance_data = false) final;
bool call(std::shared_ptr<Function> func, bool is_supported(const Node& node) const final;
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) final;
void remove_compiled_function(std::shared_ptr<Function> func) final; bool is_supported_property(const Property prop) const final;
void save(std::shared_ptr<Function> func, void remove_compiled_function(std::shared_ptr<Executable> exec) final;
const std::string& filename,
plaidml_file_format format);
private: private:
Config m_config; Config m_config;
......
...@@ -16,24 +16,12 @@ ...@@ -16,24 +16,12 @@
#include "ngraph/runtime/plaidml/plaidml_compilation_cache.hpp" #include "ngraph/runtime/plaidml/plaidml_compilation_cache.hpp"
std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction> std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
ngraph::runtime::plaidml::CompilationCache::try_lookup(std::shared_ptr<Function> func)
{
std::lock_guard<std::mutex> lock{m_mu};
auto it = m_cache.find(func);
if (it != m_cache.end())
{
return it->second;
}
return std::shared_ptr<CompiledFunction>{};
}
std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction>
ngraph::runtime::plaidml::CompilationCache::compile(std::shared_ptr<Function> func, ngraph::runtime::plaidml::CompilationCache::compile(std::shared_ptr<Function> func,
Compiler* compiler) Compiler* compiler)
{ {
std::lock_guard<std::mutex> lock{m_mu}; std::lock_guard<std::mutex> lock{m_mu};
auto it_inserted = m_cache.insert(std::make_pair(func, std::shared_ptr<CompiledFunction>{})); auto it_inserted = m_cache.insert(std::make_pair(func, std::shared_ptr<PlaidML_Executable>{}));
if (it_inserted.second) if (it_inserted.second)
{ {
try try
...@@ -49,8 +37,8 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction> ...@@ -49,8 +37,8 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction>
return it_inserted.first->second; return it_inserted.first->second;
} }
void ngraph::runtime::plaidml::CompilationCache::forget(std::shared_ptr<Function> func) void ngraph::runtime::plaidml::CompilationCache::forget(std::shared_ptr<PlaidML_Executable> exec)
{ {
std::lock_guard<std::mutex> lock{m_mu}; std::lock_guard<std::mutex> lock{m_mu};
m_cache.erase(func); m_cache.erase(exec->src_func());
} }
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
#include <unordered_map> #include <unordered_map>
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/runtime/plaidml/plaidml_compiled_function.hpp"
#include "ngraph/runtime/plaidml/plaidml_compiler.hpp" #include "ngraph/runtime/plaidml/plaidml_compiler.hpp"
#include "ngraph/runtime/plaidml/plaidml_executable.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -39,19 +39,17 @@ namespace ngraph ...@@ -39,19 +39,17 @@ namespace ngraph
class ngraph::runtime::plaidml::CompilationCache final class ngraph::runtime::plaidml::CompilationCache final
{ {
public: public:
// Looks up the supplied function in the compilation cache. If the function is not in the
// cache, returns an empty pointer.
std::shared_ptr<CompiledFunction> try_lookup(std::shared_ptr<Function> func);
// Looks up the supplied function in the compilation cache. If the function is not in the // Looks up the supplied function in the compilation cache. If the function is not in the
// cache, compiles it using the specified compiler (which must not be nullptr), adds the // cache, compiles it using the specified compiler (which must not be nullptr), adds the
// compiled function to the cache, and returns the compiled function. // compiled function to the cache, and returns the compiled function.
std::shared_ptr<CompiledFunction> compile(std::shared_ptr<Function> func, Compiler* compiler); std::shared_ptr<PlaidML_Executable> compile(std::shared_ptr<Function> func, Compiler* compiler);
// Drops the supplied function's compiled function from the compilation cache. // Drops the supplied function's compiled function from the compilation cache.
void forget(std::shared_ptr<Function> func); void forget(std::shared_ptr<PlaidML_Executable> func);
private: private:
std::mutex m_mu; std::mutex m_mu;
std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<CompiledFunction>> m_cache;
// N.B. The key here is the original source function, *not* the copy that's been processed by the compilation passes.
std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<PlaidML_Executable>> m_cache;
}; };
...@@ -73,7 +73,7 @@ ngraph::runtime::plaidml::Compiler::Compiler(Config* config) ...@@ -73,7 +73,7 @@ ngraph::runtime::plaidml::Compiler::Compiler(Config* config)
{ {
} }
std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction> std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
ngraph::runtime::plaidml::Compiler::compile(std::shared_ptr<Function> func) ngraph::runtime::plaidml::Compiler::compile(std::shared_ptr<Function> func)
{ {
// N.B. ngraph::pass::Manager::run_passes() is *not* a const // N.B. ngraph::pass::Manager::run_passes() is *not* a const
...@@ -121,15 +121,20 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction> ...@@ -121,15 +121,20 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction>
// The caller may wish to perform operations (e.g. clone) on their // The caller may wish to perform operations (e.g. clone) on their
// supplied function that will cause validation to occur. So // supplied function that will cause validation to occur. So
// before we rewrite, we make our own copy of the function. // before we rewrite, we make our own copy of the function.
func = clone_function(*func); auto rewrite_func = clone_function(*func);
// Apply passes. // Apply passes.
pass_manager.run_passes(func); pass_manager.run_passes(rewrite_func);
// Compile the resulting function. // Compile the resulting function.
Build b; Build b;
build(std::move(func), &b); build(std::move(rewrite_func), &b);
return std::make_shared<CompiledFunction>(std::move(b)); return std::make_shared<PlaidML_Executable>(std::move(b), std::move(func));
}
bool ngraph::runtime::plaidml::Compiler::is_supported(const Node& node) const
{
return GlobalOpImplMap()->count(std::type_index(typeid(node))) != 0;
} }
void ngraph::runtime::plaidml::Compiler::build(std::shared_ptr<Function> func, Build* b) void ngraph::runtime::plaidml::Compiler::build(std::shared_ptr<Function> func, Build* b)
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
#include <plaidml/plaidml++.h> #include <plaidml/plaidml++.h>
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/runtime/plaidml/plaidml_compiled_function.hpp"
#include "ngraph/runtime/plaidml/plaidml_config.hpp" #include "ngraph/runtime/plaidml/plaidml_config.hpp"
#include "ngraph/runtime/plaidml/plaidml_executable.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -42,10 +42,12 @@ class ngraph::runtime::plaidml::Compiler final ...@@ -42,10 +42,12 @@ class ngraph::runtime::plaidml::Compiler final
public: public:
Compiler(Config* config); Compiler(Config* config);
std::shared_ptr<CompiledFunction> compile(std::shared_ptr<Function> func); std::shared_ptr<PlaidML_Executable> compile(std::shared_ptr<Function> func);
void build(std::shared_ptr<Function> func, Build* build); bool is_supported(const Node& node) const;
private: private:
void build(std::shared_ptr<Function> func, Build* build);
Config* m_config; Config* m_config;
}; };
...@@ -18,25 +18,28 @@ ...@@ -18,25 +18,28 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/runtime/plaidml/plaidml_build.hpp" #include "ngraph/runtime/plaidml/plaidml_build.hpp"
#include "ngraph/runtime/plaidml/plaidml_compiled_function.hpp" #include "ngraph/runtime/plaidml/plaidml_executable.hpp"
#include "ngraph/runtime/plaidml/plaidml_tensor.hpp" #include "ngraph/runtime/plaidml/plaidml_tensor.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp" #include "ngraph/runtime/plaidml/plaidml_translate.hpp"
namespace vp = vertexai::plaidml; namespace vp = vertexai::plaidml;
ngraph::runtime::plaidml::CompiledFunction::CompiledFunction(Build build) ngraph::runtime::plaidml::PlaidML_Executable::PlaidML_Executable(Build build,
std::shared_ptr<Function> func)
: m_config{build.config} : m_config{build.config}
, m_func{std::move(build.func)} , m_func{std::move(build.func)}
, m_src_func{std::move(func)}
, m_input_names{std::move(build.input_names)} , m_input_names{std::move(build.input_names)}
, m_output_names{std::move(build.output_names)} , m_output_names{std::move(build.output_names)}
, m_invoker{build.config->ctx, std::move(build.composer)} , m_invoker{build.config->ctx, std::move(build.composer)}
{ {
set_parameters_and_results(*m_func);
NGRAPH_DEBUG << "Compiled PlaidML function " << this; NGRAPH_DEBUG << "Compiled PlaidML function " << this;
} }
bool ngraph::runtime::plaidml::CompiledFunction::schedule_invocation( bool ngraph::runtime::plaidml::PlaidML_Executable::call(
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs, const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs) const const std::vector<std::shared_ptr<runtime::Tensor>>& inputs)
{ {
std::lock_guard<std::mutex> lock{m_mu}; std::lock_guard<std::mutex> lock{m_mu};
...@@ -120,7 +123,13 @@ bool ngraph::runtime::plaidml::CompiledFunction::schedule_invocation( ...@@ -120,7 +123,13 @@ bool ngraph::runtime::plaidml::CompiledFunction::schedule_invocation(
return true; return true;
} }
void ngraph::runtime::plaidml::CompiledFunction::save(const std::string& filename, std::vector<ngraph::runtime::PerformanceCounter>
ngraph::runtime::plaidml::PlaidML_Executable::get_performance_data() const
{
return std::vector<ngraph::runtime::PerformanceCounter>{};
}
void ngraph::runtime::plaidml::PlaidML_Executable::save(const std::string& filename,
plaidml_file_format format) const plaidml_file_format format) const
{ {
std::lock_guard<std::mutex> lock{m_mu}; std::lock_guard<std::mutex> lock{m_mu};
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <plaidml/plaidml++.h> #include <plaidml/plaidml++.h>
#include "ngraph/function.hpp" #include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/plaidml/plaidml_config.hpp" #include "ngraph/runtime/plaidml/plaidml_config.hpp"
#include "ngraph/runtime/tensor.hpp" #include "ngraph/runtime/tensor.hpp"
...@@ -35,27 +35,31 @@ namespace ngraph ...@@ -35,27 +35,31 @@ namespace ngraph
namespace plaidml namespace plaidml
{ {
struct Build; struct Build;
class CompiledFunction; class PlaidML_Executable;
} }
} }
} }
// A PlaidML compiled function object produced by compiling an nGraph function. // A PlaidML executable object produced by compiling an nGraph function.
class ngraph::runtime::plaidml::CompiledFunction final class ngraph::runtime::plaidml::PlaidML_Executable final : public Executable
{ {
public: public:
CompiledFunction(Build build); PlaidML_Executable(Build build, std::shared_ptr<Function> func);
virtual ~PlaidML_Executable() {}
bool call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) final;
bool schedule_invocation(const std::vector<std::shared_ptr<runtime::Tensor>>& inputs, std::vector<PerformanceCounter> get_performance_data() const final;
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs) const;
void save(const std::string& filename, plaidml_file_format format) const; void save(const std::string& filename, plaidml_file_format format) const;
const std::shared_ptr<Function>& src_func() const { return m_src_func; }
private: private:
mutable std::mutex m_mu; // Locks the invoker while scheduling invocations. mutable std::mutex m_mu; // Locks the invoker while scheduling invocations.
mutable bool m_bound = false; mutable bool m_bound = false;
Config* m_config; Config* m_config;
std::shared_ptr<Function> m_func; std::shared_ptr<Function> m_func;
std::shared_ptr<Function> m_src_func; // The original source function.
std::unordered_map<descriptor::Tensor*, std::string> m_input_names; std::unordered_map<descriptor::Tensor*, std::string> m_input_names;
std::unordered_map<descriptor::Tensor*, std::string> m_output_names; std::unordered_map<descriptor::Tensor*, std::string> m_output_names;
mutable std::vector<std::weak_ptr<runtime::Tensor>> m_bound_inputs; mutable std::vector<std::weak_ptr<runtime::Tensor>> m_bound_inputs;
......
...@@ -142,7 +142,9 @@ OPTIONS ...@@ -142,7 +142,9 @@ OPTIONS
return EXIT_FAILURE; return EXIT_FAILURE;
} }
backend->save(f, output, format); auto exec = backend->compile(f);
static_cast<ngraph::runtime::plaidml::PlaidML_Executable*>(exec.get())->save(output, format);
std::cerr << "Wrote output to " << output << "\n"; std::cerr << "Wrote output to " << output << "\n";
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment