Commit c693cb7e authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Make validation a pass and add it after every pass by default (#3296)

* Make validation a pass and add it after every pass by default

* cleanup

* update per review comments

* Switch plaid to new API for disabling  pass validation

* address review comment
parent aedd8c2e
......@@ -433,6 +433,8 @@ set (SRC
pass/shape_relevance.hpp
pass/validate_graph.cpp
pass/validate_graph.hpp
pass/validate.cpp
pass/validate.hpp
pass/visualize_tree.cpp
pass/visualize_tree.hpp
pass/zero_dim_tensor_elimination.cpp
......
......@@ -53,13 +53,9 @@ pass::Manager::~Manager()
{
}
void pass::Manager::initialize_default_passes()
void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
{
}
void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive, bool revalidate)
{
bool profile_enabled = getenv("NGRAPH_PROFILE_PASS_ENABLE") != nullptr;
static bool profile_enabled = getenv("NGRAPH_PROFILE_PASS_ENABLE") != nullptr;
get_state().set_function(func);
vector<std::pair<shared_ptr<Function>, bool>> fs{std::make_pair(func, func->is_dynamic())};
......@@ -138,16 +134,6 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive, bool
}
}
// Better to do this in node replacement but this will do for now
if (revalidate)
{
for (auto f_pair : fs)
{
shared_ptr<Function> f = f_pair.first;
f->validate_nodes_and_infer_types();
}
}
if (m_visualize || m_serialize)
{
// visualizations and serializations will be named after the outermost function
......
......@@ -24,6 +24,7 @@
#include "ngraph/pass/manager_state.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/pass_config.hpp"
#include "ngraph/pass/validate.hpp"
namespace ngraph
{
......@@ -40,8 +41,6 @@ public:
Manager();
~Manager();
void initialize_default_passes();
template <typename T, class... Args>
void register_pass(Args&&... args)
{
......@@ -62,15 +61,22 @@ public:
m_pass_names.push_back(typeid(T).name());
#endif
}
if (m_per_pass_validation)
{
auto validate = std::make_shared<Validate>();
auto validate_base = std::static_pointer_cast<PassBase>(validate);
m_pass_list.push_back(validate_base);
}
}
void run_passes(std::shared_ptr<Function>, bool transitive = true, bool revalidate = true);
void run_passes(std::shared_ptr<Function>, bool transitive = true);
ManagerState& get_state();
PassConfig& get_pass_config() { return m_pass_config; }
void set_pass_config(const PassConfig& pass_config) { m_pass_config = pass_config; }
void set_pass_visualization(bool new_state) { m_visualize = new_state; }
void set_pass_serialization(bool new_state) { m_serialize = new_state; }
void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
private:
std::vector<std::string> m_pass_names;
std::vector<std::shared_ptr<PassBase>> m_pass_list;
......@@ -78,4 +84,5 @@ private:
PassConfig m_pass_config;
bool m_visualize = false;
bool m_serialize = false;
bool m_per_pass_validation = true;
};
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/validate.hpp"
#include "ngraph/graph_util.hpp"
using namespace ngraph;
bool pass::Validate::run_on_function(std::shared_ptr<Function> f)
{
f->validate_nodes_and_infer_types();
return false;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class Validate : public FunctionPass
{
public:
Validate()
: FunctionPass()
{
}
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
}
}
......@@ -84,6 +84,7 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
// compilation.
ngraph::pass::Manager pass_manager;
pass_manager.set_per_pass_validation(false);
// We apply the same general-purposes passes as the CPU backend.
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
......@@ -124,7 +125,7 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
auto rewrite_func = clone_function(*func);
// Apply passes, with revalidation disabled.
pass_manager.run_passes(rewrite_func, true, false);
pass_manager.run_passes(rewrite_func);
// Compile the resulting function.
Build b;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment