Commit 9b3d5732 authored by Scott Cyphers's avatar Scott Cyphers

Fix top sort

parent de37f9d3
......@@ -14,8 +14,10 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/descriptor/output.hpp"
#include <algorithm>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/node.hpp"
using namespace std;
......@@ -31,12 +33,20 @@ descriptor::Output::Output(Node* node, size_t index, const shared_ptr<Tensor>& t
// Add an input to the vector of inputs that use this output.
void descriptor::Output::add_input(Input* input)
{
m_inputs.insert(input);
// Keep the inputs in insertion order to keep sorts deterministic
if (find(m_inputs.begin(), m_inputs.end(), input) == m_inputs.end())
{
m_inputs.push_back(input);
}
}
void descriptor::Output::remove_input(Input* input)
{
m_inputs.erase(input);
auto it = find(m_inputs.begin(), m_inputs.end(), input);
if (it != m_inputs.end())
{
m_inputs.erase(it);
}
}
shared_ptr<Node> descriptor::Output::get_node() const
......
......@@ -17,7 +17,7 @@
#pragma once
#include <memory>
#include <set>
#include <vector>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/tensor.hpp"
......@@ -48,7 +48,7 @@ namespace ngraph
void set_tensor_ptr(const std::shared_ptr<Tensor>& tensor) { m_tensor = tensor; }
void add_input(Input* input);
void remove_input(Input* input);
const std::set<Input*>& get_inputs() const { return m_inputs; }
const std::vector<Input*>& get_inputs() const { return m_inputs; }
Tensor& get_tensor() const;
/// \return the shape of the output
......@@ -64,7 +64,7 @@ namespace ngraph
Node* m_node;
size_t m_index;
std::shared_ptr<Tensor> m_tensor;
std::set<Input*> m_inputs;
std::vector<Input*> m_inputs;
private:
Output(const Output&) = delete;
......
......@@ -81,12 +81,11 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results,
while (stack.size() > 0)
{
std::shared_ptr<Node> n = stack.front();
stack.pop_front();
if (instances_seen.count(n) == 0)
{
instances_seen.insert(n);
f(n);
}
stack.pop_front();
for (auto arg : n->get_arguments())
{
if (instances_seen.count(arg) == 0)
......@@ -106,6 +105,7 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results,
}
}
}
}
}
void ngraph::traverse_functions(std::shared_ptr<ngraph::Function> p,
......
......@@ -20,6 +20,7 @@
#include <functional>
#include <list>
#include <memory>
#include <stack>
#include <string>
#include <unordered_map>
#include <unordered_set>
......@@ -81,66 +82,53 @@ namespace ngraph
std::list<std::shared_ptr<Node>> topological_sort(const T& nodes,
bool include_control_deps = false)
{
std::deque<ngraph::Node*> independent_nodes;
std::unordered_map<const ngraph::Node*, size_t> node_dependency_count;
std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>> node_map;
std::unordered_map<ngraph::Node*, std::set<Node*>> control_deps_users;
std::stack<ngraph::Node*> nodes_to_do;
std::set<Node*> nodes_done;
std::list<std::shared_ptr<Node>> result;
for (auto node : nodes)
{
//build an equivalent of node->get_users() but for control dependencies
size_t control_deps_count = 0;
if (include_control_deps)
{
for (auto cd : node->get_control_dependencies())
{
control_deps_count++;
control_deps_users[cd.get()].insert(node.get());
}
nodes_to_do.push(node.get());
}
node_map[node.get()] = node;
size_t deps_count = node->get_input_size() + control_deps_count;
node_dependency_count[node.get()] = deps_count;
if (deps_count == 0)
while (nodes_to_do.size() > 0)
{
independent_nodes.push_back(node.get());
}
}
std::list<std::shared_ptr<ngraph::Node>> result_list;
while (independent_nodes.size() > 0)
Node* node = nodes_to_do.top();
if (nodes_done.count(node) != 0)
{
auto independent_node = independent_nodes.front();
result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front();
for (const std::shared_ptr<Node>& user : independent_node->get_users())
nodes_to_do.pop();
continue;
}
bool can_add = true;
size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i)
{
if (--node_dependency_count[user.get()] == 0)
Node* dep = node->input(arg_count - i - 1).get_source_output().get_node();
if (nodes_done.count(dep) == 0)
{
independent_nodes.push_back(user.get());
can_add = false;
nodes_to_do.push(dep);
}
}
if (include_control_deps)
{
auto cdit = control_deps_users.find(independent_node);
if (cdit != control_deps_users.end())
for (auto cd_user : cdit->second)
for (auto depptr : node->get_control_dependencies())
{
node_dependency_count[cd_user] -= 1;
size_t count = node_dependency_count[cd_user];
if (count == 0)
Node* dep = depptr.get();
if (nodes_done.count(dep) == 0)
{
independent_nodes.push_back(cd_user);
can_add = false;
nodes_to_do.push(dep);
}
}
}
if (can_add)
{
result.push_back(node->shared_from_this());
nodes_to_do.pop();
nodes_done.insert(node);
}
NGRAPH_CHECK(nodes.size() == result_list.size());
return result_list;
}
return result;
}
// For cases, where `nodes` is a subset of the entire graph
......
......@@ -344,7 +344,7 @@ shared_ptr<descriptor::Tensor> Node::get_output_tensor_ptr() const
return m_outputs.at(0).get_tensor_ptr();
}
const std::set<descriptor::Input*>& Node::get_output_inputs(size_t i) const
const std::vector<descriptor::Input*>& Node::get_output_inputs(size_t i) const
{
return m_outputs.at(i).get_inputs();
}
......
......@@ -257,7 +257,7 @@ namespace ngraph
"output, or update calling code not to assume only one output");
/// Returns the set of inputs using output i
const std::set<descriptor::Input*>& get_output_inputs(size_t i) const
const std::vector<descriptor::Input*>& get_output_inputs(size_t i) const
NGRAPH_DEPRECATED("use node->output(i).get_target_inputs() instead");
/// Returns the number of inputs for the op
......
......@@ -87,8 +87,7 @@ TEST(control_dependencies, cdep_ops)
make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn});
auto f = make_shared<Function>(cdop, ParameterVector{A, B});
auto nodes = f->get_ordered_ops(true);
ASSERT_EQ(nodes.back()->get_argument(0), cdop);
test_ordered_ops(f);
}
TEST(control_dependencies, two_cdep_ops)
......@@ -102,8 +101,7 @@ TEST(control_dependencies, two_cdep_ops)
std::set<std::shared_ptr<Node>>{absn, absn_c});
auto f = make_shared<Function>(cdop, ParameterVector{A, B, C});
auto nodes = f->get_ordered_ops(true);
ASSERT_EQ(nodes.back()->get_argument(0), cdop);
test_ordered_ops(f);
}
TEST(control_dependencies, two_cdep_ops_op_on_top)
......@@ -117,8 +115,7 @@ TEST(control_dependencies, two_cdep_ops_op_on_top)
auto absn_cdop = make_shared<op::Abs>(cdop);
auto f = make_shared<Function>(absn_cdop, ParameterVector{A, B});
auto nodes = f->get_ordered_ops(true);
ASSERT_EQ(nodes.back()->get_argument(0), absn_cdop);
test_ordered_ops(f);
}
TEST(control_dependencies, clone_function_cdop)
......@@ -129,6 +126,7 @@ TEST(control_dependencies, clone_function_cdop)
make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn});
auto f = make_shared<Function>(cdop, ParameterVector{A});
test_ordered_ops(f);
auto clone = ngraph::clone_function(*f.get());
auto matcher = std::make_shared<pattern::Matcher>(cdop);
auto cdop_clone = clone->get_results().at(0)->get_argument(0);
......
......@@ -313,3 +313,35 @@ std::shared_ptr<Function> make_function_from_file(const std::string& file_name)
return func;
}
#endif
::testing::AssertionResult test_ordered_ops(shared_ptr<Function> f)
{
set<shared_ptr<Node>> seen;
for (auto node : f->get_ordered_ops())
{
if (seen.count(node) > 0)
{
return ::testing::AssertionFailure() << "Duplication in ordered ops";
}
size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i)
{
shared_ptr<Node> dep = node->input(i).get_source_output().get_node_shared_ptr();
if (seen.count(dep) == 0)
{
return ::testing::AssertionFailure() << "Argument " << dep
<< " does not occur before op" << node;
}
}
for (shared_ptr<Node> dep : node->get_control_dependencies())
{
if (seen.count(dep) == 0)
{
return ::testing::AssertionFailure() << "Control dependency " << dep
<< " does not occur before op" << node;
}
}
seen.insert(node);
}
return ::testing::AssertionSuccess();
}
......@@ -25,6 +25,8 @@
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/descriptor/layout/tensor_layout.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/log.hpp"
......@@ -276,3 +278,5 @@ std::vector<T> read_binary_file(const std::string& path)
inputs_fs.read(reinterpret_cast<char*>(file_content.data()), size);
return file_content;
}
testing::AssertionResult test_ordered_ops(std::shared_ptr<ngraph::Function> f);
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment