Commit 1cc36521 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Pass Manager returns shared_ptr to each pass created. (#3843)

* Return pass created with register_pass call

* Add test
parent 476dfe14
...@@ -42,13 +42,14 @@ public: ...@@ -42,13 +42,14 @@ public:
~Manager(); ~Manager();
template <typename T, class... Args> template <typename T, class... Args>
void register_pass(Args&&... args) std::shared_ptr<T> register_pass(Args&&... args)
{ {
push_pass<T>(std::forward<Args>(args)...); auto rc = push_pass<T>(std::forward<Args>(args)...);
if (m_per_pass_validation) if (m_per_pass_validation)
{ {
push_pass<Validate>(); push_pass<Validate>();
} }
return rc;
} }
void run_passes(std::shared_ptr<Function>, bool transitive = true); void run_passes(std::shared_ptr<Function>, bool transitive = true);
...@@ -61,7 +62,7 @@ public: ...@@ -61,7 +62,7 @@ public:
void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; } void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
private: private:
template <typename T, class... Args> template <typename T, class... Args>
void push_pass(Args&&... args) std::shared_ptr<T> push_pass(Args&&... args)
{ {
static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base"); static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
auto pass = std::make_shared<T>(std::forward<Args>(args)...); auto pass = std::make_shared<T>(std::forward<Args>(args)...);
...@@ -80,6 +81,7 @@ private: ...@@ -80,6 +81,7 @@ private:
m_pass_names.push_back(typeid(T).name()); m_pass_names.push_back(typeid(T).name());
#endif #endif
} }
return pass;
} }
std::vector<std::string> m_pass_names; std::vector<std::string> m_pass_names;
......
...@@ -62,7 +62,7 @@ TEST(pass_manager, serialize_with_revalidate_does_not_crash) ...@@ -62,7 +62,7 @@ TEST(pass_manager, serialize_with_revalidate_does_not_crash)
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.set_per_pass_validation(true); pass_manager.set_per_pass_validation(true);
pass_manager.set_pass_serialization(true); pass_manager.set_pass_serialization(true);
pass_manager.register_pass<DummyPass>(); shared_ptr<DummyPass> dummy = pass_manager.register_pass<DummyPass>();
auto graph = make_test_graph(); auto graph = make_test_graph();
pass_manager.run_passes(graph); pass_manager.run_passes(graph);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment