Unverified Commit 55502f91 authored by adstraw's avatar adstraw Committed by GitHub

creating utilities for topological sort + graph and function clone (#289)

topological sort pass to call the utility function
clone util to be used by frameworks e.g. for fprop cache
parent aa357702
...@@ -27,38 +27,6 @@ using namespace std; ...@@ -27,38 +27,6 @@ using namespace std;
bool ngraph::pass::TopologicalSort::run_on_function(shared_ptr<ngraph::Function> func) bool ngraph::pass::TopologicalSort::run_on_function(shared_ptr<ngraph::Function> func)
{ {
list<shared_ptr<Node>> result_list; func->set_ordered_ops(topological_sort(func->get_ops()));
deque<Node*> independent_nodes;
unordered_map<const Node*, size_t> node_depencency_count;
unordered_map<Node*, shared_ptr<Node>> node_map;
traverse_nodes(func, [&](shared_ptr<Node> node) {
node_map[node.get()] = node;
node_depencency_count[node.get()] = node->get_arguments().size();
if (node->get_arguments().size() == 0)
{
independent_nodes.push_back(node.get());
}
});
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front();
for (auto user : independent_node->users())
{
node_depencency_count[user] -= 1;
size_t count = node_depencency_count[user];
if (count == 0)
{
independent_nodes.push_back(user);
}
}
}
func->set_ordered_ops(result_list);
return false; return false;
} }
...@@ -259,3 +259,108 @@ void ngraph::replace_node_users_arguments(std::shared_ptr<Node> target, ...@@ -259,3 +259,108 @@ void ngraph::replace_node_users_arguments(std::shared_ptr<Node> target,
} }
const_cast<std::multiset<Node*>&>(target->users()).clear(); const_cast<std::multiset<Node*>&>(target->users()).clear();
} }
std::list<std::shared_ptr<ngraph::Node>>
ngraph::topological_sort(const std::list<std::shared_ptr<Node>>& nodes)
{
deque<ngraph::Node*> independent_nodes;
unordered_map<const ngraph::Node*, size_t> node_depencency_count;
unordered_map<ngraph::Node*, shared_ptr<ngraph::Node>> node_map;
for (auto node : nodes)
{
node_map[node.get()] = node;
node_depencency_count[node.get()] = node->get_arguments().size();
if (node->get_arguments().size() == 0)
{
independent_nodes.push_back(node.get());
}
}
list<shared_ptr<ngraph::Node>> result_list;
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front();
for (auto user : independent_node->users())
{
node_depencency_count[user] -= 1;
size_t count = node_depencency_count[user];
if (count == 0)
{
independent_nodes.push_back(user);
}
}
}
return result_list;
}
void ngraph::NodeMap::Add(std::shared_ptr<ngraph::Node> orig,
std::shared_ptr<ngraph::Node> replacement)
{
if (Exists(orig))
{
throw ngraph_error("NodeMap: key already exists");
}
node_map_[orig] = replacement;
}
std::shared_ptr<ngraph::Node> ngraph::NodeMap::operator[](std::shared_ptr<ngraph::Node> orig) const
{
if (!Exists(orig))
{
throw ngraph_error("NodeMap: key does not exist");
}
return node_map_.at(orig);
}
std::list<std::shared_ptr<ngraph::Node>>
ngraph::clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map)
{
// for each node in topological order
auto sorted_nodes = topological_sort(nodes);
for (auto node : sorted_nodes)
{
if (!node_map.Exists(node))
{
// get (already) cloned arguments and clone the node
Nodes cloned_args;
for (auto arg : node->get_arguments())
{
cloned_args.push_back(node_map[arg]);
}
node_map.Add(node, node->copy_with_new_args(cloned_args));
}
}
// create and return list of cloned nodes
// order matches input list (not necessarily topological)
std::list<std::shared_ptr<ngraph::Node>> cloned_nodes;
for (auto node : nodes)
{
cloned_nodes.push_back(node_map[node]);
}
return cloned_nodes;
}
std::shared_ptr<ngraph::Function> ngraph::clone_function(std::shared_ptr<ngraph::Function> func,
NodeMap& node_map)
{
// clone function operations
clone_nodes(func->get_ops(), node_map);
// get cloned function result and parameters
auto cloned_result = node_map[func->get_result()];
std::vector<std::shared_ptr<op::Parameter>> cloned_params;
for (auto param : func->get_parameters())
{
cloned_params.push_back(std::dynamic_pointer_cast<op::Parameter>(node_map[param]));
}
// create and return cloned function
return std::make_shared<ngraph::Function>(
cloned_result, func->get_result_type(), cloned_params);
}
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
namespace ngraph namespace ngraph
...@@ -233,4 +234,42 @@ namespace ngraph ...@@ -233,4 +234,42 @@ namespace ngraph
void replace_node_users_arguments(std::shared_ptr<Node> target, void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement); std::shared_ptr<Node> replacement);
std::list<std::shared_ptr<Node>>
topological_sort(const std::list<std::shared_ptr<Node>>& nodes);
// maps original to replacement nodes e.g. for clone utilities
// performs index checking on access
class NodeMap
{
public:
// map original node to replcacement node
// throws ngraph_error if key already exists
void Add(std::shared_ptr<ngraph::Node> orig, std::shared_ptr<ngraph::Node> replacement);
// get replacement node from original node
// throws ngrah_error if key does not exist
std::shared_ptr<ngraph::Node> operator[](std::shared_ptr<ngraph::Node> orig) const;
// returns true if original node is already mapped
bool Exists(std::shared_ptr<ngraph::Node> orig) const
{
return (node_map_.count(orig) != 0);
}
private:
std::unordered_map<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>> node_map_;
};
// input nodes are cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned nodes
std::list<std::shared_ptr<ngraph::Node>>
clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map);
// input function is cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned function ops
std::shared_ptr<ngraph::Function> clone_function(std::shared_ptr<ngraph::Function> func,
NodeMap& node_map);
} // end namespace ngraph } // end namespace ngraph
...@@ -241,3 +241,91 @@ TEST(util, traverse_functions) ...@@ -241,3 +241,91 @@ TEST(util, traverse_functions)
traverse_functions(h, [&](shared_ptr<Function> fp) { functions.push_back(fp.get()); }); traverse_functions(h, [&](shared_ptr<Function> fp) { functions.push_back(fp.get()); });
ASSERT_EQ(3, functions.size()); ASSERT_EQ(3, functions.size());
} }
class CloneTest : public ::testing::Test
{
public:
// (A + B) * C
Shape shape = Shape{2, 2};
std::shared_ptr<op::Parameter> A =
make_shared<op::Parameter>(element::Float32::element_type(), shape);
std::shared_ptr<op::Parameter> B =
make_shared<op::Parameter>(element::Float32::element_type(), shape);
std::shared_ptr<op::Parameter> C =
make_shared<op::Parameter>(element::Float32::element_type(), shape);
std::shared_ptr<Node> AplusB = A + B;
std::shared_ptr<Node> AplusBtimesC = AplusB * C;
NodeMap node_map;
std::list<std::shared_ptr<ngraph::Node>> nodes;
std::shared_ptr<TensorViewType> type =
make_shared<TensorViewType>(element::Float32::element_type(), shape);
std::shared_ptr<Function> func =
make_shared<Function>(AplusBtimesC, type, op::Parameters{A, B, C}, "f");
void SetUp()
{
nodes.push_back(AplusBtimesC);
nodes.push_back(AplusB);
nodes.push_back(A);
nodes.push_back(B);
nodes.push_back(C);
}
bool CompareNodes(const std::list<std::shared_ptr<ngraph::Node>>& orig,
const std::list<std::shared_ptr<ngraph::Node>>& clone,
const NodeMap& nm)
{
if (orig.size() != clone.size())
{
return false;
}
auto origit = orig.begin();
auto cloneit = clone.begin();
while (origit != orig.end() && cloneit != clone.end())
{
if (*cloneit != nm[*origit])
{
return false;
}
++origit;
++cloneit;
}
return true;
}
};
TEST_F(CloneTest, clone_nodes_full)
{
auto cloned_nodes = clone_nodes(nodes, node_map);
ASSERT_TRUE(CompareNodes(nodes, cloned_nodes, node_map));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Parameter>(node_map[A]));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Parameter>(node_map[B]));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Parameter>(node_map[C]));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Add>(node_map[AplusB]));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Multiply>(node_map[AplusBtimesC]));
auto sorted_nodes = topological_sort(nodes);
auto sorted_cloned_nodes = topological_sort(cloned_nodes);
ASSERT_TRUE(CompareNodes(sorted_nodes, sorted_cloned_nodes, node_map));
}
TEST_F(CloneTest, clone_nodes_partial)
{
// map A -> A' prior to clone
auto Aprime = make_shared<op::Parameter>(element::Float32::element_type(), shape);
node_map.Add(A, Aprime);
auto cloned_nodes = clone_nodes(nodes, node_map);
ASSERT_TRUE(CompareNodes(nodes, cloned_nodes, node_map));
// ensure A -> A' after clone
ASSERT_EQ(Aprime, node_map[A]);
}
TEST_F(CloneTest, clone_function_full)
{
auto cloned_func = clone_function(func, node_map);
ASSERT_TRUE(CompareNodes(func->get_ops(), cloned_func->get_ops(), node_map));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment