Commit cc546c80 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

add method to get underlying executable (#2630)

* add method to get underlying executable

* Copy the function because copy cleans up the graph linkage/users

* style
parent 56e160ba
...@@ -36,10 +36,14 @@ runtime::hybrid::HybridExecutable::HybridExecutable( ...@@ -36,10 +36,14 @@ runtime::hybrid::HybridExecutable::HybridExecutable(
const shared_ptr<Function>& func, const shared_ptr<Function>& func,
bool enable_performance_collection, bool enable_performance_collection,
bool debug_enabled) bool debug_enabled)
: m_function{func} : m_function{clone_function(*func)}
, m_backend_list{backend_list} , m_backend_list{backend_list}
, m_debug_enabled{debug_enabled} , m_debug_enabled{debug_enabled}
{ {
if (backend_list.size() == 0)
{
throw runtime_error("Hybrid Executable constructed with zero-sized backend list");
}
// Run placement pass // Run placement pass
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
pass_manager.register_pass<runtime::hybrid::pass::DefaultPlacement>(m_backend_list); pass_manager.register_pass<runtime::hybrid::pass::DefaultPlacement>(m_backend_list);
......
...@@ -31,8 +31,8 @@ namespace ngraph ...@@ -31,8 +31,8 @@ namespace ngraph
{ {
class HybridExecutable; class HybridExecutable;
} }
} } // namespace runtime
} } // namespace ngraph
class ngraph::runtime::hybrid::HybridExecutable : public runtime::Executable class ngraph::runtime::hybrid::HybridExecutable : public runtime::Executable
{ {
...@@ -45,6 +45,12 @@ public: ...@@ -45,6 +45,12 @@ public:
bool call(const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& outputs, bool call(const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& inputs) override; const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& inputs) override;
template <typename T>
std::shared_ptr<T> get_as() const
{
return std::dynamic_pointer_cast<T>(m_executable);
}
private: private:
std::shared_ptr<ngraph::Function> m_function; std::shared_ptr<ngraph::Function> m_function;
std::shared_ptr<Executable> m_executable; std::shared_ptr<Executable> m_executable;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment