Commit 7277a9fd authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

make sure Slice is reshaped if needed (#1803)

parent b339ea71
...@@ -147,6 +147,13 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi ...@@ -147,6 +147,13 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
auto matched_weight = matcher_v2->get_pattern_map()[W]->get_argument(0); auto matched_weight = matcher_v2->get_pattern_map()[W]->get_argument(0);
auto matched_data = matcher_v2->get_pattern_map()[input_data]; auto matched_data = matcher_v2->get_pattern_map()[input_data];
auto matched_bias = matcher_v2->get_pattern_map()[b]->get_argument(0); auto matched_bias = matcher_v2->get_pattern_map()[b]->get_argument(0);
if (matcher_v2->get_match_root()->get_shape().size() != 2 &&
matcher_v2->get_match_root()->get_shape().size() != 3)
{
NGRAPH_DEBUG << "mat fusion (v2) root " << matcher_v2->get_match_root()->get_name()
<< " isn't 2D or 3D";
continue;
}
map_weights_to_pattern[matched_weight].push_back(matcher_v2->get_match_root()); map_weights_to_pattern[matched_weight].push_back(matcher_v2->get_match_root());
map_weights_bias_to_data[std::make_pair(matched_weight, matched_bias)].push_back( map_weights_bias_to_data[std::make_pair(matched_weight, matched_bias)].push_back(
matched_data); matched_data);
...@@ -248,8 +255,15 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi ...@@ -248,8 +255,15 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
size_t end_index = batch_size; size_t end_index = batch_size;
for (auto& matched_root_node : map_weights_to_pattern[weights]) for (auto& matched_root_node : map_weights_to_pattern[weights])
{ {
auto slice_node = std::make_shared<op::Slice>( std::shared_ptr<Node> slice_node = std::make_shared<op::Slice>(
new_add_bias, Coordinate{start_index, 0}, Coordinate{end_index, shape_axis_1}); new_add_bias, Coordinate{start_index, 0}, Coordinate{end_index, shape_axis_1});
if (matched_root_node->get_shape().size() != 2)
{
NGRAPH_ASSERT(matched_root_node->get_shape().size() == 3);
slice_node = std::make_shared<op::Reshape>(
slice_node, AxisVector{0, 1}, matched_root_node->get_shape());
}
start_index += batch_size; start_index += batch_size;
end_index += batch_size; end_index += batch_size;
NGRAPH_DEBUG << "Replacing op " << matched_root_node->get_name() << " with " NGRAPH_DEBUG << "Replacing op " << matched_root_node->get_name() << " with "
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment