Commit 4cd2c602 authored by Pruthvi's avatar Pruthvi Committed by Robert Kimball

Disabeled RNN fusion pass in IA transformer (#1217)

parent 785c1ce7
......@@ -348,11 +348,13 @@ void runtime::cpu::CPU_ExternalFunction::compile()
//in which case they should run this pass(CPUWorkspaceInsertion) explicitly
NodeVector nv_cwi;
pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
// TODO (pruthvi): Enable all the disabeled RNN fusion graph pass after fixing
// failing mxnet unit tests.
// pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
// pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<runtime::cpu::pass::MultiLayerRNNFusion>();
pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
// pass_manager.register_pass<runtime::cpu::pass::MultiLayerRNNFusion>();
// pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
......
......@@ -2591,7 +2591,9 @@ TEST(cpu_fusion, fuse_rnn_across_2layer_1timestep)
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
EXPECT_EQ(1, count_ops_of_type<op::Rnn>(cpu_f));
// TODO (pruthvi): Enable this after fixing failing
// mxnet rnn unit tests
// EXPECT_EQ(1, count_ops_of_type<op::Rnn>(cpu_f));
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(1), int_results.at(1), 1.0e-4f, 1.0e-4f));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment