Commit 293ba8b7 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

subgraph_topological_sort (#1608)

parent e20876db
......@@ -127,6 +127,86 @@ namespace ngraph
return result_list;
}
template <typename T>
std::list<std::shared_ptr<Node>> subgraph_topological_sort(const T& nodes,
bool include_control_deps = false)
{
std::deque<ngraph::Node*> independent_nodes;
std::unordered_map<const ngraph::Node*, size_t> node_dependency_count;
std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>> node_map;
std::unordered_map<ngraph::Node*, std::set<Node*>> control_deps_users;
std::unordered_set<std::shared_ptr<ngraph::Node>> nodes_set(nodes.begin(), nodes.end());
for (auto node : nodes)
{
//build an equivalent of node->get_users() but for control dependencies
size_t deps_count = 0;
if (include_control_deps)
{
for (auto cd : node->get_control_dependencies())
{
if (nodes_set.count(cd) != 0)
{
control_deps_users[cd.get()].insert(node.get());
deps_count++;
}
}
}
node_map[node.get()] = node;
for (auto arg : node->get_arguments())
{
if (nodes_set.count(arg) != 0)
{
deps_count++;
}
}
node_dependency_count[node.get()] = deps_count;
if (deps_count == 0)
{
independent_nodes.push_back(node.get());
}
}
std::list<std::shared_ptr<ngraph::Node>> result_list;
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front();
for (auto user_sp : independent_node->get_users())
{
Node* user = user_sp.get();
node_dependency_count[user] -= 1;
size_t count = node_dependency_count[user];
if (count == 0)
{
independent_nodes.push_back(user);
}
}
if (include_control_deps)
{
auto cdit = control_deps_users.find(independent_node);
if (cdit != control_deps_users.end())
for (auto cd_user : cdit->second)
{
node_dependency_count[cd_user] -= 1;
size_t count = node_dependency_count[cd_user];
if (count == 0)
{
independent_nodes.push_back(cd_user);
}
}
}
}
NGRAPH_ASSERT(nodes.size() == result_list.size());
return result_list;
}
template <typename T>
void validate_nodes_and_infer_types(const T& nodes)
{
......
......@@ -384,3 +384,33 @@ TEST(util, test_fprop_cache)
EXPECT_EQ(fprop_cache.fprop->get_results().size(), 2);
EXPECT_EQ(fprop_cache.bprop->get_parameters().size(), 5);
}
TEST(graph_util, test_subgraph_topological_sort)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto add = A + B;
auto mul = C * add;
auto sorted = ngraph::subgraph_topological_sort(NodeVector{mul, add, A});
std::list<std::shared_ptr<Node>> expected{A, add, mul};
ASSERT_EQ(expected, sorted);
}
TEST(graph_util, test_subgraph_topological_sort_control_dependencies)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto D = make_shared<op::Abs>(A);
auto E = make_shared<op::Abs>(B);
auto add = A + B;
add->add_control_dependency(D);
add->add_control_dependency(E);
auto mul = C * add;
auto sorted = ngraph::subgraph_topological_sort(NodeVector{mul, add, A, D}, true);
std::list<std::shared_ptr<Node>> expected{A, D, add, mul};
ASSERT_EQ(expected, sorted);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment