Commit b4338a52 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

make sure that data/weights/add are all the right shapes (#1805)

parent 7277a9fd
......@@ -558,3 +558,16 @@ bool ngraph::is_strided(const Strides& strides)
{
return std::any_of(strides.begin(), strides.end(), [](size_t stride) { return stride != 1; });
}
bool ngraph::is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t> valid_ranks)
{
auto node_rank = node->get_shape().size();
for (auto rank : valid_ranks)
{
if (rank == node_rank)
{
return true;
}
}
return false;
}
......@@ -315,4 +315,6 @@ namespace ngraph
bool possibly_overwritten(Node* node);
bool is_strided(const Strides& strides);
bool is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t> valid_ranks);
}
......@@ -23,6 +23,7 @@
#include <unordered_map>
#include "cpu_mat_fusion.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
......@@ -147,13 +148,26 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
auto matched_weight = matcher_v2->get_pattern_map()[W]->get_argument(0);
auto matched_data = matcher_v2->get_pattern_map()[input_data];
auto matched_bias = matcher_v2->get_pattern_map()[b]->get_argument(0);
if (matcher_v2->get_match_root()->get_shape().size() != 2 &&
matcher_v2->get_match_root()->get_shape().size() != 3)
std::vector<size_t> supported_ranks{2, 3};
if (!ngraph::is_valid_rank(matcher_v2->get_match_root(), supported_ranks))
{
NGRAPH_DEBUG << "mat fusion (v2) root " << matcher_v2->get_match_root()->get_name()
NGRAPH_DEBUG << "Add (mat_fusion_v2) " << matcher_v2->get_match_root()->get_name()
<< " isn't 2D or 3D";
continue;
}
if (!ngraph::is_valid_rank(matched_weight, supported_ranks))
{
NGRAPH_DEBUG << "Weights (mat_fusion_v2) " << matched_weight << " isn't 2D or 3D";
continue;
}
if (!ngraph::is_valid_rank(matched_data, supported_ranks))
{
NGRAPH_DEBUG << "Data (mat_fusion_v2) " << matched_data << " isn't 2D or 3D";
continue;
}
map_weights_to_pattern[matched_weight].push_back(matcher_v2->get_match_root());
map_weights_bias_to_data[std::make_pair(matched_weight, matched_bias)].push_back(
matched_data);
......@@ -241,6 +255,7 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
concated_data, data_order, Shape{data_shape[0] * data_shape[1], data_shape[2]});
}
auto new_input_node = data_shape.size() == 2 ? concated_data : input_reshape_node;
NGRAPH_ASSERT(new_input_node);
auto w_reshape_node = std::make_shared<op::Reshape>(
weights, AxisVector{1, 0}, Shape{w_shape[1], w_shape[0]});
auto new_dot = std::make_shared<op::Dot>(new_input_node, w_reshape_node);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment