Commit b4338a52 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

make sure that data/weights/add are all the right shapes (#1805)

parent 7277a9fd
...@@ -558,3 +558,16 @@ bool ngraph::is_strided(const Strides& strides) ...@@ -558,3 +558,16 @@ bool ngraph::is_strided(const Strides& strides)
{ {
return std::any_of(strides.begin(), strides.end(), [](size_t stride) { return stride != 1; }); return std::any_of(strides.begin(), strides.end(), [](size_t stride) { return stride != 1; });
} }
bool ngraph::is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t> valid_ranks)
{
auto node_rank = node->get_shape().size();
for (auto rank : valid_ranks)
{
if (rank == node_rank)
{
return true;
}
}
return false;
}
...@@ -315,4 +315,6 @@ namespace ngraph ...@@ -315,4 +315,6 @@ namespace ngraph
bool possibly_overwritten(Node* node); bool possibly_overwritten(Node* node);
bool is_strided(const Strides& strides); bool is_strided(const Strides& strides);
bool is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t> valid_ranks);
} }
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <unordered_map> #include <unordered_map>
#include "cpu_mat_fusion.hpp" #include "cpu_mat_fusion.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
...@@ -147,13 +148,26 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi ...@@ -147,13 +148,26 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
auto matched_weight = matcher_v2->get_pattern_map()[W]->get_argument(0); auto matched_weight = matcher_v2->get_pattern_map()[W]->get_argument(0);
auto matched_data = matcher_v2->get_pattern_map()[input_data]; auto matched_data = matcher_v2->get_pattern_map()[input_data];
auto matched_bias = matcher_v2->get_pattern_map()[b]->get_argument(0); auto matched_bias = matcher_v2->get_pattern_map()[b]->get_argument(0);
if (matcher_v2->get_match_root()->get_shape().size() != 2 && std::vector<size_t> supported_ranks{2, 3};
matcher_v2->get_match_root()->get_shape().size() != 3)
if (!ngraph::is_valid_rank(matcher_v2->get_match_root(), supported_ranks))
{ {
NGRAPH_DEBUG << "mat fusion (v2) root " << matcher_v2->get_match_root()->get_name() NGRAPH_DEBUG << "Add (mat_fusion_v2) " << matcher_v2->get_match_root()->get_name()
<< " isn't 2D or 3D"; << " isn't 2D or 3D";
continue; continue;
} }
if (!ngraph::is_valid_rank(matched_weight, supported_ranks))
{
NGRAPH_DEBUG << "Weights (mat_fusion_v2) " << matched_weight << " isn't 2D or 3D";
continue;
}
if (!ngraph::is_valid_rank(matched_data, supported_ranks))
{
NGRAPH_DEBUG << "Data (mat_fusion_v2) " << matched_data << " isn't 2D or 3D";
continue;
}
map_weights_to_pattern[matched_weight].push_back(matcher_v2->get_match_root()); map_weights_to_pattern[matched_weight].push_back(matcher_v2->get_match_root());
map_weights_bias_to_data[std::make_pair(matched_weight, matched_bias)].push_back( map_weights_bias_to_data[std::make_pair(matched_weight, matched_bias)].push_back(
matched_data); matched_data);
...@@ -241,6 +255,7 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi ...@@ -241,6 +255,7 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
concated_data, data_order, Shape{data_shape[0] * data_shape[1], data_shape[2]}); concated_data, data_order, Shape{data_shape[0] * data_shape[1], data_shape[2]});
} }
auto new_input_node = data_shape.size() == 2 ? concated_data : input_reshape_node; auto new_input_node = data_shape.size() == 2 ? concated_data : input_reshape_node;
NGRAPH_ASSERT(new_input_node);
auto w_reshape_node = std::make_shared<op::Reshape>( auto w_reshape_node = std::make_shared<op::Reshape>(
weights, AxisVector{1, 0}, Shape{w_shape[1], w_shape[0]}); weights, AxisVector{1, 0}, Shape{w_shape[1], w_shape[0]});
auto new_dot = std::make_shared<op::Dot>(new_input_node, w_reshape_node); auto new_dot = std::make_shared<op::Dot>(new_input_node, w_reshape_node);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment