Commit 4f26509c authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Workaround to allow BEs to disable revalidation in pass manager (#3254)

* Workaround to allow BEs to disable revalidation in pass manager

* No no no, your _other_ `false`.
parent b3602cf6
...@@ -57,7 +57,7 @@ void pass::Manager::initialize_default_passes() ...@@ -57,7 +57,7 @@ void pass::Manager::initialize_default_passes()
{ {
} }
void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive) void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive, bool revalidate)
{ {
bool profile_enabled = getenv("NGRAPH_PROFILE_PASS_ENABLE") != nullptr; bool profile_enabled = getenv("NGRAPH_PROFILE_PASS_ENABLE") != nullptr;
...@@ -139,11 +139,14 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive) ...@@ -139,11 +139,14 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
} }
// Better to do this in node replacement but this will do for now // Better to do this in node replacement but this will do for now
if (revalidate)
{
for (auto f_pair : fs) for (auto f_pair : fs)
{ {
shared_ptr<Function> f = f_pair.first; shared_ptr<Function> f = f_pair.first;
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
} }
}
if (m_visualize || m_serialize) if (m_visualize || m_serialize)
{ {
......
...@@ -64,7 +64,7 @@ public: ...@@ -64,7 +64,7 @@ public:
} }
} }
void run_passes(std::shared_ptr<Function>, bool transitive = true); void run_passes(std::shared_ptr<Function>, bool transitive = true, bool revalidate = true);
ManagerState& get_state(); ManagerState& get_state();
PassConfig& get_pass_config() { return m_pass_config; } PassConfig& get_pass_config() { return m_pass_config; }
......
...@@ -123,8 +123,8 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable> ...@@ -123,8 +123,8 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
// before we rewrite, we make our own copy of the function. // before we rewrite, we make our own copy of the function.
auto rewrite_func = clone_function(*func); auto rewrite_func = clone_function(*func);
// Apply passes. // Apply passes, with revalidation disabled.
pass_manager.run_passes(rewrite_func); pass_manager.run_passes(rewrite_func, true, false);
// Compile the resulting function. // Compile the resulting function.
Build b; Build b;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment