Commit 323d0b91 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

check if the nodes are reachable before returning the users of a node (#2069)

* - check if the nodes is reachable before returning the users of a node

* address PR comments

* test case fixes
parent 67c0488b
...@@ -179,9 +179,11 @@ namespace ngraph ...@@ -179,9 +179,11 @@ namespace ngraph
result_list.push_back(node_map[independent_node]); result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front(); independent_nodes.pop_front();
for (auto user_sp : independent_node->get_users()) for (auto& output : independent_node->get_outputs())
{
for (auto& input : output.get_inputs())
{ {
Node* user = user_sp.get(); auto user = input->get_raw_pointer_node();
node_dependency_count[user] -= 1; node_dependency_count[user] -= 1;
size_t count = node_dependency_count[user]; size_t count = node_dependency_count[user];
if (count == 0) if (count == 0)
...@@ -189,6 +191,7 @@ namespace ngraph ...@@ -189,6 +191,7 @@ namespace ngraph
independent_nodes.push_back(user); independent_nodes.push_back(user);
} }
} }
}
if (include_control_deps) if (include_control_deps)
{ {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "ngraph/autodiff/adjoints.hpp" #include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/descriptor/layout/tensor_layout.hpp" #include "ngraph/descriptor/layout/tensor_layout.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/parameter.hpp" #include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp" #include "ngraph/op/result.hpp"
...@@ -382,17 +383,27 @@ descriptor::Output* Node::get_output_to(const shared_ptr<Node>& dst) ...@@ -382,17 +383,27 @@ descriptor::Output* Node::get_output_to(const shared_ptr<Node>& dst)
throw ngraph_error("Error: dst is not one of self's output Node"); throw ngraph_error("Error: dst is not one of self's output Node");
} }
NodeVector Node::get_users() const NodeVector Node::get_users(bool check_is_used) const
{ {
NodeVector result; NodeVector result;
for (size_t i = 0; i < get_output_size(); ++i) for (size_t i = 0; i < get_output_size(); ++i)
{ {
for (auto input : get_output_inputs(i)) for (auto input : get_output_inputs(i))
{
if (check_is_used)
{
if (is_used(input->get_node().get()))
{
result.push_back(input->get_node());
}
}
else
{ {
result.push_back(input->get_node()); result.push_back(input->get_node());
} }
} }
}
return result; return result;
} }
......
...@@ -245,7 +245,7 @@ namespace ngraph ...@@ -245,7 +245,7 @@ namespace ngraph
descriptor::Output* get_output_to(const std::shared_ptr<Node>& dst); descriptor::Output* get_output_to(const std::shared_ptr<Node>& dst);
/// Get all the nodes that uses the current node /// Get all the nodes that uses the current node
NodeVector get_users() const; NodeVector get_users(bool check_is_used = false) const;
virtual std::shared_ptr<Node> get_default_value() const { return nullptr; } virtual std::shared_ptr<Node> get_default_value() const { return nullptr; }
/// Use instance ids for comparison instead of memory addresses to improve determinism /// Use instance ids for comparison instead of memory addresses to improve determinism
......
...@@ -382,6 +382,7 @@ TEST(graph_util, test_subgraph_topological_sort) ...@@ -382,6 +382,7 @@ TEST(graph_util, test_subgraph_topological_sort)
auto C = make_shared<op::Parameter>(element::f32, shape); auto C = make_shared<op::Parameter>(element::f32, shape);
auto add = A + B; auto add = A + B;
auto mul = C * add; auto mul = C * add;
auto result = make_shared<op::Result>(mul);
auto sorted = ngraph::subgraph_topological_sort(NodeVector{mul, add, A}); auto sorted = ngraph::subgraph_topological_sort(NodeVector{mul, add, A});
std::list<std::shared_ptr<Node>> expected{A, add, mul}; std::list<std::shared_ptr<Node>> expected{A, add, mul};
ASSERT_EQ(expected, sorted); ASSERT_EQ(expected, sorted);
...@@ -399,6 +400,7 @@ TEST(graph_util, test_subgraph_topological_sort_control_dependencies) ...@@ -399,6 +400,7 @@ TEST(graph_util, test_subgraph_topological_sort_control_dependencies)
add->add_control_dependency(D); add->add_control_dependency(D);
add->add_control_dependency(E); add->add_control_dependency(E);
auto mul = C * add; auto mul = C * add;
auto result = make_shared<op::Result>(mul);
auto sorted = ngraph::subgraph_topological_sort(NodeVector{mul, add, A, D}, true); auto sorted = ngraph::subgraph_topological_sort(NodeVector{mul, add, A, D}, true);
std::list<std::shared_ptr<Node>> expected{A, D, add, mul}; std::list<std::shared_ptr<Node>> expected{A, D, add, mul};
ASSERT_EQ(expected, sorted); ASSERT_EQ(expected, sorted);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment