Commit f117269f authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Adam Procter

Common pass registration for codegen and Dex (#1642)

* Common pass registration for codegen and Dex

* Make return indices optional for cpu workspace insertion
parent 28228857
......@@ -368,20 +368,8 @@ static void
writer << "}\n";
}
void runtime::cpu::CPU_ExternalFunction::compile()
void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Manager& pass_manager)
{
if (m_is_compiled)
{
return;
}
m_mkldnn_emitter.reset(new MKLDNNEmitter());
ngraph::pass::Manager pass_manager;
// nv_cwi is required only by some frontends
// in which case they should run this pass(CPUWorkspaceInsertion) explicitly
NodeVector nv_cwi;
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::NopElimination>();
// TODO (pruthvi): Enable all the disabeled RNN fusion graph pass after fixing
......@@ -396,11 +384,25 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<ngraph::pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUCollapseDims>();
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi);
NodeVector nv_cwi; // We dont need CPUWorkspaceInsertion to return list of indices
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi, false);
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
}
void runtime::cpu::CPU_ExternalFunction::compile()
{
if (m_is_compiled)
{
return;
}
m_mkldnn_emitter.reset(new MKLDNNEmitter());
ngraph::pass::Manager pass_manager;
register_common_passes(pass_manager);
unordered_map<Node*, Node*> node_function_map;
string common_function_string;
auto femitter = bind(&ngraph::runtime::cpu::CPU_ExternalFunction::emit_op_as_function,
......@@ -1132,27 +1134,8 @@ void runtime::cpu::CPU_ExternalFunction::build()
m_mkldnn_emitter.reset(new MKLDNNEmitter());
ngraph::pass::Manager pass_manager;
register_common_passes(pass_manager);
// nv_cwi is required only by some frontends
// in which case they should run this pass(CPUWorkspaceInsertion) explicitly
NodeVector nv_cwi;
pass_manager.register_pass<ngraph::pass::NopElimination>();
// TODO (pruthvi): Enable all the disabeled RNN fusion graph pass after fixing
// failing mxnet unit tests.
// pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
// pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
// pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUCollapseDims>();
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi);
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::MemoryLayout>(size_t(s_memory_pool_alignment), true);
pass_manager.run_passes(m_function, false);
......
......@@ -36,6 +36,7 @@
#endif
#include "ngraph/function.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/runtime/cpu/cpu_call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
......@@ -139,6 +140,9 @@ namespace ngraph
#endif
private:
// Register passes that are common to codegen and DEX
void register_common_passes(ngraph::pass::Manager& pass_manager);
// For non-destructive passthrough kernels, propagate function
// input buffers to internal ops
void propagate_in_place_input(ngraph::descriptor::Output* output,
......
......@@ -168,6 +168,9 @@ bool runtime::cpu::pass::CPUWorkspaceInsertion::transform(pattern::Matcher& m)
m_max_pool->get_padding_above());
ngraph::replace_node(m_max_pool_bprop, max_pool_with_indices_bprop);
m_indices_list.push_back(max_pool_with_indices_indices);
if (m_return_indices)
{
m_indices_list.push_back(max_pool_with_indices_indices);
}
return true;
}
......@@ -36,9 +36,10 @@ namespace ngraph
class ngraph::runtime::cpu::pass::CPUWorkspaceInsertion : public ngraph::pass::FunctionPass
{
public:
CPUWorkspaceInsertion(ngraph::NodeVector& indices_list)
CPUWorkspaceInsertion(ngraph::NodeVector& indices_list, bool return_indices = true)
: FunctionPass()
, m_indices_list(indices_list)
, m_return_indices(return_indices)
{
}
......@@ -46,5 +47,6 @@ public:
private:
ngraph::NodeVector& m_indices_list;
bool m_return_indices;
bool transform(ngraph::pattern::Matcher& m);
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment