Commit 7277a9fd authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

make sure Slice is reshaped if needed (#1803)

parent b339ea71
......@@ -147,6 +147,13 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
auto matched_weight = matcher_v2->get_pattern_map()[W]->get_argument(0);
auto matched_data = matcher_v2->get_pattern_map()[input_data];
auto matched_bias = matcher_v2->get_pattern_map()[b]->get_argument(0);
if (matcher_v2->get_match_root()->get_shape().size() != 2 &&
matcher_v2->get_match_root()->get_shape().size() != 3)
{
NGRAPH_DEBUG << "mat fusion (v2) root " << matcher_v2->get_match_root()->get_name()
<< " isn't 2D or 3D";
continue;
}
map_weights_to_pattern[matched_weight].push_back(matcher_v2->get_match_root());
map_weights_bias_to_data[std::make_pair(matched_weight, matched_bias)].push_back(
matched_data);
......@@ -248,8 +255,15 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
size_t end_index = batch_size;
for (auto& matched_root_node : map_weights_to_pattern[weights])
{
auto slice_node = std::make_shared<op::Slice>(
std::shared_ptr<Node> slice_node = std::make_shared<op::Slice>(
new_add_bias, Coordinate{start_index, 0}, Coordinate{end_index, shape_axis_1});
if (matched_root_node->get_shape().size() != 2)
{
NGRAPH_ASSERT(matched_root_node->get_shape().size() == 3);
slice_node = std::make_shared<op::Reshape>(
slice_node, AxisVector{0, 1}, matched_root_node->get_shape());
}
start_index += batch_size;
end_index += batch_size;
NGRAPH_DEBUG << "Replacing op " << matched_root_node->get_name() << " with "
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment