Commit 917efb94 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

Pruthvi/fix input matrix fusion (#2381)

* -   check to verify if the data_slices shares the same weights

* add the serialized graph

* - explicitly fuse the data slices, so all the parameter partitioned by slices are in contigous memory location
- fixes all the failing test cases
parent 5f38fd1a
......@@ -306,15 +306,41 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
NodeVector params = p.first;
NodeVector& op_nodes = p.second;
auto data_node = params.at(Type::DATA);
// we will sort the captured Add(Dot(X, W) + B) as per the the slice ordering of X
// this will simplify the replace_node logic
auto compare_slices = [&](const std::shared_ptr<Node> node1,
const std::shared_ptr<Node> node2) {
const auto node1_slice =
std::static_pointer_cast<op::Slice>(op_seg_map[node1].at(Type::DATA));
const auto node2_slice =
std::static_pointer_cast<op::Slice>(op_seg_map[node2].at(Type::DATA));
return (node1_slice->get_lower_bounds() < node2_slice->get_lower_bounds() &&
node1_slice->get_upper_bounds() < node2_slice->get_upper_bounds());
};
std::sort(op_nodes.begin(), op_nodes.end(), compare_slices);
// we fuse all the data slices captured in the pattern to make bigger GEMM call
auto fuse_data_slices = [&]() {
NodeVector data_slices;
for (auto& op : op_nodes)
{
auto data_node = op_seg_map.at(op).at(Type::DATA);
data_slices.push_back(data_node);
}
return std::make_shared<op::Concat>(data_slices, 0);
};
auto data_node = op_nodes.size() > 1 ? fuse_data_slices() : params.at(Type::DATA);
auto weights_node = params.at(Type::WEIGHTS);
auto bias_node = params.at(Type::BIAS);
auto& data_shape = data_node->get_shape();
const auto& data_shape = data_node->get_shape();
// construct new op nodes
auto data_order = ngraph::get_default_order(data_node->get_shape());
auto data_reshape_node = std::make_shared<op::Reshape>(
data_node, data_order, Shape{data_shape[0] * data_shape[1], data_shape[2]});
auto data_reshape_node =
std::make_shared<op::Reshape>(data_node,
AxisVector{0, 1, 2},
Shape{data_shape[0] * data_shape[1], data_shape[2]});
auto old_weights_reshape_node = op_seg_map.at(op_nodes.at(0)).at(Type::WEIGHTS);
auto weights_reshape_node =
......@@ -327,30 +353,16 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
auto add_node = std::make_shared<op::Add>(dot_node, bias_broadcast_node);
const auto& add_shape = add_node->get_shape();
// we will sort the captured Add(Dot(X, W) + B) as per the the slice ordering of X
// this will simplify the replace_node logic
auto compare_slices = [&](const std::shared_ptr<Node> node1,
const std::shared_ptr<Node> node2) {
const auto node1_slice =
std::static_pointer_cast<op::Slice>(op_seg_map[node1].at(Type::DATA));
const auto node2_slice =
std::static_pointer_cast<op::Slice>(op_seg_map[node2].at(Type::DATA));
return (node1_slice->get_lower_bounds() < node2_slice->get_lower_bounds() &&
node1_slice->get_upper_bounds() < node2_slice->get_upper_bounds());
};
std::sort(op_nodes.begin(), op_nodes.end(), compare_slices);
size_t num_timesteps = op_nodes.size();
size_t batch_size = add_shape[0] / num_timesteps;
size_t feature_size = add_shape[1];
// create a slice for each user of the dot op matching the original dot op's output
for (size_t i = 0, start_index = 0; i < op_nodes.size(); i++, start_index += batch_size)
{
// calculate the lower and upper bounds for the slice of the new fused node
// ((<x0 | x1..|xt>*W)+b), which will used to replace the nodes matched in the pattern
const Coordinate lower_bounds{start_index, 0};
const Coordinate upper_bounds{start_index + batch_size, add_shape[1]};
const Coordinate upper_bounds{start_index + batch_size, feature_size};
auto slice_node = std::make_shared<op::Slice>(add_node, lower_bounds, upper_bounds);
......
......@@ -3450,3 +3450,26 @@ TEST(cpu_fusion, rnn_input_fusion_inter_vs_cpu)
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(cpu_fusion, validate_fuse_gru_inputs)
{
const std::string file_name("mxnet/gru_debug.json");
auto cpu_func = make_function_from_file(file_name);
auto int_func = make_function_from_file(file_name);
test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_func->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_func, args, "INTERPRETER");
auto cpu_results = execute(cpu_func, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment