Commit 880594ba authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Robert Kimball

Support to selectively enable/disable passes through env variable (#2049)

* Support to selectively enable/disable passes through env variable

* Address compiler warning about zero arg macros

* Move pass selection options to PassConfig object

* remove unnecessary header include

* use default copy constructor

* Address PR feedback

* switch to map to keep xcode clang happy. doesn't seem to have a hashing function for string
parent 803c38aa
......@@ -140,6 +140,7 @@ set (SRC
pass/memory_visualize.cpp
pass/nop_elimination.cpp
pass/pass.cpp
pass/pass_config.cpp
pass/propagate_cacheability.cpp
pass/reshape_elimination.cpp
pass/zero_dim_tensor_elimination.cpp
......
......@@ -23,6 +23,7 @@
#include "ngraph/pass/manager_state.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/pass_config.hpp"
namespace ngraph
{
......@@ -57,12 +58,15 @@ public:
void run_passes(std::shared_ptr<Function>, bool transitive = true);
ManagerState& get_state();
PassConfig& get_pass_config() { return m_pass_config; }
void set_pass_config(const PassConfig& pass_config) { m_pass_config = pass_config; }
void set_pass_visualization(bool new_state) { m_visualize = new_state; }
void set_pass_serialization(bool new_state) { m_serialize = new_state; }
private:
std::vector<std::string> m_pass_names;
std::vector<std::shared_ptr<PassBase>> m_pass_list;
ManagerState m_state;
PassConfig m_pass_config;
bool m_visualize = false;
bool m_serialize = false;
};
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/pass_config.hpp"
#include "ngraph/except.hpp"
#include "ngraph/log.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
ngraph::pass::PassConfig::PassConfig()
{
/**
* Parses the semi-colon separated environment string passed through NGRAPH_PASS_ENABLES
* and returns the pass names and whether they should be enabled or disabled in the
* provided unordered_map. Implementation of pass selection is up to the backend
* E.g., NGRAPH_PASS_ENABLES="CoreFusion:0;LikeReplacement:1;CPUCollapseDims" would
* set disables on CoreFusion and enables on LikeReplacement and CPUCollapseDims
**/
const char* env_str = std::getenv("NGRAPH_PASS_ENABLES");
if (env_str)
{
std::stringstream ss;
ss << env_str;
while (ss.good())
{
std::string substr;
std::getline(ss, substr, ';');
auto split_str = split(substr, ':', false);
switch (split_str.size())
{
case 1: m_enables.emplace(split_str[0], true); break;
case 2: m_enables.emplace(split_str[0], parse_string<bool>(split_str[1])); break;
default: throw ngraph_error("Unexpected string in get_pass_enables: " + substr);
}
}
}
}
void ngraph::pass::PassConfig::set_pass_enable(std::string name, bool enable)
{
m_enables[name] = enable;
}
bool ngraph::pass::PassConfig::get_pass_enable(std::string name)
{
if (m_enables.find(name) == m_enables.end())
{
return false;
}
return m_enables[name];
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <map>
namespace ngraph
{
namespace pass
{
class PassConfig;
}
}
class ngraph::pass::PassConfig
{
public:
PassConfig();
const std::map<std::string, bool>& get_enables() { return m_enables; }
void set_pass_enable(std::string name, bool enable);
bool get_pass_enable(std::string name);
private:
std::map<std::string, bool> m_enables;
};
......@@ -129,6 +129,7 @@
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/propagate_cacheability.hpp"
#include "ngraph/pass/reshape_elimination.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/cpu/cpu_backend.hpp"
......@@ -168,6 +169,7 @@
#include "ngraph/runtime/cpu/pass/cpu_mat_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_memory_optimization.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_reshape_sinking.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#include "ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp"
......@@ -179,6 +181,34 @@
using namespace std;
using namespace ngraph;
#define STR(s) #s
#define REGISTER_KNOBBED_PASS(name, enable_by_default, prefix) \
if (pass_map.find(STR(name)) != pass_map.end()) \
{ \
if (pass_map[STR(name)]) \
{ \
pass_manager.register_pass<prefix::name>(); \
} \
} \
else if (enable_by_default) \
{ \
pass_manager.register_pass<prefix::name>(); \
}
#define REGISTER_KNOBBED_PASS_WITH_ARGS(name, enable_by_default, prefix, ...) \
if (pass_map.find(STR(name)) != pass_map.end()) \
{ \
if (pass_map[STR(name)]) \
{ \
pass_manager.register_pass<prefix::name>(__VA_ARGS__); \
} \
} \
else if (enable_by_default) \
{ \
pass_manager.register_pass<prefix::name>(__VA_ARGS__); \
}
runtime::cpu::CPU_ExternalFunction::CPU_ExternalFunction(
const shared_ptr<ngraph::Function>& function, bool release_function)
: m_function(function)
......@@ -1033,36 +1063,38 @@ using namespace ngraph::runtime;
void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Manager& pass_manager)
{
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
// TODO (pruthvi): Enable all the disabeled RNN fusion graph pass after fixing
// failing mxnet unit tests.
// pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
// pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
// pass_manager.register_pass<runtime::cpu::pass::MultiLayerRNNFusion>();
// pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUHorizontalFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUCollapseDims>();
auto pass_map = pass_manager.get_pass_config().get_enables();
REGISTER_KNOBBED_PASS(LikeReplacement, true, ngraph::pass);
REGISTER_KNOBBED_PASS(NopElimination, true, ngraph::pass);
REGISTER_KNOBBED_PASS(ZeroDimTensorElimination, true, ngraph::pass);
REGISTER_KNOBBED_PASS(LSTMFusion, false, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(RNNFusion, false, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(AlgebraicSimplification, true, ngraph::pass);
REGISTER_KNOBBED_PASS(MultiLayerRNNFusion, false, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(ConcatInputs, false, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPURnnMatFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUBatchFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUReshapeSinking, false, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(ReshapeElimination, false, ngraph::pass);
REGISTER_KNOBBED_PASS(CoreFusion, true, ngraph::pass);
REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUHorizontalFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUCollapseDims, true, runtime::cpu::pass);
#if defined(NGRAPH_HALIDE)
pass_manager.register_pass<ngraph::runtime::cpu::pass::HalideSubgraphExtraction>();
REGISTER_KNOBBED_PASS(HalideSubgraphExtraction, true, ngraph::runtime::cpu::pass);
#endif
NodeVector nv_cwi; // We dont need CPUWorkspaceInsertion to return list of indices
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi, false);
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<ngraph::pass::ConstantFolding>();
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>(
runtime::cpu::get_cse_handlers_map());
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
pass_manager.register_pass<runtime::cpu::pass::CPUMemoryOptimization>();
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
REGISTER_KNOBBED_PASS_WITH_ARGS(CPUWorkspaceInsertion, true, runtime::cpu::pass, nv_cwi, false);
REGISTER_KNOBBED_PASS_WITH_ARGS(CPUAssignment, true, runtime::cpu::pass, this);
REGISTER_KNOBBED_PASS(ConstantFolding, true, ngraph::pass);
REGISTER_KNOBBED_PASS_WITH_ARGS(CPULayout, true, runtime::cpu::pass, this);
REGISTER_KNOBBED_PASS_WITH_ARGS(
CommonSubexpressionElimination, true, ngraph::pass, runtime::cpu::get_cse_handlers_map());
REGISTER_KNOBBED_PASS(CPUPostLayoutOptimizations, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUMemoryOptimization, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(GetOutputElementElimination, true, ngraph::pass);
pass_manager.get_state().set_visualize_tree_ops_map(runtime::cpu::get_visualize_tree_ops_map());
}
......
......@@ -42,6 +42,7 @@
#include "ngraph/function.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/pass_config.hpp"
#include "ngraph/runtime/cpu/cpu_call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment