Commit 6696c145 authored by Adam Procter's avatar Adam Procter

Fix crash when NGRAPH_ENABLE_{VISUALIZE,SERIALIZE}_TRACING=1

parent 9937f8b5
...@@ -43,6 +43,25 @@ public: ...@@ -43,6 +43,25 @@ public:
template <typename T, class... Args> template <typename T, class... Args>
void register_pass(Args&&... args) void register_pass(Args&&... args)
{
push_pass<T>(std::forward<Args>(args)...);
if (m_per_pass_validation)
{
push_pass<Validate>();
}
}
void run_passes(std::shared_ptr<Function>, bool transitive = true);
ManagerState& get_state();
PassConfig& get_pass_config() { return m_pass_config; }
void set_pass_config(const PassConfig& pass_config) { m_pass_config = pass_config; }
void set_pass_visualization(bool new_state) { m_visualize = new_state; }
void set_pass_serialization(bool new_state) { m_serialize = new_state; }
void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
private:
template <typename T, class... Args>
void push_pass(Args&&... args)
{ {
static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base"); static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
auto pass = std::make_shared<T>(std::forward<Args>(args)...); auto pass = std::make_shared<T>(std::forward<Args>(args)...);
...@@ -61,23 +80,8 @@ public: ...@@ -61,23 +80,8 @@ public:
m_pass_names.push_back(typeid(T).name()); m_pass_names.push_back(typeid(T).name());
#endif #endif
} }
if (m_per_pass_validation)
{
auto validate = std::make_shared<Validate>();
auto validate_base = std::static_pointer_cast<PassBase>(validate);
m_pass_list.push_back(validate_base);
} }
}
void run_passes(std::shared_ptr<Function>, bool transitive = true);
ManagerState& get_state();
PassConfig& get_pass_config() { return m_pass_config; }
void set_pass_config(const PassConfig& pass_config) { m_pass_config = pass_config; }
void set_pass_visualization(bool new_state) { m_visualize = new_state; }
void set_pass_serialization(bool new_state) { m_serialize = new_state; }
void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
private:
std::vector<std::string> m_pass_names; std::vector<std::string> m_pass_names;
std::vector<std::shared_ptr<PassBase>> m_pass_list; std::vector<std::shared_ptr<PassBase>> m_pass_list;
ManagerState m_state; ManagerState m_state;
......
...@@ -41,3 +41,29 @@ TEST(pass_manager, add) ...@@ -41,3 +41,29 @@ TEST(pass_manager, add)
EXPECT_EQ(node_count, sorted.size()); EXPECT_EQ(node_count, sorted.size());
EXPECT_TRUE(validate_list(sorted)); EXPECT_TRUE(validate_list(sorted));
} }
namespace
{
class DummyPass : public pass::FunctionPass
{
public:
DummyPass()
: FunctionPass()
{
}
bool run_on_function(std::shared_ptr<ngraph::Function> f) override { return false; }
};
}
// Regression test: We've had an issue in the past where enabling per-pass validation and
// per-pass serialization at the same time causes a crash.
TEST(pass_manager, serialize_with_revalidate_does_not_crash)
{
pass::Manager pass_manager;
pass_manager.set_per_pass_validation(true);
pass_manager.set_pass_serialization(true);
pass_manager.register_pass<DummyPass>();
auto graph = make_test_graph();
pass_manager.run_passes(graph);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment