Commit 1cc36521 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Pass Manager returns shared_ptr to each pass created. (#3843)

* Return pass created with register_pass call

* Add test
parent 476dfe14
......@@ -42,13 +42,14 @@ public:
~Manager();
template <typename T, class... Args>
void register_pass(Args&&... args)
std::shared_ptr<T> register_pass(Args&&... args)
{
push_pass<T>(std::forward<Args>(args)...);
auto rc = push_pass<T>(std::forward<Args>(args)...);
if (m_per_pass_validation)
{
push_pass<Validate>();
}
return rc;
}
void run_passes(std::shared_ptr<Function>, bool transitive = true);
......@@ -61,7 +62,7 @@ public:
void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
private:
template <typename T, class... Args>
void push_pass(Args&&... args)
std::shared_ptr<T> push_pass(Args&&... args)
{
static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
......@@ -80,6 +81,7 @@ private:
m_pass_names.push_back(typeid(T).name());
#endif
}
return pass;
}
std::vector<std::string> m_pass_names;
......
......@@ -62,7 +62,7 @@ TEST(pass_manager, serialize_with_revalidate_does_not_crash)
pass::Manager pass_manager;
pass_manager.set_per_pass_validation(true);
pass_manager.set_pass_serialization(true);
pass_manager.register_pass<DummyPass>();
shared_ptr<DummyPass> dummy = pass_manager.register_pass<DummyPass>();
auto graph = make_test_graph();
pass_manager.run_passes(graph);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment