Commit 85f04dfb authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

add check to make sure we don't replace unreachable nodes (#1039)

* add assert to make sure we don't replace unreachable nodes

* fix unittest failures

* sparsity fix
parent 4847b2de
......@@ -112,6 +112,11 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
throw ngraph_error("Result nodes cannot be replaced.");
}
if (target->get_users().size() == 0)
{
throw ngraph_error("replacing an unreachable node");
}
// Fix input/output descriptors
assert(target->get_outputs().size() == replacement->get_outputs().size());
......
......@@ -261,6 +261,13 @@ void pass::CoreFusion::construct_optimized_strided_conv()
NGRAPH_DEBUG << "In a callback for construct_conv_skip against "
<< m.get_match_root()->get_name();
if (m.get_match_root()->get_users().empty())
{
NGRAPH_DEBUG << m.get_match_root()
<< " has already been replaced by a preceding callback";
return false;
}
auto pattern_map = m.get_pattern_map();
auto m_eltwise = pattern_map[eltwise_label];
auto strided_convs = m_eltwise->get_users();
......
......@@ -111,13 +111,16 @@ void ngraph::runtime::cpu::pass::ConcatInputs::concat_lstm_inputs()
{
auto goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(goes->get_node());
lstm_outputs.insert(goes->get_node());
// first output node of lstm
//first output node of lstm
if (goe_node->get_n() == 0)
{
NGRAPH_DEBUG << "Replacing 1st output Lstm node " << goe_node->get_name()
<< " with " << lstm_ht_out->get_name();
if (goe_node->get_users().size() > 0)
{
ngraph::replace_node(goe_node, lstm_ht_out);
}
}
else if (goe_node->get_n() == 1)
{
for (auto& goe_ct_user : goe_node->get_users())
......
......@@ -815,6 +815,10 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
for (auto& rnn_goes : rnn_node->get_users())
{
NGRAPH_DEBUG << "rnn_goes: " << rnn_goes->get_name();
if (rnn_goes->get_users().size() == 0)
{
continue;
}
if (auto rnn_goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(rnn_goes))
{
if (rnn_goe_node->get_n() == 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment