Commit 8e1922be authored by Pruthvi's avatar Pruthvi Committed by Robert Kimball

fix bug in rnn matrix fusion call back (#2279)

* - made changes to slicing logic in the rnn input matrix fusion call back
- this fixes bug in the GNMT

* - fix unit test seg fault
- add sorting slices logic make the replace_node easier

* i) add check for overlapping slices
ii) addressed PR comments

* remove ambiguity check
parent ea6a5b85
...@@ -327,21 +327,35 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi ...@@ -327,21 +327,35 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
auto add_node = std::make_shared<op::Add>(dot_node, bias_broadcast_node); auto add_node = std::make_shared<op::Add>(dot_node, bias_broadcast_node);
const auto& add_shape = add_node->get_shape(); const auto& add_shape = add_node->get_shape();
// we will sort the captured Add(Dot(X, W) + B) as per the the slice ordering of X
// this will simplify the replace_node logic
auto compare_slices = [&](const std::shared_ptr<Node> node1,
const std::shared_ptr<Node> node2) {
const auto node1_slice =
std::static_pointer_cast<op::Slice>(op_seg_map[node1].at(Type::DATA));
const auto node2_slice =
std::static_pointer_cast<op::Slice>(op_seg_map[node2].at(Type::DATA));
return (node1_slice->get_lower_bounds() < node2_slice->get_lower_bounds() &&
node1_slice->get_upper_bounds() < node2_slice->get_upper_bounds());
};
std::sort(op_nodes.begin(), op_nodes.end(), compare_slices);
size_t num_timesteps = op_nodes.size();
size_t batch_size = add_shape[0] / num_timesteps;
// create a slice for each user of the dot op matching the original dot op's output // create a slice for each user of the dot op matching the original dot op's output
for (auto op : op_nodes) for (size_t i = 0, start_index = 0; i < op_nodes.size(); i++, start_index += batch_size)
{ {
const auto old_slice = // calculate the lower and upper bounds for the slice of the new fused node
std::static_pointer_cast<op::Slice>(op_seg_map[op].at(Type::DATA)); // ((<x0 | x1..|xt>*W)+b), which will used to replace the nodes matched in the pattern
const auto& old_lower_bounds = old_slice->get_lower_bounds(); const Coordinate lower_bounds{start_index, 0};
// lower bound matching the current time step const Coordinate upper_bounds{start_index + batch_size, add_shape[1]};
const Coordinate lower_bounds{old_lower_bounds[1], 0};
// striding by the number of data auto slice_node = std::make_shared<op::Slice>(add_node, lower_bounds, upper_bounds);
const Strides strides{data_shape[1], 1};
auto slice_node =
std::make_shared<op::Slice>(add_node, lower_bounds, add_shape, strides);
// replace old nodes // replace old nodes
function->replace_node(op, slice_node); function->replace_node(op_nodes[i], slice_node);
} }
modify_graph = true; modify_graph = true;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment