Commit d169f929 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Make private members protected in hybrid classes (#2975)

* make private members protected in hybrid classes

* allow overriding the passes
parent 5650e913
...@@ -46,21 +46,25 @@ runtime::hybrid::HybridExecutable::HybridExecutable( ...@@ -46,21 +46,25 @@ runtime::hybrid::HybridExecutable::HybridExecutable(
} }
// Run placement pass // Run placement pass
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
configure_passes(pass_manager);
pass_manager.run_passes(m_function);
runtime::hybrid::rewrite_function(m_function, m_backend_list);
m_executable = backend_list[0]->compile(m_function);
set_parameters_and_results(*func);
}
void runtime::hybrid::HybridExecutable::configure_passes(ngraph::pass::Manager& pass_manager)
{
pass_manager.register_pass<runtime::hybrid::pass::DefaultPlacement>(m_backend_list); pass_manager.register_pass<runtime::hybrid::pass::DefaultPlacement>(m_backend_list);
pass_manager.register_pass<runtime::hybrid::pass::FixGetOutputElement>(); pass_manager.register_pass<runtime::hybrid::pass::FixGetOutputElement>();
pass_manager.register_pass<runtime::hybrid::pass::Liveness>(); pass_manager.register_pass<runtime::hybrid::pass::Liveness>();
pass_manager.register_pass<runtime::hybrid::pass::Dump>("graph.dump"); pass_manager.register_pass<runtime::hybrid::pass::Dump>("graph.dump");
// pass_manager.register_pass<runtime::hybrid::pass::MemoryLayout>();
if (m_debug_enabled) if (m_debug_enabled)
{ {
pass_manager.register_pass<ngraph::pass::VisualizeTree>("graph.png", node_modifiers); pass_manager.register_pass<ngraph::pass::VisualizeTree>("graph.dot", node_modifiers);
} }
pass_manager.run_passes(m_function);
runtime::hybrid::rewrite_function(m_function, m_backend_list);
m_executable = backend_list[0]->compile(m_function);
set_parameters_and_results(*func);
} }
bool runtime::hybrid::HybridExecutable::call(const vector<shared_ptr<runtime::Tensor>>& outputs, bool runtime::hybrid::HybridExecutable::call(const vector<shared_ptr<runtime::Tensor>>& outputs,
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "ngraph/pass/manager.hpp"
#include "ngraph/runtime/executable.hpp" #include "ngraph/runtime/executable.hpp"
namespace ngraph namespace ngraph
...@@ -53,7 +54,11 @@ public: ...@@ -53,7 +54,11 @@ public:
return std::dynamic_pointer_cast<T>(m_executable); return std::dynamic_pointer_cast<T>(m_executable);
} }
private: /// Allow overriding the configuration of the pass manager. If you overload this method
/// you must define all passes.
virtual void configure_passes(ngraph::pass::Manager& pass_manager);
protected:
std::shared_ptr<ngraph::Function> m_function; std::shared_ptr<ngraph::Function> m_function;
std::shared_ptr<Executable> m_executable; std::shared_ptr<Executable> m_executable;
std::unordered_map<std::shared_ptr<ngraph::op::Parameter>, std::shared_ptr<ngraph::op::Result>> std::unordered_map<std::shared_ptr<ngraph::op::Parameter>, std::shared_ptr<ngraph::op::Result>>
......
...@@ -66,7 +66,7 @@ public: ...@@ -66,7 +66,7 @@ public:
/// \param n Number of bytes to read, must be integral number of elements. /// \param n Number of bytes to read, must be integral number of elements.
void read(void* p, size_t tensor_offset, size_t n) const override; void read(void* p, size_t tensor_offset, size_t n) const override;
private: protected:
HybridTensor(const HybridTensor&) = delete; HybridTensor(const HybridTensor&) = delete;
HybridTensor(HybridTensor&&) = delete; HybridTensor(HybridTensor&&) = delete;
HybridTensor& operator=(const HybridTensor&) = delete; HybridTensor& operator=(const HybridTensor&) = delete;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment