Unverified Commit c488f12b authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Cyphers/topmaster (#3199)

* Stabilize node sorting and fix some bugs.

* Review comments

* Fix broken tests

* Implement traverse nodes with pointers

* Let sort gather nodes for get_ordered_ops

* Use stacks for stacks

* Keep control deps ordered

* Optimize subgraph sort

* Add unordered map over function ops

* Don't recheck children

* Use vectors in stacks, avoid std::list::size()
parent 8e6d8a99
......@@ -14,8 +14,10 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/descriptor/output.hpp"
#include <algorithm>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/node.hpp"
using namespace std;
......@@ -31,12 +33,20 @@ descriptor::Output::Output(Node* node, size_t index, const shared_ptr<Tensor>& t
// Add an input to the vector of inputs that use this output.
void descriptor::Output::add_input(Input* input)
{
m_inputs.insert(input);
// Keep the inputs in insertion order to keep sorts deterministic
if (find(m_inputs.begin(), m_inputs.end(), input) == m_inputs.end())
{
m_inputs.push_back(input);
}
}
void descriptor::Output::remove_input(Input* input)
{
m_inputs.erase(input);
auto it = find(m_inputs.begin(), m_inputs.end(), input);
if (it != m_inputs.end())
{
m_inputs.erase(it);
}
}
shared_ptr<Node> descriptor::Output::get_node() const
......
......@@ -17,7 +17,7 @@
#pragma once
#include <memory>
#include <set>
#include <vector>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/tensor.hpp"
......@@ -48,7 +48,7 @@ namespace ngraph
void set_tensor_ptr(const std::shared_ptr<Tensor>& tensor) { m_tensor = tensor; }
void add_input(Input* input);
void remove_input(Input* input);
const std::set<Input*>& get_inputs() const { return m_inputs; }
const std::vector<Input*>& get_inputs() const { return m_inputs; }
Tensor& get_tensor() const;
/// \return the shape of the output
......@@ -64,7 +64,7 @@ namespace ngraph
Node* m_node;
size_t m_index;
std::shared_ptr<Tensor> m_tensor;
std::set<Input*> m_inputs;
std::vector<Input*> m_inputs;
private:
Output(const Output&) = delete;
......
......@@ -97,7 +97,48 @@ void Function::init()
std::list<shared_ptr<Node>> Function::get_ordered_ops(bool include_control_deps) const
{
return topological_sort(get_ops(include_control_deps), include_control_deps);
NodeVector nodes;
for (auto& r : get_results())
{
nodes.push_back(r);
}
for (auto& param : get_parameters())
{
nodes.push_back(param);
}
return topological_sort(nodes, include_control_deps);
}
void Function::map_unordered_ops(std::function<void(Node*)> f) const
{
std::unordered_set<Node*> unordered_ops;
std::stack<Node*, std::vector<Node*>> remaining_ops;
for (auto& r : get_results())
{
remaining_ops.push(r.get());
}
for (auto& param : get_parameters())
{
remaining_ops.push(param.get());
}
while (remaining_ops.size() > 0)
{
Node* op = remaining_ops.top();
remaining_ops.pop();
if (unordered_ops.insert(op).second)
{
f(op);
for (size_t i = 0; i < op->get_input_size(); ++i)
{
remaining_ops.push(op->input(i).get_source_output().get_node());
}
for (auto& cdep : op->get_control_dependencies())
{
remaining_ops.push(cdep.get());
}
}
}
}
const std::string& Function::get_friendly_name() const
......
......@@ -88,6 +88,8 @@ namespace ngraph
std::list<std::shared_ptr<Node>> get_ops(bool include_control_deps = true) const;
std::list<std::shared_ptr<Node>> get_ordered_ops(bool include_control_deps = true) const;
void map_unordered_ops(std::function<void(Node*)> f) const;
friend std::ostream& operator<<(std::ostream&, const Function&);
size_t get_instance_id() { return m_instance_id; }
size_t get_temporary_pool_size();
......
......@@ -14,7 +14,6 @@
// limitations under the License.
//*****************************************************************************
#include <deque>
#include <unordered_map>
#include <unordered_set>
#include <vector>
......@@ -69,42 +68,34 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results,
bool include_control_deps,
const NodeVector& subgraph_params)
{
std::unordered_set<std::shared_ptr<Node>> instances_seen{subgraph_params.begin(),
subgraph_params.end()};
std::deque<std::shared_ptr<Node>> stack;
for (auto r : subgraph_results)
std::unordered_set<Node*> instances_seen;
std::stack<Node*, std::vector<Node*>> stack;
for (auto& node_ptr : subgraph_params)
{
if (instances_seen.count(r) == 0)
{
stack.push_front(r);
}
instances_seen.insert(node_ptr.get());
}
for (auto& node_ptr : subgraph_results)
{
stack.push(node_ptr.get());
}
while (stack.size() > 0)
{
std::shared_ptr<Node> n = stack.front();
if (instances_seen.count(n) == 0)
{
instances_seen.insert(n);
f(n);
}
stack.pop_front();
for (auto arg : n->get_arguments())
Node* n = stack.top();
stack.pop();
if (instances_seen.insert(n).second)
{
if (instances_seen.count(arg) == 0)
f(n->shared_from_this());
for (auto& arg : n->get_arguments())
{
stack.push_front(arg);
stack.push(arg.get());
}
}
if (include_control_deps)
{
for (auto cdep : n->get_control_dependencies())
if (include_control_deps)
{
if (instances_seen.count(cdep) == 0)
for (auto& cdep : n->get_control_dependencies())
{
stack.push_front(cdep);
stack.push(cdep.get());
}
}
}
......@@ -182,25 +173,25 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
bool ngraph::is_post_dominated(Node* X, Node* Y)
{
std::unordered_set<Node*> visited;
std::deque<Node*> stack;
stack.push_front(X);
std::stack<Node*, std::vector<Node*>> stack;
stack.push(X);
while (stack.size() > 0)
{
ngraph::Node* curr = stack.front();
ngraph::Node* curr = stack.top();
visited.insert(curr);
if (curr->is_output())
{
return false;
}
stack.pop_front();
stack.pop();
if (curr != Y)
{
for (const auto& next : curr->get_users())
{
if (visited.count(next.get()) == 0)
{
stack.push_front(next.get());
stack.push(next.get());
}
}
}
......@@ -503,12 +494,12 @@ NodeVector ngraph::extract_subgraph(const NodeVector& results, const NodeVector&
bool ngraph::is_used(Node* node)
{
std::unordered_set<Node*> instances_seen;
std::deque<Node*> stack;
stack.push_front(node);
std::stack<Node*, std::vector<Node*>> stack;
stack.push(node);
while (stack.size() > 0)
{
ngraph::Node* n = stack.front();
ngraph::Node* n = stack.top();
if (instances_seen.count(n) == 0)
{
if (n->is_output())
......@@ -517,12 +508,12 @@ bool ngraph::is_used(Node* node)
}
instances_seen.insert(n);
}
stack.pop_front();
stack.pop();
for (const auto& arg : n->get_users())
{
if (instances_seen.count(arg.get()) == 0)
{
stack.push_front(arg.get());
stack.push(arg.get());
}
}
}
......
......@@ -20,6 +20,7 @@
#include <functional>
#include <list>
#include <memory>
#include <stack>
#include <string>
#include <unordered_map>
#include <unordered_set>
......@@ -81,154 +82,131 @@ namespace ngraph
NodeVector find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
/// Topological sort of nodes needed to compute root_nodes
template <typename T>
std::list<std::shared_ptr<Node>> topological_sort(const T& nodes,
std::list<std::shared_ptr<Node>> topological_sort(T root_nodes,
bool include_control_deps = false)
{
std::deque<ngraph::Node*> independent_nodes;
std::unordered_map<const ngraph::Node*, size_t> node_dependency_count;
std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>> node_map;
std::unordered_map<ngraph::Node*, std::set<Node*>> control_deps_users;
std::stack<Node*, std::vector<Node*>> nodes_to_do;
std::unordered_set<Node*> nodes_done;
std::list<std::shared_ptr<Node>> result;
for (auto node : nodes)
for (auto& node : root_nodes)
{
//build an equivalent of node->get_users() but for control dependencies
size_t control_deps_count = 0;
if (include_control_deps)
{
for (auto cd : node->get_control_dependencies())
{
control_deps_count++;
control_deps_users[cd.get()].insert(node.get());
}
}
node_map[node.get()] = node;
size_t deps_count = node->get_input_size() + control_deps_count;
node_dependency_count[node.get()] = deps_count;
if (deps_count == 0)
{
independent_nodes.push_back(node.get());
}
nodes_to_do.push(node.get());
}
std::list<std::shared_ptr<ngraph::Node>> result_list;
while (independent_nodes.size() > 0)
while (nodes_to_do.size() > 0)
{
auto independent_node = independent_nodes.front();
result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front();
for (const std::shared_ptr<Node>& user : independent_node->get_users())
Node* node = nodes_to_do.top();
if (nodes_done.count(node) == 0)
{
if (--node_dependency_count[user.get()] == 0)
bool can_add = true;
size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i)
{
independent_nodes.push_back(user.get());
Node* dep = node->input(arg_count - i - 1).get_source_output().get_node();
if (nodes_done.count(dep) == 0)
{
can_add = false;
nodes_to_do.push(dep);
}
}
}
if (include_control_deps)
{
auto cdit = control_deps_users.find(independent_node);
if (cdit != control_deps_users.end())
for (auto cd_user : cdit->second)
if (include_control_deps)
{
for (auto& depptr : node->get_control_dependencies())
{
node_dependency_count[cd_user] -= 1;
size_t count = node_dependency_count[cd_user];
if (count == 0)
Node* dep = depptr.get();
if (nodes_done.count(dep) == 0)
{
independent_nodes.push_back(cd_user);
can_add = false;
nodes_to_do.push(dep);
}
}
}
if (can_add)
{
result.push_back(node->shared_from_this());
nodes_to_do.pop();
nodes_done.insert(node);
}
}
else
{
nodes_to_do.pop();
}
}
NGRAPH_CHECK(nodes.size() == result_list.size());
return result_list;
return result;
}
// For cases, where `nodes` is a subset of the entire graph
/// Topological sort of just nodes
template <typename T>
std::list<std::shared_ptr<Node>> subgraph_topological_sort(const T& nodes,
std::list<std::shared_ptr<Node>> subgraph_topological_sort(T nodes,
bool include_control_deps = false)
{
std::deque<ngraph::Node*> independent_nodes;
std::unordered_map<const ngraph::Node*, size_t> node_dependency_count;
std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>> node_map;
std::unordered_map<ngraph::Node*, std::set<Node*>> control_deps_users;
std::unordered_set<std::shared_ptr<ngraph::Node>> nodes_set(nodes.begin(), nodes.end());
std::stack<Node*, std::vector<Node*>> nodes_to_do;
std::unordered_set<Node*> nodes_done;
std::unordered_set<Node*> nodes_to_emit;
std::list<std::shared_ptr<Node>> result;
for (auto node : nodes)
for (auto& node : nodes)
{
nodes_to_emit.insert(node.get());
nodes_to_do.push(node.get());
}
// NB: Some centos versions implement std::list::size() by counting elements
size_t nodes_remaining = nodes_to_emit.size();
while (nodes_to_do.size() > 0 && nodes_remaining > 0)
{
//build an equivalent of node->get_users() but for control dependencies
size_t deps_count = 0;
if (include_control_deps)
Node* node = nodes_to_do.top();
if (nodes_done.count(node) == 0)
{
for (auto cd : node->get_control_dependencies())
bool can_add = true;
size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i)
{
if (nodes_set.count(cd) != 0)
Node* dep = node->input(arg_count - i - 1).get_source_output().get_node();
if (nodes_done.count(dep) == 0)
{
control_deps_users[cd.get()].insert(node.get());
deps_count++;
can_add = false;
nodes_to_do.push(dep);
}
}
}
node_map[node.get()] = node;
for (auto arg : node->get_arguments())
{
if (nodes_set.count(arg) != 0)
if (include_control_deps)
{
deps_count++;
for (auto& depptr : node->get_control_dependencies())
{
Node* dep = depptr.get();
if (nodes_done.count(dep) == 0)
{
can_add = false;
nodes_to_do.push(dep);
}
}
}
}
node_dependency_count[node.get()] = deps_count;
if (deps_count == 0)
{
independent_nodes.push_back(node.get());
}
}
std::list<std::shared_ptr<ngraph::Node>> result_list;
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front();
for (const std::shared_ptr<Node>& user : independent_node->get_users())
{
if (--node_dependency_count[user.get()] == 0)
if (can_add)
{
independent_nodes.push_back(user.get());
if (nodes_to_emit.count(node) != 0)
{
result.push_back(node->shared_from_this());
nodes_remaining--;
}
nodes_to_do.pop();
nodes_done.insert(node);
}
}
if (include_control_deps)
else
{
auto cdit = control_deps_users.find(independent_node);
if (cdit != control_deps_users.end())
for (auto cd_user : cdit->second)
{
node_dependency_count[cd_user] -= 1;
size_t count = node_dependency_count[cd_user];
if (count == 0)
{
independent_nodes.push_back(cd_user);
}
}
nodes_to_do.pop();
}
}
NGRAPH_CHECK(nodes.size() == result_list.size());
return result_list;
return result;
}
template <typename T>
void validate_nodes_and_infer_types(const T& nodes)
{
for (auto node : subgraph_topological_sort(nodes))
for (auto& node : subgraph_topological_sort(nodes))
{
node->revalidate_and_infer_types();
}
......
......@@ -329,14 +329,18 @@ NodeVector Node::get_arguments() const
return result;
}
const std::set<std::shared_ptr<Node>>& Node::get_control_dependencies() const
const std::vector<std::shared_ptr<Node>>& Node::get_control_dependencies() const
{
return m_control_dependencies;
}
void Node::add_control_dependency(std::shared_ptr<Node> node)
{
m_control_dependencies.insert(node);
if (find(m_control_dependencies.begin(), m_control_dependencies.end(), node) ==
m_control_dependencies.end())
{
m_control_dependencies.push_back(node);
}
}
namespace ngraph
......@@ -444,7 +448,7 @@ shared_ptr<descriptor::Tensor> Node::get_output_tensor_ptr() const
return m_outputs.at(0).get_tensor_ptr();
}
const std::set<descriptor::Input*>& Node::get_output_inputs(size_t i) const
const std::vector<descriptor::Input*>& Node::get_output_inputs(size_t i) const
{
return m_outputs.at(i).get_inputs();
}
......
......@@ -236,13 +236,17 @@ namespace ngraph
NGRAPH_DEPRECATED("use outputs() instead");
/// Get control dependencies registered on the node
const std::set<std::shared_ptr<Node>>& get_control_dependencies() const;
const std::vector<std::shared_ptr<Node>>& get_control_dependencies() const;
void add_control_dependency(std::shared_ptr<Node> node);
void remove_control_dependency(std::shared_ptr<Node> node)
{
m_control_dependencies.erase(node);
auto it = find(m_control_dependencies.begin(), m_control_dependencies.end(), node);
if (it != m_control_dependencies.end())
{
m_control_dependencies.erase(it);
}
}
/// Returns the number of outputs from the node.
......@@ -295,7 +299,7 @@ namespace ngraph
"output, or update calling code not to assume only one output");
/// Returns the set of inputs using output i
const std::set<descriptor::Input*>& get_output_inputs(size_t i) const
const std::vector<descriptor::Input*>& get_output_inputs(size_t i) const
NGRAPH_DEPRECATED("use node->output(i).get_target_inputs() instead");
/// Returns the number of inputs for the op
......@@ -390,7 +394,7 @@ namespace ngraph
descriptor::Input& get_input_descriptor(size_t position);
descriptor::Output& get_output_descriptor(size_t position);
std::set<std::shared_ptr<Node>> m_control_dependencies;
std::vector<std::shared_ptr<Node>> m_control_dependencies;
const std::string m_node_type;
size_t m_instance_id{m_next_instance_id.fetch_add(1)};
std::string m_friendly_name;
......
......@@ -87,8 +87,7 @@ TEST(control_dependencies, cdep_ops)
make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn});
auto f = make_shared<Function>(cdop, ParameterVector{A, B});
auto nodes = f->get_ordered_ops(true);
ASSERT_EQ(nodes.back()->get_argument(0), cdop);
test_ordered_ops(f, NodeVector{absn});
}
TEST(control_dependencies, two_cdep_ops)
......@@ -102,8 +101,7 @@ TEST(control_dependencies, two_cdep_ops)
std::set<std::shared_ptr<Node>>{absn, absn_c});
auto f = make_shared<Function>(cdop, ParameterVector{A, B, C});
auto nodes = f->get_ordered_ops(true);
ASSERT_EQ(nodes.back()->get_argument(0), cdop);
test_ordered_ops(f, NodeVector{absn, absn_c});
}
TEST(control_dependencies, two_cdep_ops_op_on_top)
......@@ -117,8 +115,7 @@ TEST(control_dependencies, two_cdep_ops_op_on_top)
auto absn_cdop = make_shared<op::Abs>(cdop);
auto f = make_shared<Function>(absn_cdop, ParameterVector{A, B});
auto nodes = f->get_ordered_ops(true);
ASSERT_EQ(nodes.back()->get_argument(0), absn_cdop);
test_ordered_ops(f, NodeVector{absn, absn_b});
}
TEST(control_dependencies, clone_function_cdop)
......@@ -129,6 +126,7 @@ TEST(control_dependencies, clone_function_cdop)
make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn});
auto f = make_shared<Function>(cdop, ParameterVector{A});
test_ordered_ops(f, NodeVector{absn});
auto clone = ngraph::clone_function(*f.get());
auto matcher = std::make_shared<pattern::Matcher>(cdop);
auto cdop_clone = clone->get_results().at(0)->get_argument(0);
......
......@@ -540,7 +540,6 @@ TEST(util, enum_mask_operators)
TEST(graph, huge)
{
Function* f;
std::vector<std::weak_ptr<Node>> weak_nodes;
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 3});
......@@ -549,16 +548,12 @@ TEST(graph, huge)
{
n = make_shared<op::Negative>(n);
}
f = new Function(NodeVector{n}, ParameterVector{param});
for (auto node : f->get_ops())
{
weak_nodes.push_back(node);
}
auto f = make_shared<Function>(NodeVector{n}, ParameterVector{param});
f->map_unordered_ops(
[&weak_nodes](Node* node) { weak_nodes.push_back(node->shared_from_this()); });
}
delete f;
for (auto weak_node : weak_nodes)
for (auto& weak_node : weak_nodes)
{
EXPECT_TRUE(weak_node.expired());
}
......
......@@ -313,3 +313,44 @@ std::shared_ptr<Function> make_function_from_file(const std::string& file_name)
return func;
}
#endif
::testing::AssertionResult test_ordered_ops(shared_ptr<Function> f, const NodeVector& required_ops)
{
unordered_set<Node*> seen;
for (auto& node_ptr : f->get_ordered_ops())
{
Node* node = node_ptr.get();
if (seen.count(node) > 0)
{
return ::testing::AssertionFailure() << "Duplication in ordered ops";
}
size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i)
{
Node* dep = node->input(i).get_source_output().get_node();
if (seen.count(dep) == 0)
{
return ::testing::AssertionFailure() << "Argument " << *dep
<< " does not occur before op" << *node;
}
}
for (auto& dep_ptr : node->get_control_dependencies())
{
if (seen.count(dep_ptr.get()) == 0)
{
return ::testing::AssertionFailure() << "Control dependency " << *dep_ptr
<< " does not occur before op" << *node;
}
}
seen.insert(node);
}
for (auto& node_ptr : required_ops)
{
if (seen.count(node_ptr.get()) == 0)
{
return ::testing::AssertionFailure() << "Required op " << *node_ptr
<< "does not occur in ordered ops";
}
}
return ::testing::AssertionSuccess();
}
......@@ -25,6 +25,8 @@
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/descriptor/layout/tensor_layout.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/log.hpp"
......@@ -277,3 +279,6 @@ std::vector<T> read_binary_file(const std::string& path)
inputs_fs.read(reinterpret_cast<char*>(file_content.data()), size);
return file_content;
}
testing::AssertionResult test_ordered_ops(std::shared_ptr<ngraph::Function> f,
const ngraph::NodeVector& required_ops);
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment