Unverified Commit 0096def7 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Convert Function::get_ops/get_ordered_ops to return vector instead of list (#4207)

* Make passes work with list<nodes>

* Convert ops returned by Function from list to vector

* fix build error
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent eb2445a6
...@@ -96,7 +96,7 @@ void Function::init() ...@@ -96,7 +96,7 @@ void Function::init()
true /*include control dependencies*/); true /*include control dependencies*/);
} }
std::list<shared_ptr<Node>> Function::get_ordered_ops(bool include_control_deps) const std::vector<shared_ptr<Node>> Function::get_ordered_ops(bool include_control_deps) const
{ {
NodeVector nodes; NodeVector nodes;
for (auto& r : get_results()) for (auto& r : get_results())
...@@ -223,9 +223,9 @@ shared_ptr<Node> Function::get_result() const ...@@ -223,9 +223,9 @@ shared_ptr<Node> Function::get_result() const
return m_results.at(0); return m_results.at(0);
} }
std::list<shared_ptr<Node>> Function::get_ops(bool include_control_deps) const std::vector<shared_ptr<Node>> Function::get_ops(bool include_control_deps) const
{ {
std::list<std::shared_ptr<Node>> ops; std::vector<std::shared_ptr<Node>> ops;
traverse_nodes(this, [&](shared_ptr<Node> node) { ops.push_back(node); }, include_control_deps); traverse_nodes(this, [&](shared_ptr<Node> node) { ops.push_back(node); }, include_control_deps);
return ops; return ops;
} }
......
...@@ -92,8 +92,8 @@ namespace ngraph ...@@ -92,8 +92,8 @@ namespace ngraph
/// \returns A const reference to the function's friendly name. /// \returns A const reference to the function's friendly name.
const std::string& get_friendly_name() const; const std::string& get_friendly_name() const;
std::list<std::shared_ptr<Node>> get_ops(bool include_control_deps = true) const; std::vector<std::shared_ptr<Node>> get_ops(bool include_control_deps = true) const;
std::list<std::shared_ptr<Node>> get_ordered_ops(bool include_control_deps = true) const; std::vector<std::shared_ptr<Node>> get_ordered_ops(bool include_control_deps = true) const;
void map_unordered_ops(std::function<void(Node*)> f) const; void map_unordered_ops(std::function<void(Node*)> f) const;
friend std::ostream& operator<<(std::ostream&, const Function&); friend std::ostream& operator<<(std::ostream&, const Function&);
......
...@@ -291,8 +291,8 @@ bool ngraph::is_post_dominated(Node* X, Node* Y) ...@@ -291,8 +291,8 @@ bool ngraph::is_post_dominated(Node* X, Node* Y)
return true; return true;
} }
std::list<std::shared_ptr<ngraph::Node>> std::vector<std::shared_ptr<ngraph::Node>>
ngraph::clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map) ngraph::clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map)
{ {
// for each node in topological order // for each node in topological order
auto sorted_nodes = topological_sort(nodes, true); auto sorted_nodes = topological_sort(nodes, true);
...@@ -334,9 +334,9 @@ std::list<std::shared_ptr<ngraph::Node>> ...@@ -334,9 +334,9 @@ std::list<std::shared_ptr<ngraph::Node>>
} }
} }
// create and return list of cloned nodes // create and return vector of cloned nodes
// order matches input list (not necessarily topological) // order matches input vector (not necessarily topological)
std::list<std::shared_ptr<ngraph::Node>> cloned_nodes; std::vector<std::shared_ptr<ngraph::Node>> cloned_nodes;
for (auto node : nodes) for (auto node : nodes)
{ {
cloned_nodes.push_back(node_map.at(node.get())); cloned_nodes.push_back(node_map.at(node.get()));
...@@ -345,7 +345,7 @@ std::list<std::shared_ptr<ngraph::Node>> ...@@ -345,7 +345,7 @@ std::list<std::shared_ptr<ngraph::Node>>
} }
std::list<std::shared_ptr<ngraph::Node>> std::list<std::shared_ptr<ngraph::Node>>
ngraph::clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, ngraph::clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes,
RawNodeOutputMap& output_map) RawNodeOutputMap& output_map)
{ {
// for each node in topological order // for each node in topological order
......
...@@ -258,12 +258,12 @@ namespace ngraph ...@@ -258,12 +258,12 @@ namespace ngraph
/// Topological sort of nodes needed to compute root_nodes /// Topological sort of nodes needed to compute root_nodes
template <typename T> template <typename T>
std::list<std::shared_ptr<Node>> topological_sort(T root_nodes, std::vector<std::shared_ptr<Node>> topological_sort(T root_nodes,
bool include_control_deps = false) bool include_control_deps = false)
{ {
std::stack<Node*, std::vector<Node*>> nodes_to_do; std::stack<Node*, std::vector<Node*>> nodes_to_do;
std::unordered_set<Node*> nodes_done; std::unordered_set<Node*> nodes_done;
std::list<std::shared_ptr<Node>> result; std::vector<std::shared_ptr<Node>> result;
for (auto& node : root_nodes) for (auto& node : root_nodes)
{ {
...@@ -314,13 +314,13 @@ namespace ngraph ...@@ -314,13 +314,13 @@ namespace ngraph
/// Topological sort of just nodes /// Topological sort of just nodes
template <typename T> template <typename T>
std::list<std::shared_ptr<Node>> subgraph_topological_sort(T nodes, std::vector<std::shared_ptr<Node>> subgraph_topological_sort(T nodes,
bool include_control_deps = false) bool include_control_deps = false)
{ {
std::stack<Node*, std::vector<Node*>> nodes_to_do; std::stack<Node*, std::vector<Node*>> nodes_to_do;
std::unordered_set<Node*> nodes_done; std::unordered_set<Node*> nodes_done;
std::unordered_set<Node*> nodes_to_emit; std::unordered_set<Node*> nodes_to_emit;
std::list<std::shared_ptr<Node>> result; std::vector<std::shared_ptr<Node>> result;
for (auto& node : nodes) for (auto& node : nodes)
{ {
...@@ -394,14 +394,14 @@ namespace ngraph ...@@ -394,14 +394,14 @@ namespace ngraph
// input nodes are cloned and returned // input nodes are cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes // NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned nodes // NodeMap output (by reference) fully maps input and cloned nodes
std::list<std::shared_ptr<ngraph::Node>> std::vector<std::shared_ptr<ngraph::Node>>
clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map); clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map);
// input nodes are cloned and returned // input nodes are cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes // NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned nodes // NodeMap output (by reference) fully maps input and cloned nodes
std::list<std::shared_ptr<ngraph::Node>> std::list<std::shared_ptr<ngraph::Node>>
clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes,
RawNodeOutputMap& node_map); RawNodeOutputMap& node_map);
// input function is cloned and returned // input function is cloned and returned
......
...@@ -35,7 +35,7 @@ using namespace ngraph; ...@@ -35,7 +35,7 @@ using namespace ngraph;
bool pass::Liveness::run_on_function(shared_ptr<Function> function) bool pass::Liveness::run_on_function(shared_ptr<Function> function)
{ {
list<shared_ptr<Node>> ops = function->get_ordered_ops(); auto ops = function->get_ordered_ops();
unordered_set<descriptor::Tensor*> persistent_tensors; unordered_set<descriptor::Tensor*> persistent_tensors;
unordered_set<descriptor::Tensor*> output_tensors; unordered_set<descriptor::Tensor*> output_tensors;
......
...@@ -40,7 +40,7 @@ bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<Function>>& function ...@@ -40,7 +40,7 @@ bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<Function>>& function
{ {
for (shared_ptr<Function> f : functions) for (shared_ptr<Function> f : functions)
{ {
list<shared_ptr<Node>> nodes = f->get_ordered_ops(); vector<shared_ptr<Node>> nodes = f->get_ordered_ops();
file << "<!DOCTYPE html>\n<html>\n"; file << "<!DOCTYPE html>\n<html>\n";
file << "<head>\n"; file << "<head>\n";
file << " <style>\n"; file << " <style>\n";
...@@ -96,7 +96,7 @@ bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<Function>>& function ...@@ -96,7 +96,7 @@ bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<Function>>& function
} }
unordered_set<const descriptor::Tensor*> unordered_set<const descriptor::Tensor*>
pass::MemoryVisualize::find_largest_op(const list<shared_ptr<Node>>& nodes) pass::MemoryVisualize::find_largest_op(const vector<shared_ptr<Node>>& nodes)
{ {
size_t largest_size = 0; size_t largest_size = 0;
unordered_set<const descriptor::Tensor*> liveness_list; unordered_set<const descriptor::Tensor*> liveness_list;
...@@ -122,7 +122,7 @@ unordered_set<const descriptor::Tensor*> ...@@ -122,7 +122,7 @@ unordered_set<const descriptor::Tensor*>
return largest_live_list; return largest_live_list;
} }
void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<shared_ptr<Node>>& nodes) void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const vector<shared_ptr<Node>>& nodes)
{ {
unordered_set<const descriptor::Tensor*> largest_live_list = find_largest_op(nodes); unordered_set<const descriptor::Tensor*> largest_live_list = find_largest_op(nodes);
...@@ -178,7 +178,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<shared_ ...@@ -178,7 +178,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<shared_
file << "</table>\n"; file << "</table>\n";
} }
void pass::MemoryVisualize::draw_histogram(ostream& file, const list<shared_ptr<Node>>& nodes) void pass::MemoryVisualize::draw_histogram(ostream& file, const vector<shared_ptr<Node>>& nodes)
{ {
size_t stroke_width = 14; size_t stroke_width = 14;
size_t text_offset = 4; size_t text_offset = 4;
...@@ -219,7 +219,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<shared_ptr< ...@@ -219,7 +219,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<shared_ptr<
file << "</svg>\n"; file << "</svg>\n";
} }
void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<shared_ptr<Node>>& nodes) void pass::MemoryVisualize::draw_op_influence(ostream& file, const vector<shared_ptr<Node>>& nodes)
{ {
file << "<table>\n"; file << "<table>\n";
file << " <tr>"; file << " <tr>";
...@@ -260,7 +260,7 @@ size_t pass::MemoryVisualize::memory_footprint(shared_ptr<Node> /* node */) ...@@ -260,7 +260,7 @@ size_t pass::MemoryVisualize::memory_footprint(shared_ptr<Node> /* node */)
return 0; return 0;
} }
size_t pass::MemoryVisualize::memory_footprint(const std::list<shared_ptr<Node>>& /* nodes */) size_t pass::MemoryVisualize::memory_footprint(const std::vector<shared_ptr<Node>>& /* nodes */)
{ {
return 0; return 0;
} }
...@@ -38,15 +38,15 @@ public: ...@@ -38,15 +38,15 @@ public:
private: private:
std::unordered_set<const descriptor::Tensor*> std::unordered_set<const descriptor::Tensor*>
find_largest_op(const std::list<std::shared_ptr<Node>>& nodes); find_largest_op(const std::vector<std::shared_ptr<Node>>& nodes);
void draw_tensor_weight(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes); void draw_tensor_weight(std::ostream& file, const std::vector<std::shared_ptr<Node>>& nodes);
void draw_histogram(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes); void draw_histogram(std::ostream& file, const std::vector<std::shared_ptr<Node>>& nodes);
void draw_op_influence(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes); void draw_op_influence(std::ostream& file, const std::vector<std::shared_ptr<Node>>& nodes);
int compute_op_weight(std::shared_ptr<Node> exop); int compute_op_weight(std::shared_ptr<Node> exop);
static size_t memory_usage(std::shared_ptr<Node>); static size_t memory_usage(std::shared_ptr<Node>);
static size_t memory_footprint(std::shared_ptr<Node>); static size_t memory_footprint(std::shared_ptr<Node>);
static size_t memory_footprint(const std::list<std::shared_ptr<Node>>&); static size_t memory_footprint(const std::vector<std::shared_ptr<Node>>&);
const std::string m_filename; const std::string m_filename;
}; };
...@@ -69,3 +69,13 @@ pass::NodePass::~NodePass() ...@@ -69,3 +69,13 @@ pass::NodePass::~NodePass()
pass::CallGraphPass::~CallGraphPass() pass::CallGraphPass::~CallGraphPass()
{ {
} }
bool pass::CallGraphPass::run_on_call_graph(const std::vector<std::shared_ptr<ngraph::Node>>& nodes)
{
list<shared_ptr<Node>> node_list;
for (auto op : nodes)
{
node_list.push_back(op);
}
return run_on_call_graph(node_list);
}
...@@ -115,4 +115,5 @@ class NGRAPH_API ngraph::pass::CallGraphPass : public PassBase ...@@ -115,4 +115,5 @@ class NGRAPH_API ngraph::pass::CallGraphPass : public PassBase
public: public:
virtual ~CallGraphPass(); virtual ~CallGraphPass();
virtual bool run_on_call_graph(const std::list<std::shared_ptr<ngraph::Node>>&) = 0; virtual bool run_on_call_graph(const std::list<std::shared_ptr<ngraph::Node>>&) = 0;
virtual bool run_on_call_graph(const std::vector<std::shared_ptr<ngraph::Node>>&);
}; };
...@@ -202,10 +202,10 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions) ...@@ -202,10 +202,10 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions)
} }
auto nodes = topological_sort(f->get_ops()); auto nodes = topological_sort(f->get_ops());
nodes.reverse();
for (auto& node : nodes) for (auto it = nodes.rbegin(); it != nodes.rend(); ++it)
{ {
auto& node = *it;
for (auto& output : node->outputs()) for (auto& output : node->outputs())
{ {
for (auto& input : output.get_target_inputs()) for (auto& input : output.get_target_inputs())
......
...@@ -511,7 +511,7 @@ void runtime::cpu::CPU_ExternalFunction::compile(ngraph::pass::PassConfig& pass_ ...@@ -511,7 +511,7 @@ void runtime::cpu::CPU_ExternalFunction::compile(ngraph::pass::PassConfig& pass_
femitter, node_function_map, common_function_string); femitter, node_function_map, common_function_string);
pass_manager.run_passes(m_function); pass_manager.run_passes(m_function);
list<shared_ptr<Node>> ordered_ops = m_function->get_ordered_ops(); auto ordered_ops = m_function->get_ordered_ops();
CodeWriter writer; CodeWriter writer;
......
...@@ -56,7 +56,7 @@ size_t runtime::cpu::pass::CPUMemoryAssignment::get_bufferID(descriptor::Tensor* ...@@ -56,7 +56,7 @@ size_t runtime::cpu::pass::CPUMemoryAssignment::get_bufferID(descriptor::Tensor*
} }
void runtime::cpu::pass::CPUMemoryAssignment::process_in_place_concat( void runtime::cpu::pass::CPUMemoryAssignment::process_in_place_concat(
std::list<std::shared_ptr<Node>> nodes) std::vector<std::shared_ptr<Node>> nodes)
{ {
for (shared_ptr<Node> node : nodes) for (shared_ptr<Node> node : nodes)
{ {
...@@ -215,7 +215,7 @@ void runtime::cpu::pass::CPUMemoryAssignment::propagate_in_place_concat(const Ou ...@@ -215,7 +215,7 @@ void runtime::cpu::pass::CPUMemoryAssignment::propagate_in_place_concat(const Ou
// slice // slice
void runtime::cpu::pass::CPUMemoryAssignment::process_in_place_slice( void runtime::cpu::pass::CPUMemoryAssignment::process_in_place_slice(
std::list<std::shared_ptr<Node>> nodes) std::vector<std::shared_ptr<Node>> nodes)
{ {
for (shared_ptr<Node>& node : nodes) for (shared_ptr<Node>& node : nodes)
{ {
...@@ -328,7 +328,7 @@ void runtime::cpu::pass::CPUMemoryAssignment::propagate_in_place_slice(const Inp ...@@ -328,7 +328,7 @@ void runtime::cpu::pass::CPUMemoryAssignment::propagate_in_place_slice(const Inp
// new set is created. bufferID_to_tensorSets maps bufferID to the pair of TensorRole and buffer // new set is created. bufferID_to_tensorSets maps bufferID to the pair of TensorRole and buffer
// set. TensorRole is INPUT, CONSTANT, OUTPUT, or INTERMEDIATE, which tells from where the memory // set. TensorRole is INPUT, CONSTANT, OUTPUT, or INTERMEDIATE, which tells from where the memory
// buffer comes. tensor_to_bufferID maps tensor to the ID of the buffer set it belongs to. // buffer comes. tensor_to_bufferID maps tensor to the ID of the buffer set it belongs to.
void runtime::cpu::pass::CPUMemoryAssignment::build_buffer_sets_maps(list<shared_ptr<Node>>& ops) void runtime::cpu::pass::CPUMemoryAssignment::build_buffer_sets_maps(vector<shared_ptr<Node>>& ops)
{ {
unordered_set<descriptor::Tensor*> in_place_slice_chain; unordered_set<descriptor::Tensor*> in_place_slice_chain;
size_t count = 0; size_t count = 0;
...@@ -545,7 +545,7 @@ void runtime::cpu::pass::CPUMemoryAssignment::build_buffer_sets_maps(list<shared ...@@ -545,7 +545,7 @@ void runtime::cpu::pass::CPUMemoryAssignment::build_buffer_sets_maps(list<shared
} }
void runtime::cpu::pass::CPUMemoryAssignment::liveness_analysis( void runtime::cpu::pass::CPUMemoryAssignment::liveness_analysis(
std::list<std::shared_ptr<Node>>& ops) std::vector<std::shared_ptr<Node>>& ops)
{ {
auto find_role = [](TensorRole tensor_role) -> string { auto find_role = [](TensorRole tensor_role) -> string {
switch (tensor_role) switch (tensor_role)
...@@ -620,7 +620,7 @@ void runtime::cpu::pass::CPUMemoryAssignment::liveness_analysis( ...@@ -620,7 +620,7 @@ void runtime::cpu::pass::CPUMemoryAssignment::liveness_analysis(
bool runtime::cpu::pass::CPUMemoryAssignment::run_on_function(shared_ptr<ngraph::Function> function) bool runtime::cpu::pass::CPUMemoryAssignment::run_on_function(shared_ptr<ngraph::Function> function)
{ {
list<shared_ptr<Node>> ops = function->get_ordered_ops(); auto ops = function->get_ordered_ops();
build_buffer_sets_maps(ops); build_buffer_sets_maps(ops);
liveness_analysis(ops); liveness_analysis(ops);
......
...@@ -50,22 +50,22 @@ public: ...@@ -50,22 +50,22 @@ public:
private: private:
// Find in-place concat ops and set appropriate memory pool offset for its arguments // Find in-place concat ops and set appropriate memory pool offset for its arguments
void process_in_place_concat(std::list<std::shared_ptr<Node>> nodes); void process_in_place_concat(std::vector<std::shared_ptr<Node>> nodes);
// For a chain of concat ops, propagate memory pool offsets // For a chain of concat ops, propagate memory pool offsets
void propagate_in_place_concat(const ngraph::Output<ngraph::Node>& concat); void propagate_in_place_concat(const ngraph::Output<ngraph::Node>& concat);
// Find in-place slice ops and set appropriate memory pool offset for its output // Find in-place slice ops and set appropriate memory pool offset for its output
void process_in_place_slice(std::list<std::shared_ptr<Node>> nodes); void process_in_place_slice(std::vector<std::shared_ptr<Node>> nodes);
// propagate slice when its arg comes from function input // propagate slice when its arg comes from function input
void propagate_in_place_slice(const ngraph::Input<ngraph::Node>& input); void propagate_in_place_slice(const ngraph::Input<ngraph::Node>& input);
// build buffer sets maps // build buffer sets maps
void build_buffer_sets_maps(std::list<std::shared_ptr<Node>>& ops); void build_buffer_sets_maps(std::vector<std::shared_ptr<Node>>& ops);
// liveness analysis to build new and free list for each node // liveness analysis to build new and free list for each node
void liveness_analysis(std::list<std::shared_ptr<Node>>& ops); void liveness_analysis(std::vector<std::shared_ptr<Node>>& ops);
size_t get_bufferID(descriptor::Tensor* tensor); size_t get_bufferID(descriptor::Tensor* tensor);
......
...@@ -192,7 +192,7 @@ public: ...@@ -192,7 +192,7 @@ public:
std::shared_ptr<Node> AplusBtimesC = AplusB * C; std::shared_ptr<Node> AplusBtimesC = AplusB * C;
NodeMap node_map; NodeMap node_map;
std::list<std::shared_ptr<ngraph::Node>> nodes; std::vector<std::shared_ptr<ngraph::Node>> nodes;
std::shared_ptr<Function> func = std::shared_ptr<Function> func =
make_shared<Function>(AplusBtimesC, ParameterVector{A, B, C}, "f"); make_shared<Function>(AplusBtimesC, ParameterVector{A, B, C}, "f");
...@@ -205,8 +205,8 @@ public: ...@@ -205,8 +205,8 @@ public:
nodes.push_back(C); nodes.push_back(C);
} }
bool CompareNodeVector(const std::list<std::shared_ptr<ngraph::Node>>& orig, bool CompareNodeVector(const std::vector<std::shared_ptr<ngraph::Node>>& orig,
const std::list<std::shared_ptr<ngraph::Node>>& clone, const std::vector<std::shared_ptr<ngraph::Node>>& clone,
const NodeMap& nm) const NodeMap& nm)
{ {
if (orig.size() != clone.size()) if (orig.size() != clone.size())
...@@ -373,7 +373,7 @@ TEST(graph_util, test_subgraph_topological_sort) ...@@ -373,7 +373,7 @@ TEST(graph_util, test_subgraph_topological_sort)
auto mul = C * add; auto mul = C * add;
auto result = make_shared<op::Result>(mul); auto result = make_shared<op::Result>(mul);
auto sorted = ngraph::subgraph_topological_sort(NodeVector{mul, add, A}); auto sorted = ngraph::subgraph_topological_sort(NodeVector{mul, add, A});
std::list<std::shared_ptr<Node>> expected{A, add, mul}; std::vector<std::shared_ptr<Node>> expected{A, add, mul};
ASSERT_EQ(expected, sorted); ASSERT_EQ(expected, sorted);
} }
...@@ -391,7 +391,7 @@ TEST(graph_util, test_subgraph_topological_sort_control_dependencies) ...@@ -391,7 +391,7 @@ TEST(graph_util, test_subgraph_topological_sort_control_dependencies)
auto mul = C * add; auto mul = C * add;
auto result = make_shared<op::Result>(mul); auto result = make_shared<op::Result>(mul);
auto sorted = ngraph::subgraph_topological_sort(NodeVector{mul, add, A, D}, true); auto sorted = ngraph::subgraph_topological_sort(NodeVector{mul, add, A, D}, true);
std::list<std::shared_ptr<Node>> expected{A, D, add, mul}; std::vector<std::shared_ptr<Node>> expected{A, D, add, mul};
ASSERT_EQ(expected, sorted); ASSERT_EQ(expected, sorted);
} }
......
...@@ -23,9 +23,9 @@ ...@@ -23,9 +23,9 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
// This function traverses the list of ops and verifies that each op's dependencies (its inputs) // This function traverses the vector of ops and verifies that each op's dependencies (its inputs)
// is located earlier in the list. That is enough to be valid // is located earlier in the vector. That is enough to be valid
bool validate_list(const list<shared_ptr<Node>>& nodes) bool validate_list(const vector<shared_ptr<Node>>& nodes)
{ {
bool rc = true; bool rc = true;
for (auto it = nodes.rbegin(); it != nodes.rend(); it++) for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
......
...@@ -45,7 +45,7 @@ namespace ngraph ...@@ -45,7 +45,7 @@ namespace ngraph
class Function; class Function;
} }
bool validate_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes); bool validate_list(const std::vector<std::shared_ptr<ngraph::Node>>& nodes);
std::shared_ptr<ngraph::Function> make_test_graph(); std::shared_ptr<ngraph::Function> make_test_graph();
#ifndef NGRAPH_JSON_DISABLE #ifndef NGRAPH_JSON_DISABLE
std::shared_ptr<ngraph::Function> make_function_from_file(const std::string& file_name); std::shared_ptr<ngraph::Function> make_function_from_file(const std::string& file_name);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment