Commit 9b3d5732 authored by Scott Cyphers's avatar Scott Cyphers

Fix top sort

parent de37f9d3
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "ngraph/descriptor/output.hpp" #include <algorithm>
#include "ngraph/descriptor/input.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
using namespace std; using namespace std;
...@@ -31,12 +33,20 @@ descriptor::Output::Output(Node* node, size_t index, const shared_ptr<Tensor>& t ...@@ -31,12 +33,20 @@ descriptor::Output::Output(Node* node, size_t index, const shared_ptr<Tensor>& t
// Add an input to the vector of inputs that use this output. // Add an input to the vector of inputs that use this output.
void descriptor::Output::add_input(Input* input) void descriptor::Output::add_input(Input* input)
{ {
m_inputs.insert(input); // Keep the inputs in insertion order to keep sorts deterministic
if (find(m_inputs.begin(), m_inputs.end(), input) == m_inputs.end())
{
m_inputs.push_back(input);
}
} }
void descriptor::Output::remove_input(Input* input) void descriptor::Output::remove_input(Input* input)
{ {
m_inputs.erase(input); auto it = find(m_inputs.begin(), m_inputs.end(), input);
if (it != m_inputs.end())
{
m_inputs.erase(it);
}
} }
shared_ptr<Node> descriptor::Output::get_node() const shared_ptr<Node> descriptor::Output::get_node() const
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <set> #include <vector>
#include "ngraph/descriptor/input.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
...@@ -48,7 +48,7 @@ namespace ngraph ...@@ -48,7 +48,7 @@ namespace ngraph
void set_tensor_ptr(const std::shared_ptr<Tensor>& tensor) { m_tensor = tensor; } void set_tensor_ptr(const std::shared_ptr<Tensor>& tensor) { m_tensor = tensor; }
void add_input(Input* input); void add_input(Input* input);
void remove_input(Input* input); void remove_input(Input* input);
const std::set<Input*>& get_inputs() const { return m_inputs; } const std::vector<Input*>& get_inputs() const { return m_inputs; }
Tensor& get_tensor() const; Tensor& get_tensor() const;
/// \return the shape of the output /// \return the shape of the output
...@@ -64,7 +64,7 @@ namespace ngraph ...@@ -64,7 +64,7 @@ namespace ngraph
Node* m_node; Node* m_node;
size_t m_index; size_t m_index;
std::shared_ptr<Tensor> m_tensor; std::shared_ptr<Tensor> m_tensor;
std::set<Input*> m_inputs; std::vector<Input*> m_inputs;
private: private:
Output(const Output&) = delete; Output(const Output&) = delete;
......
...@@ -81,12 +81,11 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results, ...@@ -81,12 +81,11 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results,
while (stack.size() > 0) while (stack.size() > 0)
{ {
std::shared_ptr<Node> n = stack.front(); std::shared_ptr<Node> n = stack.front();
stack.pop_front();
if (instances_seen.count(n) == 0) if (instances_seen.count(n) == 0)
{ {
instances_seen.insert(n); instances_seen.insert(n);
f(n); f(n);
}
stack.pop_front();
for (auto arg : n->get_arguments()) for (auto arg : n->get_arguments())
{ {
if (instances_seen.count(arg) == 0) if (instances_seen.count(arg) == 0)
...@@ -106,6 +105,7 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results, ...@@ -106,6 +105,7 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results,
} }
} }
} }
}
} }
void ngraph::traverse_functions(std::shared_ptr<ngraph::Function> p, void ngraph::traverse_functions(std::shared_ptr<ngraph::Function> p,
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <functional> #include <functional>
#include <list> #include <list>
#include <memory> #include <memory>
#include <stack>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -81,66 +82,53 @@ namespace ngraph ...@@ -81,66 +82,53 @@ namespace ngraph
std::list<std::shared_ptr<Node>> topological_sort(const T& nodes, std::list<std::shared_ptr<Node>> topological_sort(const T& nodes,
bool include_control_deps = false) bool include_control_deps = false)
{ {
std::deque<ngraph::Node*> independent_nodes; std::stack<ngraph::Node*> nodes_to_do;
std::unordered_map<const ngraph::Node*, size_t> node_dependency_count; std::set<Node*> nodes_done;
std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>> node_map; std::list<std::shared_ptr<Node>> result;
std::unordered_map<ngraph::Node*, std::set<Node*>> control_deps_users;
for (auto node : nodes) for (auto node : nodes)
{ {
//build an equivalent of node->get_users() but for control dependencies nodes_to_do.push(node.get());
size_t control_deps_count = 0;
if (include_control_deps)
{
for (auto cd : node->get_control_dependencies())
{
control_deps_count++;
control_deps_users[cd.get()].insert(node.get());
}
} }
while (nodes_to_do.size() > 0)
node_map[node.get()] = node;
size_t deps_count = node->get_input_size() + control_deps_count;
node_dependency_count[node.get()] = deps_count;
if (deps_count == 0)
{ {
independent_nodes.push_back(node.get()); Node* node = nodes_to_do.top();
} if (nodes_done.count(node) != 0)
}
std::list<std::shared_ptr<ngraph::Node>> result_list;
while (independent_nodes.size() > 0)
{ {
auto independent_node = independent_nodes.front(); nodes_to_do.pop();
result_list.push_back(node_map[independent_node]); continue;
independent_nodes.pop_front(); }
bool can_add = true;
for (const std::shared_ptr<Node>& user : independent_node->get_users()) size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i)
{ {
if (--node_dependency_count[user.get()] == 0) Node* dep = node->input(arg_count - i - 1).get_source_output().get_node();
if (nodes_done.count(dep) == 0)
{ {
independent_nodes.push_back(user.get()); can_add = false;
nodes_to_do.push(dep);
} }
} }
if (include_control_deps) if (include_control_deps)
{ {
auto cdit = control_deps_users.find(independent_node); for (auto depptr : node->get_control_dependencies())
if (cdit != control_deps_users.end())
for (auto cd_user : cdit->second)
{ {
node_dependency_count[cd_user] -= 1; Node* dep = depptr.get();
size_t count = node_dependency_count[cd_user]; if (nodes_done.count(dep) == 0)
if (count == 0)
{ {
independent_nodes.push_back(cd_user); can_add = false;
nodes_to_do.push(dep);
} }
} }
} }
if (can_add)
{
result.push_back(node->shared_from_this());
nodes_to_do.pop();
nodes_done.insert(node);
} }
}
NGRAPH_CHECK(nodes.size() == result_list.size()); return result;
return result_list;
} }
// For cases, where `nodes` is a subset of the entire graph // For cases, where `nodes` is a subset of the entire graph
......
...@@ -344,7 +344,7 @@ shared_ptr<descriptor::Tensor> Node::get_output_tensor_ptr() const ...@@ -344,7 +344,7 @@ shared_ptr<descriptor::Tensor> Node::get_output_tensor_ptr() const
return m_outputs.at(0).get_tensor_ptr(); return m_outputs.at(0).get_tensor_ptr();
} }
const std::set<descriptor::Input*>& Node::get_output_inputs(size_t i) const const std::vector<descriptor::Input*>& Node::get_output_inputs(size_t i) const
{ {
return m_outputs.at(i).get_inputs(); return m_outputs.at(i).get_inputs();
} }
......
...@@ -257,7 +257,7 @@ namespace ngraph ...@@ -257,7 +257,7 @@ namespace ngraph
"output, or update calling code not to assume only one output"); "output, or update calling code not to assume only one output");
/// Returns the set of inputs using output i /// Returns the set of inputs using output i
const std::set<descriptor::Input*>& get_output_inputs(size_t i) const const std::vector<descriptor::Input*>& get_output_inputs(size_t i) const
NGRAPH_DEPRECATED("use node->output(i).get_target_inputs() instead"); NGRAPH_DEPRECATED("use node->output(i).get_target_inputs() instead");
/// Returns the number of inputs for the op /// Returns the number of inputs for the op
......
...@@ -87,8 +87,7 @@ TEST(control_dependencies, cdep_ops) ...@@ -87,8 +87,7 @@ TEST(control_dependencies, cdep_ops)
make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn}); make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn});
auto f = make_shared<Function>(cdop, ParameterVector{A, B}); auto f = make_shared<Function>(cdop, ParameterVector{A, B});
auto nodes = f->get_ordered_ops(true); test_ordered_ops(f);
ASSERT_EQ(nodes.back()->get_argument(0), cdop);
} }
TEST(control_dependencies, two_cdep_ops) TEST(control_dependencies, two_cdep_ops)
...@@ -102,8 +101,7 @@ TEST(control_dependencies, two_cdep_ops) ...@@ -102,8 +101,7 @@ TEST(control_dependencies, two_cdep_ops)
std::set<std::shared_ptr<Node>>{absn, absn_c}); std::set<std::shared_ptr<Node>>{absn, absn_c});
auto f = make_shared<Function>(cdop, ParameterVector{A, B, C}); auto f = make_shared<Function>(cdop, ParameterVector{A, B, C});
auto nodes = f->get_ordered_ops(true); test_ordered_ops(f);
ASSERT_EQ(nodes.back()->get_argument(0), cdop);
} }
TEST(control_dependencies, two_cdep_ops_op_on_top) TEST(control_dependencies, two_cdep_ops_op_on_top)
...@@ -117,8 +115,7 @@ TEST(control_dependencies, two_cdep_ops_op_on_top) ...@@ -117,8 +115,7 @@ TEST(control_dependencies, two_cdep_ops_op_on_top)
auto absn_cdop = make_shared<op::Abs>(cdop); auto absn_cdop = make_shared<op::Abs>(cdop);
auto f = make_shared<Function>(absn_cdop, ParameterVector{A, B}); auto f = make_shared<Function>(absn_cdop, ParameterVector{A, B});
auto nodes = f->get_ordered_ops(true); test_ordered_ops(f);
ASSERT_EQ(nodes.back()->get_argument(0), absn_cdop);
} }
TEST(control_dependencies, clone_function_cdop) TEST(control_dependencies, clone_function_cdop)
...@@ -129,6 +126,7 @@ TEST(control_dependencies, clone_function_cdop) ...@@ -129,6 +126,7 @@ TEST(control_dependencies, clone_function_cdop)
make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn}); make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn});
auto f = make_shared<Function>(cdop, ParameterVector{A}); auto f = make_shared<Function>(cdop, ParameterVector{A});
test_ordered_ops(f);
auto clone = ngraph::clone_function(*f.get()); auto clone = ngraph::clone_function(*f.get());
auto matcher = std::make_shared<pattern::Matcher>(cdop); auto matcher = std::make_shared<pattern::Matcher>(cdop);
auto cdop_clone = clone->get_results().at(0)->get_argument(0); auto cdop_clone = clone->get_results().at(0)->get_argument(0);
......
...@@ -313,3 +313,35 @@ std::shared_ptr<Function> make_function_from_file(const std::string& file_name) ...@@ -313,3 +313,35 @@ std::shared_ptr<Function> make_function_from_file(const std::string& file_name)
return func; return func;
} }
#endif #endif
::testing::AssertionResult test_ordered_ops(shared_ptr<Function> f)
{
set<shared_ptr<Node>> seen;
for (auto node : f->get_ordered_ops())
{
if (seen.count(node) > 0)
{
return ::testing::AssertionFailure() << "Duplication in ordered ops";
}
size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i)
{
shared_ptr<Node> dep = node->input(i).get_source_output().get_node_shared_ptr();
if (seen.count(dep) == 0)
{
return ::testing::AssertionFailure() << "Argument " << dep
<< " does not occur before op" << node;
}
}
for (shared_ptr<Node> dep : node->get_control_dependencies())
{
if (seen.count(dep) == 0)
{
return ::testing::AssertionFailure() << "Control dependency " << dep
<< " does not occur before op" << node;
}
}
seen.insert(node);
}
return ::testing::AssertionSuccess();
}
...@@ -25,6 +25,8 @@ ...@@ -25,6 +25,8 @@
#include <random> #include <random>
#include <vector> #include <vector>
#include "gtest/gtest.h"
#include "ngraph/descriptor/layout/tensor_layout.hpp" #include "ngraph/descriptor/layout/tensor_layout.hpp"
#include "ngraph/file_util.hpp" #include "ngraph/file_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
...@@ -276,3 +278,5 @@ std::vector<T> read_binary_file(const std::string& path) ...@@ -276,3 +278,5 @@ std::vector<T> read_binary_file(const std::string& path)
inputs_fs.read(reinterpret_cast<char*>(file_content.data()), size); inputs_fs.read(reinterpret_cast<char*>(file_content.data()), size);
return file_content; return file_content;
} }
testing::AssertionResult test_ordered_ops(std::shared_ptr<ngraph::Function> f);
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment