Commit f117269f authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Adam Procter

Common pass registration for codegen and Dex (#1642)

* Common pass registration for codegen and Dex

* Make return indices optional for cpu workspace insertion
parent 28228857
...@@ -368,20 +368,8 @@ static void ...@@ -368,20 +368,8 @@ static void
writer << "}\n"; writer << "}\n";
} }
void runtime::cpu::CPU_ExternalFunction::compile() void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Manager& pass_manager)
{ {
if (m_is_compiled)
{
return;
}
m_mkldnn_emitter.reset(new MKLDNNEmitter());
ngraph::pass::Manager pass_manager;
// nv_cwi is required only by some frontends
// in which case they should run this pass(CPUWorkspaceInsertion) explicitly
NodeVector nv_cwi;
pass_manager.register_pass<ngraph::pass::LikeReplacement>(); pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::NopElimination>(); pass_manager.register_pass<ngraph::pass::NopElimination>();
// TODO (pruthvi): Enable all the disabeled RNN fusion graph pass after fixing // TODO (pruthvi): Enable all the disabeled RNN fusion graph pass after fixing
...@@ -396,11 +384,25 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -396,11 +384,25 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<ngraph::pass::CoreFusion>(); pass_manager.register_pass<ngraph::pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUCollapseDims>(); pass_manager.register_pass<runtime::cpu::pass::CPUCollapseDims>();
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi); NodeVector nv_cwi; // We dont need CPUWorkspaceInsertion to return list of indices
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi, false);
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this); pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this); pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>(); pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>(); pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
}
void runtime::cpu::CPU_ExternalFunction::compile()
{
if (m_is_compiled)
{
return;
}
m_mkldnn_emitter.reset(new MKLDNNEmitter());
ngraph::pass::Manager pass_manager;
register_common_passes(pass_manager);
unordered_map<Node*, Node*> node_function_map; unordered_map<Node*, Node*> node_function_map;
string common_function_string; string common_function_string;
auto femitter = bind(&ngraph::runtime::cpu::CPU_ExternalFunction::emit_op_as_function, auto femitter = bind(&ngraph::runtime::cpu::CPU_ExternalFunction::emit_op_as_function,
...@@ -1132,27 +1134,8 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1132,27 +1134,8 @@ void runtime::cpu::CPU_ExternalFunction::build()
m_mkldnn_emitter.reset(new MKLDNNEmitter()); m_mkldnn_emitter.reset(new MKLDNNEmitter());
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
register_common_passes(pass_manager);
// nv_cwi is required only by some frontends
// in which case they should run this pass(CPUWorkspaceInsertion) explicitly
NodeVector nv_cwi;
pass_manager.register_pass<ngraph::pass::NopElimination>();
// TODO (pruthvi): Enable all the disabeled RNN fusion graph pass after fixing
// failing mxnet unit tests.
// pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
// pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
// pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUCollapseDims>();
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi);
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
pass_manager.register_pass<ngraph::pass::Liveness>(); pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::MemoryLayout>(size_t(s_memory_pool_alignment), true); pass_manager.register_pass<ngraph::pass::MemoryLayout>(size_t(s_memory_pool_alignment), true);
pass_manager.run_passes(m_function, false); pass_manager.run_passes(m_function, false);
......
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
#endif #endif
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/runtime/cpu/cpu_call_frame.hpp" #include "ngraph/runtime/cpu/cpu_call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp" #include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
...@@ -139,6 +140,9 @@ namespace ngraph ...@@ -139,6 +140,9 @@ namespace ngraph
#endif #endif
private: private:
// Register passes that are common to codegen and DEX
void register_common_passes(ngraph::pass::Manager& pass_manager);
// For non-destructive passthrough kernels, propagate function // For non-destructive passthrough kernels, propagate function
// input buffers to internal ops // input buffers to internal ops
void propagate_in_place_input(ngraph::descriptor::Output* output, void propagate_in_place_input(ngraph::descriptor::Output* output,
......
...@@ -168,6 +168,9 @@ bool runtime::cpu::pass::CPUWorkspaceInsertion::transform(pattern::Matcher& m) ...@@ -168,6 +168,9 @@ bool runtime::cpu::pass::CPUWorkspaceInsertion::transform(pattern::Matcher& m)
m_max_pool->get_padding_above()); m_max_pool->get_padding_above());
ngraph::replace_node(m_max_pool_bprop, max_pool_with_indices_bprop); ngraph::replace_node(m_max_pool_bprop, max_pool_with_indices_bprop);
if (m_return_indices)
{
m_indices_list.push_back(max_pool_with_indices_indices); m_indices_list.push_back(max_pool_with_indices_indices);
}
return true; return true;
} }
...@@ -36,9 +36,10 @@ namespace ngraph ...@@ -36,9 +36,10 @@ namespace ngraph
class ngraph::runtime::cpu::pass::CPUWorkspaceInsertion : public ngraph::pass::FunctionPass class ngraph::runtime::cpu::pass::CPUWorkspaceInsertion : public ngraph::pass::FunctionPass
{ {
public: public:
CPUWorkspaceInsertion(ngraph::NodeVector& indices_list) CPUWorkspaceInsertion(ngraph::NodeVector& indices_list, bool return_indices = true)
: FunctionPass() : FunctionPass()
, m_indices_list(indices_list) , m_indices_list(indices_list)
, m_return_indices(return_indices)
{ {
} }
...@@ -46,5 +47,6 @@ public: ...@@ -46,5 +47,6 @@ public:
private: private:
ngraph::NodeVector& m_indices_list; ngraph::NodeVector& m_indices_list;
bool m_return_indices;
bool transform(ngraph::pattern::Matcher& m); bool transform(ngraph::pattern::Matcher& m);
}; };
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment