Commit c693cb7e authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Make validation a pass and add it after every pass by default (#3296)

* Make validation a pass and add it after every pass by default

* cleanup

* update per review comments

* Switch plaid to new API for disabling  pass validation

* address review comment
parent aedd8c2e
...@@ -433,6 +433,8 @@ set (SRC ...@@ -433,6 +433,8 @@ set (SRC
pass/shape_relevance.hpp pass/shape_relevance.hpp
pass/validate_graph.cpp pass/validate_graph.cpp
pass/validate_graph.hpp pass/validate_graph.hpp
pass/validate.cpp
pass/validate.hpp
pass/visualize_tree.cpp pass/visualize_tree.cpp
pass/visualize_tree.hpp pass/visualize_tree.hpp
pass/zero_dim_tensor_elimination.cpp pass/zero_dim_tensor_elimination.cpp
......
...@@ -53,13 +53,9 @@ pass::Manager::~Manager() ...@@ -53,13 +53,9 @@ pass::Manager::~Manager()
{ {
} }
void pass::Manager::initialize_default_passes() void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
{ {
} static bool profile_enabled = getenv("NGRAPH_PROFILE_PASS_ENABLE") != nullptr;
void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive, bool revalidate)
{
bool profile_enabled = getenv("NGRAPH_PROFILE_PASS_ENABLE") != nullptr;
get_state().set_function(func); get_state().set_function(func);
vector<std::pair<shared_ptr<Function>, bool>> fs{std::make_pair(func, func->is_dynamic())}; vector<std::pair<shared_ptr<Function>, bool>> fs{std::make_pair(func, func->is_dynamic())};
...@@ -138,16 +134,6 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive, bool ...@@ -138,16 +134,6 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive, bool
} }
} }
// Better to do this in node replacement but this will do for now
if (revalidate)
{
for (auto f_pair : fs)
{
shared_ptr<Function> f = f_pair.first;
f->validate_nodes_and_infer_types();
}
}
if (m_visualize || m_serialize) if (m_visualize || m_serialize)
{ {
// visualizations and serializations will be named after the outermost function // visualizations and serializations will be named after the outermost function
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "ngraph/pass/manager_state.hpp" #include "ngraph/pass/manager_state.hpp"
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "ngraph/pass/pass_config.hpp" #include "ngraph/pass/pass_config.hpp"
#include "ngraph/pass/validate.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -40,8 +41,6 @@ public: ...@@ -40,8 +41,6 @@ public:
Manager(); Manager();
~Manager(); ~Manager();
void initialize_default_passes();
template <typename T, class... Args> template <typename T, class... Args>
void register_pass(Args&&... args) void register_pass(Args&&... args)
{ {
...@@ -62,15 +61,22 @@ public: ...@@ -62,15 +61,22 @@ public:
m_pass_names.push_back(typeid(T).name()); m_pass_names.push_back(typeid(T).name());
#endif #endif
} }
if (m_per_pass_validation)
{
auto validate = std::make_shared<Validate>();
auto validate_base = std::static_pointer_cast<PassBase>(validate);
m_pass_list.push_back(validate_base);
}
} }
void run_passes(std::shared_ptr<Function>, bool transitive = true, bool revalidate = true); void run_passes(std::shared_ptr<Function>, bool transitive = true);
ManagerState& get_state(); ManagerState& get_state();
PassConfig& get_pass_config() { return m_pass_config; } PassConfig& get_pass_config() { return m_pass_config; }
void set_pass_config(const PassConfig& pass_config) { m_pass_config = pass_config; } void set_pass_config(const PassConfig& pass_config) { m_pass_config = pass_config; }
void set_pass_visualization(bool new_state) { m_visualize = new_state; } void set_pass_visualization(bool new_state) { m_visualize = new_state; }
void set_pass_serialization(bool new_state) { m_serialize = new_state; } void set_pass_serialization(bool new_state) { m_serialize = new_state; }
void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
private: private:
std::vector<std::string> m_pass_names; std::vector<std::string> m_pass_names;
std::vector<std::shared_ptr<PassBase>> m_pass_list; std::vector<std::shared_ptr<PassBase>> m_pass_list;
...@@ -78,4 +84,5 @@ private: ...@@ -78,4 +84,5 @@ private:
PassConfig m_pass_config; PassConfig m_pass_config;
bool m_visualize = false; bool m_visualize = false;
bool m_serialize = false; bool m_serialize = false;
bool m_per_pass_validation = true;
}; };
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/validate.hpp"
#include "ngraph/graph_util.hpp"
using namespace ngraph;
bool pass::Validate::run_on_function(std::shared_ptr<Function> f)
{
f->validate_nodes_and_infer_types();
return false;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class Validate : public FunctionPass
{
public:
Validate()
: FunctionPass()
{
}
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
}
}
...@@ -84,6 +84,7 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable> ...@@ -84,6 +84,7 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
// compilation. // compilation.
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
pass_manager.set_per_pass_validation(false);
// We apply the same general-purposes passes as the CPU backend. // We apply the same general-purposes passes as the CPU backend.
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>(); pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
...@@ -124,7 +125,7 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable> ...@@ -124,7 +125,7 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
auto rewrite_func = clone_function(*func); auto rewrite_func = clone_function(*func);
// Apply passes, with revalidation disabled. // Apply passes, with revalidation disabled.
pass_manager.run_passes(rewrite_func, true, false); pass_manager.run_passes(rewrite_func);
// Compile the resulting function. // Compile the resulting function.
Build b; Build b;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment