Unverified Commit c488f12b authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Cyphers/topmaster (#3199)

* Stabilize node sorting and fix some bugs.

* Review comments

* Fix broken tests

* Implement traverse nodes with pointers

* Let sort gather nodes for get_ordered_ops

* Use stacks for stacks

* Keep control deps ordered

* Optimize subgraph sort

* Add unordered map over function ops

* Don't recheck children

* Use vectors in stacks, avoid std::list::size()
parent 8e6d8a99
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "ngraph/descriptor/output.hpp" #include <algorithm>
#include "ngraph/descriptor/input.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
using namespace std; using namespace std;
...@@ -31,12 +33,20 @@ descriptor::Output::Output(Node* node, size_t index, const shared_ptr<Tensor>& t ...@@ -31,12 +33,20 @@ descriptor::Output::Output(Node* node, size_t index, const shared_ptr<Tensor>& t
// Add an input to the vector of inputs that use this output. // Add an input to the vector of inputs that use this output.
void descriptor::Output::add_input(Input* input) void descriptor::Output::add_input(Input* input)
{ {
m_inputs.insert(input); // Keep the inputs in insertion order to keep sorts deterministic
if (find(m_inputs.begin(), m_inputs.end(), input) == m_inputs.end())
{
m_inputs.push_back(input);
}
} }
void descriptor::Output::remove_input(Input* input) void descriptor::Output::remove_input(Input* input)
{ {
m_inputs.erase(input); auto it = find(m_inputs.begin(), m_inputs.end(), input);
if (it != m_inputs.end())
{
m_inputs.erase(it);
}
} }
shared_ptr<Node> descriptor::Output::get_node() const shared_ptr<Node> descriptor::Output::get_node() const
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <set> #include <vector>
#include "ngraph/descriptor/input.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
...@@ -48,7 +48,7 @@ namespace ngraph ...@@ -48,7 +48,7 @@ namespace ngraph
void set_tensor_ptr(const std::shared_ptr<Tensor>& tensor) { m_tensor = tensor; } void set_tensor_ptr(const std::shared_ptr<Tensor>& tensor) { m_tensor = tensor; }
void add_input(Input* input); void add_input(Input* input);
void remove_input(Input* input); void remove_input(Input* input);
const std::set<Input*>& get_inputs() const { return m_inputs; } const std::vector<Input*>& get_inputs() const { return m_inputs; }
Tensor& get_tensor() const; Tensor& get_tensor() const;
/// \return the shape of the output /// \return the shape of the output
...@@ -64,7 +64,7 @@ namespace ngraph ...@@ -64,7 +64,7 @@ namespace ngraph
Node* m_node; Node* m_node;
size_t m_index; size_t m_index;
std::shared_ptr<Tensor> m_tensor; std::shared_ptr<Tensor> m_tensor;
std::set<Input*> m_inputs; std::vector<Input*> m_inputs;
private: private:
Output(const Output&) = delete; Output(const Output&) = delete;
......
...@@ -97,7 +97,48 @@ void Function::init() ...@@ -97,7 +97,48 @@ void Function::init()
std::list<shared_ptr<Node>> Function::get_ordered_ops(bool include_control_deps) const std::list<shared_ptr<Node>> Function::get_ordered_ops(bool include_control_deps) const
{ {
return topological_sort(get_ops(include_control_deps), include_control_deps); NodeVector nodes;
for (auto& r : get_results())
{
nodes.push_back(r);
}
for (auto& param : get_parameters())
{
nodes.push_back(param);
}
return topological_sort(nodes, include_control_deps);
}
void Function::map_unordered_ops(std::function<void(Node*)> f) const
{
std::unordered_set<Node*> unordered_ops;
std::stack<Node*, std::vector<Node*>> remaining_ops;
for (auto& r : get_results())
{
remaining_ops.push(r.get());
}
for (auto& param : get_parameters())
{
remaining_ops.push(param.get());
}
while (remaining_ops.size() > 0)
{
Node* op = remaining_ops.top();
remaining_ops.pop();
if (unordered_ops.insert(op).second)
{
f(op);
for (size_t i = 0; i < op->get_input_size(); ++i)
{
remaining_ops.push(op->input(i).get_source_output().get_node());
}
for (auto& cdep : op->get_control_dependencies())
{
remaining_ops.push(cdep.get());
}
}
}
} }
const std::string& Function::get_friendly_name() const const std::string& Function::get_friendly_name() const
......
...@@ -88,6 +88,8 @@ namespace ngraph ...@@ -88,6 +88,8 @@ namespace ngraph
std::list<std::shared_ptr<Node>> get_ops(bool include_control_deps = true) const; std::list<std::shared_ptr<Node>> get_ops(bool include_control_deps = true) const;
std::list<std::shared_ptr<Node>> get_ordered_ops(bool include_control_deps = true) const; std::list<std::shared_ptr<Node>> get_ordered_ops(bool include_control_deps = true) const;
void map_unordered_ops(std::function<void(Node*)> f) const;
friend std::ostream& operator<<(std::ostream&, const Function&); friend std::ostream& operator<<(std::ostream&, const Function&);
size_t get_instance_id() { return m_instance_id; } size_t get_instance_id() { return m_instance_id; }
size_t get_temporary_pool_size(); size_t get_temporary_pool_size();
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <deque>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -69,42 +68,34 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results, ...@@ -69,42 +68,34 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results,
bool include_control_deps, bool include_control_deps,
const NodeVector& subgraph_params) const NodeVector& subgraph_params)
{ {
std::unordered_set<std::shared_ptr<Node>> instances_seen{subgraph_params.begin(), std::unordered_set<Node*> instances_seen;
subgraph_params.end()}; std::stack<Node*, std::vector<Node*>> stack;
std::deque<std::shared_ptr<Node>> stack; for (auto& node_ptr : subgraph_params)
for (auto r : subgraph_results)
{ {
if (instances_seen.count(r) == 0) instances_seen.insert(node_ptr.get());
{ }
stack.push_front(r); for (auto& node_ptr : subgraph_results)
} {
stack.push(node_ptr.get());
} }
while (stack.size() > 0) while (stack.size() > 0)
{ {
std::shared_ptr<Node> n = stack.front(); Node* n = stack.top();
if (instances_seen.count(n) == 0) stack.pop();
{ if (instances_seen.insert(n).second)
instances_seen.insert(n);
f(n);
}
stack.pop_front();
for (auto arg : n->get_arguments())
{ {
if (instances_seen.count(arg) == 0) f(n->shared_from_this());
for (auto& arg : n->get_arguments())
{ {
stack.push_front(arg); stack.push(arg.get());
} }
}
if (include_control_deps) if (include_control_deps)
{
for (auto cdep : n->get_control_dependencies())
{ {
if (instances_seen.count(cdep) == 0) for (auto& cdep : n->get_control_dependencies())
{ {
stack.push_front(cdep); stack.push(cdep.get());
} }
} }
} }
...@@ -182,25 +173,25 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re ...@@ -182,25 +173,25 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
bool ngraph::is_post_dominated(Node* X, Node* Y) bool ngraph::is_post_dominated(Node* X, Node* Y)
{ {
std::unordered_set<Node*> visited; std::unordered_set<Node*> visited;
std::deque<Node*> stack; std::stack<Node*, std::vector<Node*>> stack;
stack.push_front(X); stack.push(X);
while (stack.size() > 0) while (stack.size() > 0)
{ {
ngraph::Node* curr = stack.front(); ngraph::Node* curr = stack.top();
visited.insert(curr); visited.insert(curr);
if (curr->is_output()) if (curr->is_output())
{ {
return false; return false;
} }
stack.pop_front(); stack.pop();
if (curr != Y) if (curr != Y)
{ {
for (const auto& next : curr->get_users()) for (const auto& next : curr->get_users())
{ {
if (visited.count(next.get()) == 0) if (visited.count(next.get()) == 0)
{ {
stack.push_front(next.get()); stack.push(next.get());
} }
} }
} }
...@@ -503,12 +494,12 @@ NodeVector ngraph::extract_subgraph(const NodeVector& results, const NodeVector& ...@@ -503,12 +494,12 @@ NodeVector ngraph::extract_subgraph(const NodeVector& results, const NodeVector&
bool ngraph::is_used(Node* node) bool ngraph::is_used(Node* node)
{ {
std::unordered_set<Node*> instances_seen; std::unordered_set<Node*> instances_seen;
std::deque<Node*> stack; std::stack<Node*, std::vector<Node*>> stack;
stack.push_front(node); stack.push(node);
while (stack.size() > 0) while (stack.size() > 0)
{ {
ngraph::Node* n = stack.front(); ngraph::Node* n = stack.top();
if (instances_seen.count(n) == 0) if (instances_seen.count(n) == 0)
{ {
if (n->is_output()) if (n->is_output())
...@@ -517,12 +508,12 @@ bool ngraph::is_used(Node* node) ...@@ -517,12 +508,12 @@ bool ngraph::is_used(Node* node)
} }
instances_seen.insert(n); instances_seen.insert(n);
} }
stack.pop_front(); stack.pop();
for (const auto& arg : n->get_users()) for (const auto& arg : n->get_users())
{ {
if (instances_seen.count(arg.get()) == 0) if (instances_seen.count(arg.get()) == 0)
{ {
stack.push_front(arg.get()); stack.push(arg.get());
} }
} }
} }
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <functional> #include <functional>
#include <list> #include <list>
#include <memory> #include <memory>
#include <stack>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -81,154 +82,131 @@ namespace ngraph ...@@ -81,154 +82,131 @@ namespace ngraph
NodeVector find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement); NodeVector find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
/// Topological sort of nodes needed to compute root_nodes
template <typename T> template <typename T>
std::list<std::shared_ptr<Node>> topological_sort(const T& nodes, std::list<std::shared_ptr<Node>> topological_sort(T root_nodes,
bool include_control_deps = false) bool include_control_deps = false)
{ {
std::deque<ngraph::Node*> independent_nodes; std::stack<Node*, std::vector<Node*>> nodes_to_do;
std::unordered_map<const ngraph::Node*, size_t> node_dependency_count; std::unordered_set<Node*> nodes_done;
std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>> node_map; std::list<std::shared_ptr<Node>> result;
std::unordered_map<ngraph::Node*, std::set<Node*>> control_deps_users;
for (auto node : nodes) for (auto& node : root_nodes)
{ {
//build an equivalent of node->get_users() but for control dependencies nodes_to_do.push(node.get());
size_t control_deps_count = 0;
if (include_control_deps)
{
for (auto cd : node->get_control_dependencies())
{
control_deps_count++;
control_deps_users[cd.get()].insert(node.get());
}
}
node_map[node.get()] = node;
size_t deps_count = node->get_input_size() + control_deps_count;
node_dependency_count[node.get()] = deps_count;
if (deps_count == 0)
{
independent_nodes.push_back(node.get());
}
} }
while (nodes_to_do.size() > 0)
std::list<std::shared_ptr<ngraph::Node>> result_list;
while (independent_nodes.size() > 0)
{ {
auto independent_node = independent_nodes.front(); Node* node = nodes_to_do.top();
result_list.push_back(node_map[independent_node]); if (nodes_done.count(node) == 0)
independent_nodes.pop_front();
for (const std::shared_ptr<Node>& user : independent_node->get_users())
{ {
if (--node_dependency_count[user.get()] == 0) bool can_add = true;
size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i)
{ {
independent_nodes.push_back(user.get()); Node* dep = node->input(arg_count - i - 1).get_source_output().get_node();
if (nodes_done.count(dep) == 0)
{
can_add = false;
nodes_to_do.push(dep);
}
} }
} if (include_control_deps)
{
if (include_control_deps) for (auto& depptr : node->get_control_dependencies())
{
auto cdit = control_deps_users.find(independent_node);
if (cdit != control_deps_users.end())
for (auto cd_user : cdit->second)
{ {
node_dependency_count[cd_user] -= 1; Node* dep = depptr.get();
size_t count = node_dependency_count[cd_user]; if (nodes_done.count(dep) == 0)
if (count == 0)
{ {
independent_nodes.push_back(cd_user); can_add = false;
nodes_to_do.push(dep);
} }
} }
}
if (can_add)
{
result.push_back(node->shared_from_this());
nodes_to_do.pop();
nodes_done.insert(node);
}
}
else
{
nodes_to_do.pop();
} }
} }
return result;
NGRAPH_CHECK(nodes.size() == result_list.size());
return result_list;
} }
// For cases, where `nodes` is a subset of the entire graph /// Topological sort of just nodes
template <typename T> template <typename T>
std::list<std::shared_ptr<Node>> subgraph_topological_sort(const T& nodes, std::list<std::shared_ptr<Node>> subgraph_topological_sort(T nodes,
bool include_control_deps = false) bool include_control_deps = false)
{ {
std::deque<ngraph::Node*> independent_nodes; std::stack<Node*, std::vector<Node*>> nodes_to_do;
std::unordered_map<const ngraph::Node*, size_t> node_dependency_count; std::unordered_set<Node*> nodes_done;
std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>> node_map; std::unordered_set<Node*> nodes_to_emit;
std::unordered_map<ngraph::Node*, std::set<Node*>> control_deps_users; std::list<std::shared_ptr<Node>> result;
std::unordered_set<std::shared_ptr<ngraph::Node>> nodes_set(nodes.begin(), nodes.end());
for (auto node : nodes) for (auto& node : nodes)
{
nodes_to_emit.insert(node.get());
nodes_to_do.push(node.get());
}
// NB: Some centos versions implement std::list::size() by counting elements
size_t nodes_remaining = nodes_to_emit.size();
while (nodes_to_do.size() > 0 && nodes_remaining > 0)
{ {
//build an equivalent of node->get_users() but for control dependencies Node* node = nodes_to_do.top();
size_t deps_count = 0; if (nodes_done.count(node) == 0)
if (include_control_deps)
{ {
for (auto cd : node->get_control_dependencies()) bool can_add = true;
size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i)
{ {
if (nodes_set.count(cd) != 0) Node* dep = node->input(arg_count - i - 1).get_source_output().get_node();
if (nodes_done.count(dep) == 0)
{ {
control_deps_users[cd.get()].insert(node.get()); can_add = false;
deps_count++; nodes_to_do.push(dep);
} }
} }
} if (include_control_deps)
node_map[node.get()] = node;
for (auto arg : node->get_arguments())
{
if (nodes_set.count(arg) != 0)
{ {
deps_count++; for (auto& depptr : node->get_control_dependencies())
{
Node* dep = depptr.get();
if (nodes_done.count(dep) == 0)
{
can_add = false;
nodes_to_do.push(dep);
}
}
} }
} if (can_add)
node_dependency_count[node.get()] = deps_count;
if (deps_count == 0)
{
independent_nodes.push_back(node.get());
}
}
std::list<std::shared_ptr<ngraph::Node>> result_list;
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front();
for (const std::shared_ptr<Node>& user : independent_node->get_users())
{
if (--node_dependency_count[user.get()] == 0)
{ {
independent_nodes.push_back(user.get()); if (nodes_to_emit.count(node) != 0)
{
result.push_back(node->shared_from_this());
nodes_remaining--;
}
nodes_to_do.pop();
nodes_done.insert(node);
} }
} }
if (include_control_deps) else
{ {
auto cdit = control_deps_users.find(independent_node); nodes_to_do.pop();
if (cdit != control_deps_users.end())
for (auto cd_user : cdit->second)
{
node_dependency_count[cd_user] -= 1;
size_t count = node_dependency_count[cd_user];
if (count == 0)
{
independent_nodes.push_back(cd_user);
}
}
} }
} }
return result;
NGRAPH_CHECK(nodes.size() == result_list.size());
return result_list;
} }
template <typename T> template <typename T>
void validate_nodes_and_infer_types(const T& nodes) void validate_nodes_and_infer_types(const T& nodes)
{ {
for (auto node : subgraph_topological_sort(nodes)) for (auto& node : subgraph_topological_sort(nodes))
{ {
node->revalidate_and_infer_types(); node->revalidate_and_infer_types();
} }
......
...@@ -329,14 +329,18 @@ NodeVector Node::get_arguments() const ...@@ -329,14 +329,18 @@ NodeVector Node::get_arguments() const
return result; return result;
} }
const std::set<std::shared_ptr<Node>>& Node::get_control_dependencies() const const std::vector<std::shared_ptr<Node>>& Node::get_control_dependencies() const
{ {
return m_control_dependencies; return m_control_dependencies;
} }
void Node::add_control_dependency(std::shared_ptr<Node> node) void Node::add_control_dependency(std::shared_ptr<Node> node)
{ {
m_control_dependencies.insert(node); if (find(m_control_dependencies.begin(), m_control_dependencies.end(), node) ==
m_control_dependencies.end())
{
m_control_dependencies.push_back(node);
}
} }
namespace ngraph namespace ngraph
...@@ -444,7 +448,7 @@ shared_ptr<descriptor::Tensor> Node::get_output_tensor_ptr() const ...@@ -444,7 +448,7 @@ shared_ptr<descriptor::Tensor> Node::get_output_tensor_ptr() const
return m_outputs.at(0).get_tensor_ptr(); return m_outputs.at(0).get_tensor_ptr();
} }
const std::set<descriptor::Input*>& Node::get_output_inputs(size_t i) const const std::vector<descriptor::Input*>& Node::get_output_inputs(size_t i) const
{ {
return m_outputs.at(i).get_inputs(); return m_outputs.at(i).get_inputs();
} }
......
...@@ -236,13 +236,17 @@ namespace ngraph ...@@ -236,13 +236,17 @@ namespace ngraph
NGRAPH_DEPRECATED("use outputs() instead"); NGRAPH_DEPRECATED("use outputs() instead");
/// Get control dependencies registered on the node /// Get control dependencies registered on the node
const std::set<std::shared_ptr<Node>>& get_control_dependencies() const; const std::vector<std::shared_ptr<Node>>& get_control_dependencies() const;
void add_control_dependency(std::shared_ptr<Node> node); void add_control_dependency(std::shared_ptr<Node> node);
void remove_control_dependency(std::shared_ptr<Node> node) void remove_control_dependency(std::shared_ptr<Node> node)
{ {
m_control_dependencies.erase(node); auto it = find(m_control_dependencies.begin(), m_control_dependencies.end(), node);
if (it != m_control_dependencies.end())
{
m_control_dependencies.erase(it);
}
} }
/// Returns the number of outputs from the node. /// Returns the number of outputs from the node.
...@@ -295,7 +299,7 @@ namespace ngraph ...@@ -295,7 +299,7 @@ namespace ngraph
"output, or update calling code not to assume only one output"); "output, or update calling code not to assume only one output");
/// Returns the set of inputs using output i /// Returns the set of inputs using output i
const std::set<descriptor::Input*>& get_output_inputs(size_t i) const const std::vector<descriptor::Input*>& get_output_inputs(size_t i) const
NGRAPH_DEPRECATED("use node->output(i).get_target_inputs() instead"); NGRAPH_DEPRECATED("use node->output(i).get_target_inputs() instead");
/// Returns the number of inputs for the op /// Returns the number of inputs for the op
...@@ -390,7 +394,7 @@ namespace ngraph ...@@ -390,7 +394,7 @@ namespace ngraph
descriptor::Input& get_input_descriptor(size_t position); descriptor::Input& get_input_descriptor(size_t position);
descriptor::Output& get_output_descriptor(size_t position); descriptor::Output& get_output_descriptor(size_t position);
std::set<std::shared_ptr<Node>> m_control_dependencies; std::vector<std::shared_ptr<Node>> m_control_dependencies;
const std::string m_node_type; const std::string m_node_type;
size_t m_instance_id{m_next_instance_id.fetch_add(1)}; size_t m_instance_id{m_next_instance_id.fetch_add(1)};
std::string m_friendly_name; std::string m_friendly_name;
......
...@@ -87,8 +87,7 @@ TEST(control_dependencies, cdep_ops) ...@@ -87,8 +87,7 @@ TEST(control_dependencies, cdep_ops)
make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn}); make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn});
auto f = make_shared<Function>(cdop, ParameterVector{A, B}); auto f = make_shared<Function>(cdop, ParameterVector{A, B});
auto nodes = f->get_ordered_ops(true); test_ordered_ops(f, NodeVector{absn});
ASSERT_EQ(nodes.back()->get_argument(0), cdop);
} }
TEST(control_dependencies, two_cdep_ops) TEST(control_dependencies, two_cdep_ops)
...@@ -102,8 +101,7 @@ TEST(control_dependencies, two_cdep_ops) ...@@ -102,8 +101,7 @@ TEST(control_dependencies, two_cdep_ops)
std::set<std::shared_ptr<Node>>{absn, absn_c}); std::set<std::shared_ptr<Node>>{absn, absn_c});
auto f = make_shared<Function>(cdop, ParameterVector{A, B, C}); auto f = make_shared<Function>(cdop, ParameterVector{A, B, C});
auto nodes = f->get_ordered_ops(true); test_ordered_ops(f, NodeVector{absn, absn_c});
ASSERT_EQ(nodes.back()->get_argument(0), cdop);
} }
TEST(control_dependencies, two_cdep_ops_op_on_top) TEST(control_dependencies, two_cdep_ops_op_on_top)
...@@ -117,8 +115,7 @@ TEST(control_dependencies, two_cdep_ops_op_on_top) ...@@ -117,8 +115,7 @@ TEST(control_dependencies, two_cdep_ops_op_on_top)
auto absn_cdop = make_shared<op::Abs>(cdop); auto absn_cdop = make_shared<op::Abs>(cdop);
auto f = make_shared<Function>(absn_cdop, ParameterVector{A, B}); auto f = make_shared<Function>(absn_cdop, ParameterVector{A, B});
auto nodes = f->get_ordered_ops(true); test_ordered_ops(f, NodeVector{absn, absn_b});
ASSERT_EQ(nodes.back()->get_argument(0), absn_cdop);
} }
TEST(control_dependencies, clone_function_cdop) TEST(control_dependencies, clone_function_cdop)
...@@ -129,6 +126,7 @@ TEST(control_dependencies, clone_function_cdop) ...@@ -129,6 +126,7 @@ TEST(control_dependencies, clone_function_cdop)
make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn}); make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn});
auto f = make_shared<Function>(cdop, ParameterVector{A}); auto f = make_shared<Function>(cdop, ParameterVector{A});
test_ordered_ops(f, NodeVector{absn});
auto clone = ngraph::clone_function(*f.get()); auto clone = ngraph::clone_function(*f.get());
auto matcher = std::make_shared<pattern::Matcher>(cdop); auto matcher = std::make_shared<pattern::Matcher>(cdop);
auto cdop_clone = clone->get_results().at(0)->get_argument(0); auto cdop_clone = clone->get_results().at(0)->get_argument(0);
......
...@@ -540,7 +540,6 @@ TEST(util, enum_mask_operators) ...@@ -540,7 +540,6 @@ TEST(util, enum_mask_operators)
TEST(graph, huge) TEST(graph, huge)
{ {
Function* f;
std::vector<std::weak_ptr<Node>> weak_nodes; std::vector<std::weak_ptr<Node>> weak_nodes;
{ {
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 3}); auto param = make_shared<op::Parameter>(element::f32, Shape{3, 3});
...@@ -549,16 +548,12 @@ TEST(graph, huge) ...@@ -549,16 +548,12 @@ TEST(graph, huge)
{ {
n = make_shared<op::Negative>(n); n = make_shared<op::Negative>(n);
} }
f = new Function(NodeVector{n}, ParameterVector{param}); auto f = make_shared<Function>(NodeVector{n}, ParameterVector{param});
for (auto node : f->get_ops()) f->map_unordered_ops(
{ [&weak_nodes](Node* node) { weak_nodes.push_back(node->shared_from_this()); });
weak_nodes.push_back(node);
}
} }
delete f; for (auto& weak_node : weak_nodes)
for (auto weak_node : weak_nodes)
{ {
EXPECT_TRUE(weak_node.expired()); EXPECT_TRUE(weak_node.expired());
} }
......
...@@ -313,3 +313,44 @@ std::shared_ptr<Function> make_function_from_file(const std::string& file_name) ...@@ -313,3 +313,44 @@ std::shared_ptr<Function> make_function_from_file(const std::string& file_name)
return func; return func;
} }
#endif #endif
::testing::AssertionResult test_ordered_ops(shared_ptr<Function> f, const NodeVector& required_ops)
{
unordered_set<Node*> seen;
for (auto& node_ptr : f->get_ordered_ops())
{
Node* node = node_ptr.get();
if (seen.count(node) > 0)
{
return ::testing::AssertionFailure() << "Duplication in ordered ops";
}
size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i)
{
Node* dep = node->input(i).get_source_output().get_node();
if (seen.count(dep) == 0)
{
return ::testing::AssertionFailure() << "Argument " << *dep
<< " does not occur before op" << *node;
}
}
for (auto& dep_ptr : node->get_control_dependencies())
{
if (seen.count(dep_ptr.get()) == 0)
{
return ::testing::AssertionFailure() << "Control dependency " << *dep_ptr
<< " does not occur before op" << *node;
}
}
seen.insert(node);
}
for (auto& node_ptr : required_ops)
{
if (seen.count(node_ptr.get()) == 0)
{
return ::testing::AssertionFailure() << "Required op " << *node_ptr
<< "does not occur in ordered ops";
}
}
return ::testing::AssertionSuccess();
}
...@@ -25,6 +25,8 @@ ...@@ -25,6 +25,8 @@
#include <random> #include <random>
#include <vector> #include <vector>
#include "gtest/gtest.h"
#include "ngraph/descriptor/layout/tensor_layout.hpp" #include "ngraph/descriptor/layout/tensor_layout.hpp"
#include "ngraph/file_util.hpp" #include "ngraph/file_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
...@@ -277,3 +279,6 @@ std::vector<T> read_binary_file(const std::string& path) ...@@ -277,3 +279,6 @@ std::vector<T> read_binary_file(const std::string& path)
inputs_fs.read(reinterpret_cast<char*>(file_content.data()), size); inputs_fs.read(reinterpret_cast<char*>(file_content.data()), size);
return file_content; return file_content;
} }
testing::AssertionResult test_ordered_ops(std::shared_ptr<ngraph::Function> f,
const ngraph::NodeVector& required_ops);
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment