Commit 85f04dfb authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

add check to make sure we don't replace unreachable nodes (#1039)

* add assert to make sure we don't replace unreachable nodes

* fix unittest failures

* sparsity fix
parent 4847b2de
...@@ -112,6 +112,11 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re ...@@ -112,6 +112,11 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
throw ngraph_error("Result nodes cannot be replaced."); throw ngraph_error("Result nodes cannot be replaced.");
} }
if (target->get_users().size() == 0)
{
throw ngraph_error("replacing an unreachable node");
}
// Fix input/output descriptors // Fix input/output descriptors
assert(target->get_outputs().size() == replacement->get_outputs().size()); assert(target->get_outputs().size() == replacement->get_outputs().size());
......
...@@ -261,6 +261,13 @@ void pass::CoreFusion::construct_optimized_strided_conv() ...@@ -261,6 +261,13 @@ void pass::CoreFusion::construct_optimized_strided_conv()
NGRAPH_DEBUG << "In a callback for construct_conv_skip against " NGRAPH_DEBUG << "In a callback for construct_conv_skip against "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
if (m.get_match_root()->get_users().empty())
{
NGRAPH_DEBUG << m.get_match_root()
<< " has already been replaced by a preceding callback";
return false;
}
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto m_eltwise = pattern_map[eltwise_label]; auto m_eltwise = pattern_map[eltwise_label];
auto strided_convs = m_eltwise->get_users(); auto strided_convs = m_eltwise->get_users();
......
...@@ -111,12 +111,15 @@ void ngraph::runtime::cpu::pass::ConcatInputs::concat_lstm_inputs() ...@@ -111,12 +111,15 @@ void ngraph::runtime::cpu::pass::ConcatInputs::concat_lstm_inputs()
{ {
auto goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(goes->get_node()); auto goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(goes->get_node());
lstm_outputs.insert(goes->get_node()); lstm_outputs.insert(goes->get_node());
// first output node of lstm //first output node of lstm
if (goe_node->get_n() == 0) if (goe_node->get_n() == 0)
{ {
NGRAPH_DEBUG << "Replacing 1st output Lstm node " << goe_node->get_name() NGRAPH_DEBUG << "Replacing 1st output Lstm node " << goe_node->get_name()
<< " with " << lstm_ht_out->get_name(); << " with " << lstm_ht_out->get_name();
ngraph::replace_node(goe_node, lstm_ht_out); if (goe_node->get_users().size() > 0)
{
ngraph::replace_node(goe_node, lstm_ht_out);
}
} }
else if (goe_node->get_n() == 1) else if (goe_node->get_n() == 1)
{ {
......
...@@ -815,6 +815,10 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_ ...@@ -815,6 +815,10 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
for (auto& rnn_goes : rnn_node->get_users()) for (auto& rnn_goes : rnn_node->get_users())
{ {
NGRAPH_DEBUG << "rnn_goes: " << rnn_goes->get_name(); NGRAPH_DEBUG << "rnn_goes: " << rnn_goes->get_name();
if (rnn_goes->get_users().size() == 0)
{
continue;
}
if (auto rnn_goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(rnn_goes)) if (auto rnn_goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(rnn_goes))
{ {
if (rnn_goe_node->get_n() == 0) if (rnn_goe_node->get_n() == 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment