Commit fe4b0f49 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

mat_fusion 4d test case (#1809)

parent e7443b19
......@@ -2916,16 +2916,20 @@ TEST(cpu_fusion, dot_batch_forward)
}
}
static std::shared_ptr<Function>
create_rnn_input_linear_transformation_function(size_t num_timesteps)
create_rnn_input_linear_transformation_function(size_t num_timesteps, bool data_is_4d = false)
{
auto W = std::make_shared<op::Parameter>(element::f32, Shape{400, 50});
auto bias = std::make_shared<op::Parameter>(element::f32, Shape{400});
op::ParameterVector params{W, bias};
auto create_graph = [&]() -> std::shared_ptr<Node> {
auto data_param = std::make_shared<op::Parameter>(element::f32, Shape{10, 1, 50});
auto data_param = (data_is_4d)
? std::make_shared<op::Parameter>(element::f32, Shape{2, 5, 1, 50})
: std::make_shared<op::Parameter>(element::f32, Shape{10, 1, 50});
params.push_back(data_param);
auto reshape_axis_order = data_is_4d ? AxisVector{0, 1, 2, 3} : AxisVector{0, 1, 2};
auto data_param_reshape =
std::make_shared<op::Reshape>(data_param, AxisVector{0, 1, 2}, Shape{10, 50});
std::make_shared<op::Reshape>(data_param, reshape_axis_order, Shape{10, 50});
auto W_reshape = std::make_shared<op::Reshape>(W, AxisVector{1, 0}, Shape{50, 400});
auto dot = std::make_shared<op::Dot>(data_param_reshape, W_reshape);
auto bias_broadcast = make_shared<op::Broadcast>(bias, dot->get_shape(), AxisSet{0});
......@@ -2955,6 +2959,18 @@ TEST(cpu_fusion, fuse_rnn_input_across_time_steps)
EXPECT_EQ(ref_matmulbias_count, matmulbias_count);
}
TEST(cpu_fusion, fuse_rnn_input_across_time_steps_4d_data)
{
auto func = create_rnn_input_linear_transformation_function(10, true);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(func);
size_t ref_matmulbias_count = 10; // no CPURnnMatFusion transformations
auto matmulbias_count = count_ops_of_type<op::MatmulBias>(func);
EXPECT_EQ(ref_matmulbias_count, matmulbias_count);
}
TEST(cpu_fusion, rnn_input_fusion_inter_vs_cpu)
{
shared_ptr<Function> cpu_func = create_rnn_input_linear_transformation_function(10);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment