Unverified Commit fd7ee58c authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Always use control deps in sorting (#4315)

* Always use control deps in sorting

* Fix build

* Deprecate old topological sort
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 2c7dd226
......@@ -231,7 +231,6 @@ namespace ngraph
ngraph::traverse_nodes(
ng_node_vector,
[&tag](std::shared_ptr<ngraph::Node> ng_node) { ng_node->add_provenance_tag(tag); },
false,
ng_inputs);
}
} // namespace onnx_import
......
......@@ -88,21 +88,19 @@ void Function::init()
{
validate_nodes_and_infer_types();
traverse_nodes(this,
[&](shared_ptr<Node> node) {
if (node->is_parameter())
{
auto it = std::find(m_parameters.begin(), m_parameters.end(), node);
if (it == m_parameters.end())
{
throw ngraph_error("Function references undeclared parameter");
}
}
},
true /*include control dependencies*/);
traverse_nodes(this, [&](shared_ptr<Node> node) {
if (node->is_parameter())
{
auto it = std::find(m_parameters.begin(), m_parameters.end(), node);
if (it == m_parameters.end())
{
throw ngraph_error("Function references undeclared parameter");
}
}
});
}
std::vector<shared_ptr<Node>> Function::get_ordered_ops(bool include_control_deps) const
std::vector<shared_ptr<Node>> Function::get_ordered_ops() const
{
vector<shared_ptr<Node>> nodes;
for (auto& r : get_results())
......@@ -114,7 +112,7 @@ std::vector<shared_ptr<Node>> Function::get_ordered_ops(bool include_control_dep
nodes.push_back(param);
}
return m_topological_sorter(nodes, include_control_deps);
return m_topological_sorter(nodes);
}
void Function::map_unordered_ops(std::function<void(Node*)> f) const
......@@ -229,10 +227,10 @@ shared_ptr<Node> Function::get_result() const
return m_results.at(0);
}
std::vector<shared_ptr<Node>> Function::get_ops(bool include_control_deps) const
std::vector<shared_ptr<Node>> Function::get_ops() const
{
std::vector<std::shared_ptr<Node>> ops;
traverse_nodes(this, [&](shared_ptr<Node> node) { ops.push_back(node); }, include_control_deps);
traverse_nodes(this, [&](shared_ptr<Node> node) { ops.push_back(node); });
return ops;
}
......
......@@ -92,8 +92,8 @@ namespace ngraph
/// \returns A const reference to the function's friendly name.
const std::string& get_friendly_name() const;
std::vector<std::shared_ptr<Node>> get_ops(bool include_control_deps = true) const;
std::vector<std::shared_ptr<Node>> get_ordered_ops(bool include_control_deps = true) const;
std::vector<std::shared_ptr<Node>> get_ops() const;
std::vector<std::shared_ptr<Node>> get_ordered_ops() const;
void map_unordered_ops(std::function<void(Node*)> f) const;
friend std::ostream& operator<<(std::ostream&, const Function&);
......@@ -127,7 +127,7 @@ namespace ngraph
const std::shared_ptr<op::Parameter>& parameter);
using topological_sort_t = std::function<std::vector<std::shared_ptr<Node>>(
const std::vector<std::shared_ptr<Node>>& root_nodes, bool include_control_deps)>;
const std::vector<std::shared_ptr<Node>>& root_nodes)>;
void set_topological_sort(topological_sort_t);
protected:
......
......@@ -38,15 +38,12 @@ using namespace std;
using namespace ngraph;
void ngraph::traverse_nodes(const std::shared_ptr<const Function> p,
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps)
std::function<void(std::shared_ptr<Node>)> f)
{
traverse_nodes(p.get(), f, include_control_deps);
traverse_nodes(p.get(), f);
}
void ngraph::traverse_nodes(const Function* p,
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps)
void ngraph::traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f)
{
NodeVector nodes;
......@@ -60,12 +57,11 @@ void ngraph::traverse_nodes(const Function* p,
nodes.push_back(param);
}
traverse_nodes(nodes, f, include_control_deps);
traverse_nodes(nodes, f);
}
void ngraph::traverse_nodes(const NodeVector& subgraph_results,
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps,
const NodeVector& subgraph_params)
{
std::unordered_set<Node*> instances_seen;
......@@ -91,17 +87,22 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results,
stack.push(n->get_input_node_ptr(i));
}
if (include_control_deps)
for (auto& cdep : n->get_control_dependencies())
{
for (auto& cdep : n->get_control_dependencies())
{
stack.push(cdep.get());
}
stack.push(cdep.get());
}
}
}
}
void ngraph::traverse_nodes(const NodeVector& subgraph_results,
std::function<void(std::shared_ptr<Node>)> f,
bool,
const NodeVector& subgraph_params)
{
traverse_nodes(subgraph_results, f, subgraph_params);
}
NodeVector ngraph::find_common_args(std::shared_ptr<Node> node1, std::shared_ptr<Node> node2)
{
std::unordered_set<std::shared_ptr<Node>> node1_args;
......@@ -110,7 +111,7 @@ NodeVector ngraph::find_common_args(std::shared_ptr<Node> node1, std::shared_ptr
node1_args.insert(node);
};
traverse_nodes({node1}, compute_node1_args, false, NodeVector{});
traverse_nodes({node1}, compute_node1_args, NodeVector{});
std::unordered_set<std::shared_ptr<Node>> node2_args;
......@@ -118,7 +119,7 @@ NodeVector ngraph::find_common_args(std::shared_ptr<Node> node1, std::shared_ptr
node2_args.insert(node);
};
traverse_nodes({node2}, compute_node2_args, false, NodeVector{});
traverse_nodes({node2}, compute_node2_args, NodeVector{});
NodeVector common_args;
for (auto e : node1_args)
......@@ -170,14 +171,14 @@ void ngraph::replace_node(std::shared_ptr<Node> target,
}
};
traverse_nodes({target}, set_replacement_prov, false, common_args);
traverse_nodes({target}, set_replacement_prov, common_args);
replacement->add_provenance_tags(removed_subgraph_tags);
auto set_prov_new_nodes = [&removed_subgraph_tags](std::shared_ptr<Node> node) {
node->add_provenance_tags(removed_subgraph_tags);
};
traverse_nodes({replacement}, set_prov_new_nodes, false, common_args);
traverse_nodes({replacement}, set_prov_new_nodes, common_args);
}
// For each of target's output O with replacement output O_rep:
......@@ -295,7 +296,7 @@ std::vector<std::shared_ptr<ngraph::Node>>
ngraph::clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map)
{
// for each node in topological order
auto sorted_nodes = topological_sort(nodes, true);
auto sorted_nodes = topological_sort(nodes);
for (auto node : sorted_nodes)
{
if (node_map.count(node.get()) == 0)
......@@ -349,7 +350,7 @@ std::list<std::shared_ptr<ngraph::Node>>
RawNodeOutputMap& output_map)
{
// for each node in topological order
auto sorted_nodes = topological_sort(nodes, true);
auto sorted_nodes = topological_sort(nodes);
std::list<shared_ptr<Node>> cloned_nodes;
for (auto node : sorted_nodes)
{
......@@ -418,7 +419,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
NodeMap& node_map)
{
// clone function operations
clone_nodes(func.get_ops(true), node_map);
clone_nodes(func.get_ops(), node_map);
// get cloned function results and parameters
ResultVector cloned_results;
......@@ -640,7 +641,7 @@ NodeVector ngraph::get_subgraph_outputs(const NodeVector& nodes,
NodeVector ngraph::extract_subgraph(const NodeVector& results, const NodeVector& args)
{
NodeVector subgraph;
traverse_nodes(results, [&](std::shared_ptr<Node> n) { subgraph.push_back(n); }, true, args);
traverse_nodes(results, [&](std::shared_ptr<Node> n) { subgraph.push_back(n); }, args);
return subgraph;
}
......
......@@ -48,18 +48,13 @@ namespace ngraph
}
void traverse_nodes(const std::shared_ptr<const Function> p,
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps = false);
std::function<void(std::shared_ptr<Node>)> f);
void traverse_nodes(const Function* p,
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps);
void traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f);
/// \brief Visit each node in a sub-graph of the entire graph
/// \param subgraph_results The output nodes of the sub-graph
/// \param f Function to execute at each node in the traversal
/// \param include_control_deps Whether to include control deps
/// while traversing the sub-graph
/// \param subgraph_params Input nodes of the sub-graph (optional)
///
/// Traverses a sub-graph starting from subgraph_results moving up
......@@ -71,9 +66,14 @@ namespace ngraph
/// subgraph relevant to the computation of certain outputs
void traverse_nodes(const NodeVector& subgraph_results,
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps,
const NodeVector& subgraph_params = {});
void traverse_nodes(const NodeVector& subgraph_results,
std::function<void(std::shared_ptr<Node>)> f,
bool,
const NodeVector& subgraph_params = {})
NGRAPH_DEPRECATED("Use traverse_nodes without control-deps option");
void traverse_functions(std::shared_ptr<Function> p,
std::function<void(std::shared_ptr<Function>)> f)
NGRAPH_DEPRECATED("Replace with f(p)");
......@@ -258,8 +258,7 @@ namespace ngraph
/// Topological sort of nodes needed to compute root_nodes
template <typename T>
std::vector<std::shared_ptr<Node>> topological_sort(T root_nodes,
bool include_control_deps = false)
std::vector<std::shared_ptr<Node>> topological_sort(T root_nodes)
{
std::stack<Node*, std::vector<Node*>> nodes_to_do;
std::unordered_set<Node*> nodes_done;
......@@ -285,16 +284,13 @@ namespace ngraph
nodes_to_do.push(dep);
}
}
if (include_control_deps)
for (auto& depptr : node->get_control_dependencies())
{
for (auto& depptr : node->get_control_dependencies())
Node* dep = depptr.get();
if (nodes_done.count(dep) == 0)
{
Node* dep = depptr.get();
if (nodes_done.count(dep) == 0)
{
can_add = false;
nodes_to_do.push(dep);
}
can_add = false;
nodes_to_do.push(dep);
}
}
if (can_add)
......@@ -314,8 +310,7 @@ namespace ngraph
/// Topological sort of just nodes
template <typename T>
std::vector<std::shared_ptr<Node>> subgraph_topological_sort(T nodes,
bool include_control_deps = false)
std::vector<std::shared_ptr<Node>> subgraph_topological_sort(T nodes)
{
std::stack<Node*, std::vector<Node*>> nodes_to_do;
std::unordered_set<Node*> nodes_done;
......@@ -345,16 +340,13 @@ namespace ngraph
nodes_to_do.push(dep);
}
}
if (include_control_deps)
for (auto& depptr : node->get_control_dependencies())
{
for (auto& depptr : node->get_control_dependencies())
Node* dep = depptr.get();
if (nodes_done.count(dep) == 0)
{
Node* dep = depptr.get();
if (nodes_done.count(dep) == 0)
{
can_add = false;
nodes_to_do.push(dep);
}
can_add = false;
nodes_to_do.push(dep);
}
}
if (can_add)
......
......@@ -497,14 +497,14 @@ void Node::transfer_provenance_tags(const shared_ptr<Node>& replacement)
}
};
traverse_nodes({shared_from_this()}, set_replacement_prov, false, common_args);
traverse_nodes({shared_from_this()}, set_replacement_prov, common_args);
replacement->add_provenance_tags(removed_subgraph_tags);
auto set_prov_new_nodes = [&removed_subgraph_tags](std::shared_ptr<Node> node) {
node->add_provenance_tags(removed_subgraph_tags);
};
traverse_nodes({replacement}, set_prov_new_nodes, false, common_args);
traverse_nodes({replacement}, set_prov_new_nodes, common_args);
}
std::shared_ptr<Node> Node::get_argument(size_t index) const
......
......@@ -453,7 +453,7 @@ json JSONSerializer::serialize_function(const Function& f)
}
json nodes;
for (shared_ptr<Node> node : f.get_ordered_ops(true))
for (shared_ptr<Node> node : f.get_ordered_ops())
{
nodes.push_back(serialize_node(*node));
}
......
......@@ -207,14 +207,12 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// Traverse bprop to find all of the nodes in the bprop graph
std::set<Output<Node>> in_bprop;
ngraph::traverse_nodes(bprop,
[&in_bprop](std::shared_ptr<Node> node) {
for (auto value : node->outputs())
{
in_bprop.insert(value);
}
},
false /* no control dependencies */);
ngraph::traverse_nodes(bprop, [&in_bprop](std::shared_ptr<Node> node) {
for (auto value : node->outputs())
{
in_bprop.insert(value);
}
});
// Traverse fprop to make a map that stores parameters with the same
// shape and element type as the nodes in fprop iff they are in bprop
......@@ -299,8 +297,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
{
fprop_cache.fprop_output_nodes.push_back(inverted_node_map.at(Output<Node>(node)));
}
},
false /* no control dependencies */);
});
// create the new outputs for fprop and the new fprop function
ResultVector fprop_outputs = fprop->get_results();
......
......@@ -456,7 +456,7 @@ TEST(provenance, fused_decomposition_tag)
EXPECT_TRUE(tags.find(tag) != tags.end());
};
const auto decomposed_op = f->get_result()->input(0).get_source_output().get_node_shared_ptr();
traverse_nodes(as_node_vector(decomposed_op->outputs()), tag_check, false, {p1});
traverse_nodes(as_node_vector(decomposed_op->outputs()), tag_check, {p1});
}
TEST(provenance, topk_setk)
......@@ -579,10 +579,8 @@ TEST(provenance, opset1_upgrade_pass_topk)
auto tags = node->get_provenance_tags();
EXPECT_TRUE(tags.find(tag) != tags.end());
};
traverse_nodes(as_node_vector(topk_v1->outputs()),
tag_check,
false,
as_node_vector(topk_v0->input_values()));
traverse_nodes(
as_node_vector(topk_v1->outputs()), tag_check, as_node_vector(topk_v0->input_values()));
}
TEST(provenance, opset0_downgrade_pass_topk)
......@@ -614,10 +612,8 @@ TEST(provenance, opset0_downgrade_pass_topk)
auto tags = node->get_provenance_tags();
EXPECT_TRUE(tags.find(tag) != tags.end());
};
traverse_nodes(as_node_vector(topk_v0->outputs()),
tag_check,
false,
as_node_vector(topk_v1->input_values()));
traverse_nodes(
as_node_vector(topk_v0->outputs()), tag_check, as_node_vector(topk_v1->input_values()));
}
TEST(provenance, opset1_upgrade_pass_graph)
......
......@@ -390,7 +390,7 @@ TEST(graph_util, test_subgraph_topological_sort_control_dependencies)
add->add_control_dependency(E);
auto mul = C * add;
auto result = make_shared<op::Result>(mul);
auto sorted = ngraph::subgraph_topological_sort(NodeVector{mul, add, A, D}, true);
auto sorted = ngraph::subgraph_topological_sort(NodeVector{mul, add, A, D});
std::vector<std::shared_ptr<Node>> expected{A, D, add, mul};
ASSERT_EQ(expected, sorted);
}
......@@ -701,11 +701,11 @@ TEST(util, topological_sort_replace)
auto f = make_shared<Function>(A + B + C, ParameterVector{A, B, C});
bool custom_sorter_used = false;
f->set_topological_sort([&custom_sorter_used](
const std::vector<std::shared_ptr<Node>>& root_nodes, bool include_control_deps) {
custom_sorter_used = true;
return topological_sort(root_nodes, include_control_deps);
});
f->set_topological_sort(
[&custom_sorter_used](const std::vector<std::shared_ptr<Node>>& root_nodes) {
custom_sorter_used = true;
return topological_sort(root_nodes);
});
// Need to now call topological sort but don't care about the results
f->get_ordered_ops();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment