Commit 323d0b91 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

check if the nodes are reachable before returning the users of a node (#2069)

* - check if the nodes is reachable before returning the users of a node

* address PR comments

* test case fixes
parent 67c0488b
......@@ -179,14 +179,17 @@ namespace ngraph
result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front();
for (auto user_sp : independent_node->get_users())
for (auto& output : independent_node->get_outputs())
{
Node* user = user_sp.get();
node_dependency_count[user] -= 1;
size_t count = node_dependency_count[user];
if (count == 0)
for (auto& input : output.get_inputs())
{
independent_nodes.push_back(user);
auto user = input->get_raw_pointer_node();
node_dependency_count[user] -= 1;
size_t count = node_dependency_count[user];
if (count == 0)
{
independent_nodes.push_back(user);
}
}
}
......
......@@ -21,6 +21,7 @@
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/descriptor/layout/tensor_layout.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp"
......@@ -382,7 +383,7 @@ descriptor::Output* Node::get_output_to(const shared_ptr<Node>& dst)
throw ngraph_error("Error: dst is not one of self's output Node");
}
NodeVector Node::get_users() const
NodeVector Node::get_users(bool check_is_used) const
{
NodeVector result;
......@@ -390,7 +391,17 @@ NodeVector Node::get_users() const
{
for (auto input : get_output_inputs(i))
{
result.push_back(input->get_node());
if (check_is_used)
{
if (is_used(input->get_node().get()))
{
result.push_back(input->get_node());
}
}
else
{
result.push_back(input->get_node());
}
}
}
......
......@@ -245,7 +245,7 @@ namespace ngraph
descriptor::Output* get_output_to(const std::shared_ptr<Node>& dst);
/// Get all the nodes that uses the current node
NodeVector get_users() const;
NodeVector get_users(bool check_is_used = false) const;
virtual std::shared_ptr<Node> get_default_value() const { return nullptr; }
/// Use instance ids for comparison instead of memory addresses to improve determinism
......
......@@ -382,6 +382,7 @@ TEST(graph_util, test_subgraph_topological_sort)
auto C = make_shared<op::Parameter>(element::f32, shape);
auto add = A + B;
auto mul = C * add;
auto result = make_shared<op::Result>(mul);
auto sorted = ngraph::subgraph_topological_sort(NodeVector{mul, add, A});
std::list<std::shared_ptr<Node>> expected{A, add, mul};
ASSERT_EQ(expected, sorted);
......@@ -399,6 +400,7 @@ TEST(graph_util, test_subgraph_topological_sort_control_dependencies)
add->add_control_dependency(D);
add->add_control_dependency(E);
auto mul = C * add;
auto result = make_shared<op::Result>(mul);
auto sorted = ngraph::subgraph_topological_sort(NodeVector{mul, add, A, D}, true);
std::list<std::shared_ptr<Node>> expected{A, D, add, mul};
ASSERT_EQ(expected, sorted);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment