Commit 18e41513 authored by Pruthvi's avatar Pruthvi Committed by Robert Kimball

Pruthvi/rnn fusion (#1677)

* WIP input * weights rnn optimization

* concat + slcing + replacing new node works

* WIP unit test case of fusing rnn inputs

* - Added unit test case for fusing rnn input weights
- registered CPURnnMatFusion_v1/v2 in codegen and DEX

* fixed redeclaration of a variable

* Refactored rnn input traformation passes into a single pass

* Refactored CPURnnMatFusion call back functions

* change random generator range to include -ve values in unit test

* address PR comments

* dont fuse if the shape of the data slices dont match
parent ee712ae8
......@@ -1000,11 +1000,12 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
pass_manager.register_pass<ngraph::pass::NopElimination>();
// TODO (pruthvi): Enable all the disabeled RNN fusion graph pass after fixing
// failing mxnet unit tests.
// pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
// pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
// pass_manager.register_pass<runtime::cpu::pass::MultiLayerRNNFusion>();
pass_manager.register_pass<runtime::cpu::pass::MultiLayerRNNFusion>();
// pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
......
......@@ -2822,3 +2822,64 @@ TEST(cpu_fusion, dot_batch_forward)
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
static std::shared_ptr<Function>
create_rnn_input_linear_transformation_function(size_t num_timesteps)
{
auto W = std::make_shared<op::Parameter>(element::f32, Shape{400, 50});
auto bias = std::make_shared<op::Parameter>(element::f32, Shape{400});
op::ParameterVector params{W, bias};
auto create_graph = [&]() -> std::shared_ptr<Node> {
auto data_param = std::make_shared<op::Parameter>(element::f32, Shape{10, 1, 50});
params.push_back(data_param);
auto data_param_reshape =
std::make_shared<op::Reshape>(data_param, AxisVector{0, 1, 2}, Shape{10, 50});
auto W_reshape = std::make_shared<op::Reshape>(W, AxisVector{1, 0}, Shape{50, 400});
auto dot = std::make_shared<op::Dot>(data_param_reshape, W_reshape);
auto bias_broadcast = make_shared<op::Broadcast>(bias, dot->get_shape(), AxisSet{0});
auto add_bias = std::make_shared<op::Add>(dot, bias_broadcast);
return add_bias;
};
NodeVector graph_nodes;
for (size_t i = 0; i < num_timesteps; i++)
{
graph_nodes.push_back(create_graph());
}
auto concat = std::make_shared<op::Concat>(graph_nodes, 0);
return make_shared<Function>(NodeVector{concat}, params);
}
TEST(cpu_fusion, fuse_rnn_input_across_time_steps)
{
auto func = create_rnn_input_linear_transformation_function(10);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(func);
size_t ref_matmulbias_count = 1;
auto matmulbias_count = count_ops_of_type<op::MatmulBias>(func);
EXPECT_EQ(ref_matmulbias_count, matmulbias_count);
}
TEST(cpu_fusion, rnn_input_fusion_inter_vs_cpu)
{
shared_ptr<Function> cpu_func = create_rnn_input_linear_transformation_function(10);
shared_ptr<Function> int_func = create_rnn_input_linear_transformation_function(10);
test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_func->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_func, args, "INTERPRETER");
auto cpu_results = execute(cpu_func, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment