Commit f75b8006 authored by Pruthvi's avatar Pruthvi Committed by Adam Procter

RNN fusion across layers (#1085)

* - Added graph pass for fusing RNN op across layer
- Added test case for inter v/s cpu for verifying layer fused RNN
- more sanity checks in the RNN fusion graph pass
- added support to replace the recurrent cell state correctly in the fused RNN op

* Fixed multi layer rnn fusion unit test failure

* Addressed PR comments
parent 7c8e9250
...@@ -344,9 +344,10 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -344,9 +344,10 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<ngraph::pass::NopElimination>(); pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>(); pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::RNNFusion>(); pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>(); pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<runtime::cpu::pass::MultiLayerRNNFusion>();
pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>(); pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>(); pass_manager.register_pass<ngraph::pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
......
...@@ -29,6 +29,7 @@ namespace ngraph ...@@ -29,6 +29,7 @@ namespace ngraph
{ {
class LSTMFusion; class LSTMFusion;
class RNNFusion; class RNNFusion;
class MultiLayerRNNFusion;
} }
} }
} }
...@@ -61,3 +62,16 @@ public: ...@@ -61,3 +62,16 @@ public:
private: private:
void construct_rnn_lstm_fprop(); void construct_rnn_lstm_fprop();
}; };
class ngraph::runtime::cpu::pass::MultiLayerRNNFusion : public ngraph::pass::RecurrentGraphRewrite
{
public:
MultiLayerRNNFusion()
: RecurrentGraphRewrite()
{
construct_multi_layer_rnn_fusion_fprop();
}
private:
void construct_multi_layer_rnn_fusion_fprop();
};
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/op/tanh.hpp" #include "ngraph/op/tanh.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/reshape_elimination.hpp" #include "ngraph/pass/reshape_elimination.hpp"
...@@ -2197,3 +2198,45 @@ TEST(cpu_fusion, fuse_batch_dot_forward) ...@@ -2197,3 +2198,45 @@ TEST(cpu_fusion, fuse_batch_dot_forward)
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f)); EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
} }
} }
TEST(cpu_fusion, fuse_rnn_across_layer)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<runtime::cpu::pass::MultiLayerRNNFusion>();
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/2rnn_layer_1timestep.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
size_t ref_rnn_count = 1;
auto rnn_count = count_ops_of_type<op::Rnn>(func);
EXPECT_EQ(ref_rnn_count, rnn_count);
}
TEST(cpu_fusion, fuse_rnn_across_2layer_1timestep)
{
const std::string file_name("mxnet/2rnn_layer_1timestep.json");
auto cpu_f = make_function(file_name);
auto int_f = make_function(file_name);
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
EXPECT_EQ(1, count_ops_of_type<op::Rnn>(cpu_f));
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(1), int_results.at(1), 1.0e-4f, 1.0e-4f));
}
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment