Commit d169f929 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Make private members protected in hybrid classes (#2975)

* make private members protected in hybrid classes

* allow overriding the passes
parent 5650e913
......@@ -46,21 +46,25 @@ runtime::hybrid::HybridExecutable::HybridExecutable(
}
// Run placement pass
ngraph::pass::Manager pass_manager;
configure_passes(pass_manager);
pass_manager.run_passes(m_function);
runtime::hybrid::rewrite_function(m_function, m_backend_list);
m_executable = backend_list[0]->compile(m_function);
set_parameters_and_results(*func);
}
void runtime::hybrid::HybridExecutable::configure_passes(ngraph::pass::Manager& pass_manager)
{
pass_manager.register_pass<runtime::hybrid::pass::DefaultPlacement>(m_backend_list);
pass_manager.register_pass<runtime::hybrid::pass::FixGetOutputElement>();
pass_manager.register_pass<runtime::hybrid::pass::Liveness>();
pass_manager.register_pass<runtime::hybrid::pass::Dump>("graph.dump");
// pass_manager.register_pass<runtime::hybrid::pass::MemoryLayout>();
if (m_debug_enabled)
{
pass_manager.register_pass<ngraph::pass::VisualizeTree>("graph.png", node_modifiers);
pass_manager.register_pass<ngraph::pass::VisualizeTree>("graph.dot", node_modifiers);
}
pass_manager.run_passes(m_function);
runtime::hybrid::rewrite_function(m_function, m_backend_list);
m_executable = backend_list[0]->compile(m_function);
set_parameters_and_results(*func);
}
bool runtime::hybrid::HybridExecutable::call(const vector<shared_ptr<runtime::Tensor>>& outputs,
......
......@@ -21,6 +21,7 @@
#include <string>
#include <vector>
#include "ngraph/pass/manager.hpp"
#include "ngraph/runtime/executable.hpp"
namespace ngraph
......@@ -53,7 +54,11 @@ public:
return std::dynamic_pointer_cast<T>(m_executable);
}
private:
/// Allow overriding the configuration of the pass manager. If you overload this method
/// you must define all passes.
virtual void configure_passes(ngraph::pass::Manager& pass_manager);
protected:
std::shared_ptr<ngraph::Function> m_function;
std::shared_ptr<Executable> m_executable;
std::unordered_map<std::shared_ptr<ngraph::op::Parameter>, std::shared_ptr<ngraph::op::Result>>
......
......@@ -66,7 +66,7 @@ public:
/// \param n Number of bytes to read, must be integral number of elements.
void read(void* p, size_t tensor_offset, size_t n) const override;
private:
protected:
HybridTensor(const HybridTensor&) = delete;
HybridTensor(HybridTensor&&) = delete;
HybridTensor& operator=(const HybridTensor&) = delete;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment