Commit 4aff3ec0 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #193 from NervanaSystems/bob/pass_ptr

change all passes to take shared_ptr rather than naked pointer
parents 96ae6fdb 14d74e83
......@@ -39,26 +39,26 @@ Function::Function(const std::shared_ptr<Node>& result,
parameter->assign_function(this, i++);
}
traverse_nodes(result, [&](Node* node) { m_ops.push_back(node); });
traverse_nodes(result, [&](shared_ptr<Node> node) { m_ops.push_back(node); });
}
void Function::set_ordered_ops(const std::list<Node*>& ordered_ops)
void Function::set_ordered_ops(const std::list<shared_ptr<Node>>& ordered_ops)
{
m_ordered_ops = ordered_ops;
m_ordered_ops_valid = true;
}
std::list<Node*>& Function::get_ops()
std::list<shared_ptr<Node>>& Function::get_ops()
{
return m_ops;
}
const std::list<Node*>& Function::get_ops() const
const std::list<shared_ptr<Node>>& Function::get_ops() const
{
return m_ops;
}
std::list<Node*>& Function::get_ordered_ops()
std::list<shared_ptr<Node>>& Function::get_ordered_ops()
{
if (!m_ordered_ops_valid)
{
......@@ -67,7 +67,7 @@ std::list<Node*>& Function::get_ordered_ops()
return m_ordered_ops;
}
const std::list<Node*>& Function::get_ordered_ops() const
const std::list<shared_ptr<Node>>& Function::get_ordered_ops() const
{
if (!m_ordered_ops_valid)
{
......
......@@ -47,11 +47,11 @@ namespace ngraph
const std::shared_ptr<ValueType> get_result_type() const { return m_result_type; }
std::string get_name() const;
void set_name(const std::string& name);
std::list<Node*>& get_ops();
const std::list<Node*>& get_ops() const;
std::list<Node*>& get_ordered_ops();
const std::list<Node*>& get_ordered_ops() const;
void set_ordered_ops(const std::list<Node*>&);
std::list<std::shared_ptr<Node>>& get_ops();
const std::list<std::shared_ptr<Node>>& get_ops() const;
std::list<std::shared_ptr<Node>>& get_ordered_ops();
const std::list<std::shared_ptr<Node>>& get_ordered_ops() const;
void set_ordered_ops(const std::list<std::shared_ptr<Node>>&);
void set_ordered_ops_valid() { m_ordered_ops_valid = true; }
void clear_ordered_ops_valid() { m_ordered_ops_valid = false; }
friend std::ostream& operator<<(std::ostream&, const Function&);
......@@ -62,8 +62,8 @@ namespace ngraph
std::string m_name;
std::shared_ptr<ValueType> m_result_type;
bool m_ordered_ops_valid;
std::list<Node*> m_ordered_ops;
std::list<Node*> m_ops;
std::list<std::shared_ptr<Node>> m_ordered_ops;
std::list<std::shared_ptr<Node>> m_ops;
private:
Function(const Function&) = delete;
......
......@@ -28,7 +28,7 @@ namespace ngraph
/// @param function The function to be called
/// @param args The function arguments
///
FunctionCall(const std::shared_ptr<Function>& function,
FunctionCall(std::shared_ptr<Function> function,
const std::vector<std::shared_ptr<Node>>& args)
: Builtin(args)
, m_function(function)
......
......@@ -25,15 +25,15 @@
using namespace std;
using namespace ngraph;
bool pass::AssignTensors::run_on_call_graph(list<Node*>& nodes)
bool pass::AssignTensors::run_on_call_graph(list<std::shared_ptr<Node>>& nodes)
{
for (Node* node : nodes)
for (shared_ptr<Node> node : nodes)
{
try
{
// We need to set the nodes is_output state prior to call assign_tensors
// so that the output state can be passes to the constructed tensors.
if (node == get_state().get_functions().at(0)->get_result().get())
if (node == get_state().get_functions().at(0)->get_result())
{
node->set_is_output();
}
......
......@@ -27,7 +27,7 @@ namespace ngraph
class ngraph::pass::AssignTensors : public CallGraphPass
{
public:
virtual bool run_on_call_graph(std::list<Node*>& nodes) override;
virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>& nodes) override;
private:
};
......@@ -24,22 +24,22 @@ using namespace std;
using namespace ngraph;
using namespace ngraph::pass;
bool CollectFunctions::run_on_function(ngraph::Function* func)
bool CollectFunctions::run_on_function(shared_ptr<ngraph::Function> func)
{
set<Function*> functions;
deque<Function*> stack;
set<shared_ptr<ngraph::Function>> functions;
deque<shared_ptr<ngraph::Function>> stack;
stack.push_back(func);
while (stack.empty() == false)
{
Function* f = stack.front();
shared_ptr<ngraph::Function> f = stack.front();
stack.pop_front();
functions.insert(f);
traverse_nodes(f->get_result(), [&](Node* node) {
op::FunctionCall* fc = dynamic_cast<op::FunctionCall*>(node);
traverse_nodes(f->get_result(), [&](shared_ptr<Node> node) {
shared_ptr<op::FunctionCall> fc = dynamic_pointer_cast<op::FunctionCall>(node);
if (fc)
{
stack.push_back(fc->get_function().get());
stack.push_back(fc->get_function());
}
});
}
......
......@@ -27,7 +27,7 @@ namespace ngraph
class ngraph::pass::CollectFunctions : public FunctionPass
{
public:
bool run_on_function(ngraph::Function*) override;
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
private:
};
......@@ -28,14 +28,14 @@ pass::DumpSorted::DumpSorted(const string& output_file)
{
}
bool pass::DumpSorted::run_on_module(vector<Function*>& functions)
bool pass::DumpSorted::run_on_module(vector<shared_ptr<ngraph::Function>>& functions)
{
ofstream out{m_output_file};
if (out)
{
for (Function* f : functions)
for (shared_ptr<Function> f : functions)
{
for (const Node* node : f->get_ordered_ops())
for (const shared_ptr<Node>& node : f->get_ordered_ops())
{
out << node->get_name() << "(";
vector<string> inputs;
......
......@@ -31,7 +31,7 @@ class ngraph::pass::DumpSorted : public ModulePass
public:
DumpSorted(const std::string& output_file);
virtual bool run_on_module(std::vector<Function*>&) override;
virtual bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
private:
const std::string m_output_file;
......
......@@ -28,13 +28,13 @@ using namespace std;
using namespace ngraph;
using namespace ngraph::descriptor;
bool pass::Liveness::run_on_call_graph(list<Node*>& ops)
bool pass::Liveness::run_on_call_graph(list<shared_ptr<Node>>& ops)
{
unordered_set<Tensor*> currently_live;
for (auto it = ops.rbegin(); it != ops.rend(); it++)
{
Node* node = *it;
shared_ptr<Node> node = *it;
node->liveness_live_list.clear();
node->liveness_new_list.clear();
node->liveness_free_list.clear();
......@@ -91,7 +91,7 @@ bool pass::Liveness::run_on_call_graph(list<Node*>& ops)
// Add outputs to live_list and remove from free_list
unordered_set<Tensor*> outputs;
unordered_set<Tensor*> seen;
for (Node* node : ops)
for (shared_ptr<Node> node : ops)
{
for (Tensor* tensor : node->liveness_live_list)
{
......
......@@ -28,7 +28,7 @@ namespace ngraph
class ngraph::pass::Liveness : public CallGraphPass
{
public:
virtual bool run_on_call_graph(std::list<Node*>&) override;
virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>&) override;
private:
bool is_temporary(const descriptor::Tensor&);
......
......@@ -38,12 +38,7 @@ void ngraph::pass::Manager::initialize_default_passes()
void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
{
run_passes(func.get());
}
void ngraph::pass::Manager::run_passes(Function* func)
{
vector<Function*> fs = {func};
vector<shared_ptr<Function>> fs = {func};
get_state().set_functions(fs);
for (shared_ptr<PassBase> pass : m_pass_list)
......@@ -59,16 +54,16 @@ void ngraph::pass::Manager::run_passes(Function* func)
}
else if (function_pass)
{
for (Function* f : fs)
for (shared_ptr<Function> f : fs)
{
function_pass->run_on_function(f);
}
}
else if (node_pass)
{
for (Function* f : fs)
for (shared_ptr<Function> f : fs)
{
for (Node* n : f->get_ops())
for (shared_ptr<Node> n : f->get_ops())
{
node_pass->run_on_node(n);
}
......@@ -76,7 +71,7 @@ void ngraph::pass::Manager::run_passes(Function* func)
}
else if (call_graph_pass)
{
for (Function* f : fs)
for (shared_ptr<Function> f : fs)
{
call_graph_pass->run_on_call_graph(f->get_ordered_ops());
}
......
......@@ -47,7 +47,6 @@ public:
m_pass_list.push_back(pass_base);
}
void run_passes(Function*);
void run_passes(std::shared_ptr<Function>);
ManagerState& get_state();
......
......@@ -23,7 +23,7 @@
using namespace std;
using namespace ngraph;
vector<Function*>& ngraph::pass::ManagerState::get_functions()
vector<shared_ptr<Function>>& ngraph::pass::ManagerState::get_functions()
{
return m_function_list;
}
......
......@@ -30,7 +30,7 @@ namespace ngraph
class ngraph::pass::ManagerState
{
public:
std::vector<Function*>& get_functions();
std::vector<std::shared_ptr<Function>>& get_functions();
template <typename T>
void set_functions(const T& collection)
......@@ -44,5 +44,5 @@ public:
private:
size_t m_temporary_pool_size = 0;
std::vector<Function*> m_function_list;
std::vector<std::shared_ptr<Function>> m_function_list;
};
......@@ -26,10 +26,10 @@ using namespace std;
using namespace ngraph;
using namespace ngraph::descriptor;
bool pass::MemoryLayout::run_on_call_graph(std::list<Node*>& node_list)
bool pass::MemoryLayout::run_on_call_graph(std::list<std::shared_ptr<Node>>& node_list)
{
MemoryManager mm;
for (const Node* node : node_list)
for (shared_ptr<Node> node : node_list)
{
for (Tensor* tensor : node->liveness_new_list)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
class ngraph::pass::MemoryLayout : public CallGraphPass
{
public:
virtual bool run_on_call_graph(std::list<Node*>&) override;
virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>&) override;
private:
};
......
......@@ -32,13 +32,13 @@ pass::MemoryVisualize::MemoryVisualize(const string& filename)
{
}
bool pass::MemoryVisualize::run_on_module(vector<Function*>& functions)
bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<ngraph::Function>>& functions)
{
ofstream file(m_filename);
{
for (const Function* f : functions)
for (shared_ptr<Function> f : functions)
{
const list<Node*> nodes = f->get_ordered_ops();
list<shared_ptr<Node>> nodes = f->get_ordered_ops();
file << "<!DOCTYPE html>\n<html>\n";
file << "<head>\n";
file << " <style>\n";
......@@ -62,7 +62,7 @@ bool pass::MemoryVisualize::run_on_module(vector<Function*>& functions)
file << "<body>\n";
unordered_set<descriptor::Tensor*> tensors;
size_t temp_max_size = 0;
for (Node* node : nodes)
for (shared_ptr<Node> node : nodes)
{
tensors.insert(node->liveness_live_list.begin(), node->liveness_live_list.end());
}
......@@ -96,11 +96,11 @@ bool pass::MemoryVisualize::run_on_module(vector<Function*>& functions)
return false;
}
const Node* pass::MemoryVisualize::find_largest_op(const list<Node*>& nodes)
shared_ptr<Node> pass::MemoryVisualize::find_largest_op(const list<shared_ptr<Node>>& nodes)
{
const Node* largest_op = nullptr;
shared_ptr<Node> largest_op = nullptr;
size_t largest_size = 0;
for (const Node* exop : nodes)
for (shared_ptr<Node> exop : nodes)
{
size_t size = 0;
for (const Tensor* tensor : exop->liveness_live_list)
......@@ -116,9 +116,9 @@ const Node* pass::MemoryVisualize::find_largest_op(const list<Node*>& nodes)
return largest_op;
}
void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>& nodes)
void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<shared_ptr<Node>>& nodes)
{
const Node* largest_op = find_largest_op(nodes);
shared_ptr<Node> largest_op = find_largest_op(nodes);
if (largest_op)
{
......@@ -130,7 +130,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>&
unordered_map<const Tensor*, size_t> age_list;
vector<const Tensor*> tensor_set;
unordered_map<const Tensor*, const Node*> generator_op;
unordered_map<const Tensor*, shared_ptr<Node>> generator_op;
file << "<table>\n";
file << " <tr>";
file << "<th align=\"left\">tensor</th>";
......@@ -139,7 +139,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>&
file << "<th align=\"right\">generator weight</th>";
file << "</tr>\n";
size_t i = 0;
for (const Node* exop : nodes)
for (shared_ptr<Node> exop : nodes)
{
for (const Tensor* tensor : exop->liveness_new_list)
{
......@@ -179,7 +179,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>&
}
}
void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nodes)
void pass::MemoryVisualize::draw_histogram(ostream& file, const list<shared_ptr<Node>>& nodes)
{
size_t stroke_width = 14;
size_t text_offset = 4;
......@@ -188,7 +188,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
size_t scale = width - offset;
size_t line_spacing = stroke_width * 1.5;
size_t line_count = 0;
for (const Node* node : nodes)
for (shared_ptr<Node> node : nodes)
{
(void)node;
line_count += 1;
......@@ -198,7 +198,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
file << "<svg viewBox=\"0 0 " << width << " " << height << "\">\n";
size_t y = 0;
for (const Node* node : nodes)
for (shared_ptr<Node> node : nodes)
{
float usage = float(MemoryVisualize::memory_usage(node));
float footprint = float(MemoryVisualize::memory_footprint(node));
......@@ -220,14 +220,14 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
file << "</svg>\n";
}
void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<Node*>& nodes)
void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<shared_ptr<Node>>& nodes)
{
file << "<table>\n";
file << " <tr>";
file << "<th align=\"left\">op</th>";
file << "<th align=\"right\">influence</th>";
file << "</tr>\n";
for (const Node* exop : nodes)
for (shared_ptr<Node> exop : nodes)
{
int weight = compute_op_weight(exop);
file << " <tr>";
......@@ -237,7 +237,7 @@ void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<Node*>&
}
}
int pass::MemoryVisualize::compute_op_weight(const Node* exop)
int pass::MemoryVisualize::compute_op_weight(const shared_ptr<Node> exop)
{
int mass = 0;
// for input_decl in exop.input_decls:
......@@ -265,17 +265,17 @@ int pass::MemoryVisualize::compute_op_weight(const Node* exop)
return mass;
}
size_t pass::MemoryVisualize::memory_usage(const Node* node)
size_t pass::MemoryVisualize::memory_usage(shared_ptr<Node> node)
{
return 0;
}
size_t pass::MemoryVisualize::memory_footprint(const Node* node)
size_t pass::MemoryVisualize::memory_footprint(shared_ptr<Node> node)
{
return 0;
}
size_t pass::MemoryVisualize::memory_footprint(const std::list<Node*>& nodes)
size_t pass::MemoryVisualize::memory_footprint(const std::list<shared_ptr<Node>>& nodes)
{
return 0;
}
......@@ -32,18 +32,18 @@ class ngraph::pass::MemoryVisualize : public ModulePass
{
public:
MemoryVisualize(const std::string& filename);
virtual bool run_on_module(std::vector<Function*>&) override;
virtual bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
private:
const Node* find_largest_op(const std::list<Node*>& nodes);
void draw_tensor_weight(std::ostream& file, const std::list<Node*>& nodes);
void draw_histogram(std::ostream& file, const std::list<Node*>& nodes);
void draw_op_influence(std::ostream& file, const std::list<Node*>& nodes);
int compute_op_weight(const Node* exop);
static size_t memory_usage(const Node*);
static size_t memory_footprint(const Node*);
static size_t memory_footprint(const std::list<Node*>&);
std::shared_ptr<Node> find_largest_op(const std::list<std::shared_ptr<Node>>& nodes);
void draw_tensor_weight(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes);
void draw_histogram(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes);
void draw_op_influence(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes);
int compute_op_weight(std::shared_ptr<Node> exop);
static size_t memory_usage(std::shared_ptr<Node>);
static size_t memory_footprint(std::shared_ptr<Node>);
static size_t memory_footprint(const std::list<std::shared_ptr<Node>>&);
const std::string m_filename;
};
......@@ -53,26 +53,26 @@ class ngraph::pass::ModulePass : public PassBase
{
public:
virtual ~ModulePass() {}
virtual bool run_on_module(std::vector<ngraph::Function*>&) = 0;
virtual bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) = 0;
};
class ngraph::pass::FunctionPass : public PassBase
{
public:
virtual ~FunctionPass() {}
virtual bool run_on_function(ngraph::Function*) = 0;
virtual bool run_on_function(std::shared_ptr<ngraph::Function>) = 0;
};
class ngraph::pass::NodePass : public PassBase
{
public:
virtual ~NodePass() {}
virtual bool run_on_node(ngraph::Node*) = 0;
virtual bool run_on_node(std::shared_ptr<ngraph::Node>) = 0;
};
class ngraph::pass::CallGraphPass : public PassBase
{
public:
virtual ~CallGraphPass() {}
virtual bool run_on_call_graph(std::list<ngraph::Node*>&) = 0;
virtual bool run_on_call_graph(std::list<std::shared_ptr<ngraph::Node>>&) = 0;
};
......@@ -20,9 +20,9 @@
using namespace std;
using namespace ngraph;
bool pass::PropagateTypes::run_on_call_graph(list<Node*>& nodes)
bool pass::PropagateTypes::run_on_call_graph(list<shared_ptr<Node>>& nodes)
{
for (Node* node : nodes)
for (shared_ptr<Node> node : nodes)
{
try
{
......
......@@ -27,7 +27,7 @@ namespace ngraph
class ngraph::pass::PropagateTypes : public CallGraphPass
{
public:
virtual bool run_on_call_graph(std::list<Node*>&) override;
virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>&) override;
private:
};
......@@ -25,24 +25,26 @@
using namespace ngraph;
using namespace std;
bool ngraph::pass::TopologicalSort::run_on_function(ngraph::Function* func)
bool ngraph::pass::TopologicalSort::run_on_function(shared_ptr<ngraph::Function> func)
{
list<Node*> result_list;
list<shared_ptr<Node>> result_list;
deque<Node*> independent_nodes;
unordered_map<Node*, size_t> node_depencency_count;
unordered_map<const Node*, size_t> node_depencency_count;
unordered_map<Node*, shared_ptr<Node>> node_map;
traverse_nodes(func->get_result(), [&](Node* node) {
node_depencency_count[node] = node->get_arguments().size();
traverse_nodes(func->get_result(), [&](shared_ptr<Node> node) {
node_map[node.get()] = node;
node_depencency_count[node.get()] = node->get_arguments().size();
if (node->get_arguments().size() == 0)
{
independent_nodes.push_back(node);
independent_nodes.push_back(node.get());
}
});
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
result_list.push_back(independent_node);
result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front();
for (auto user : independent_node->users())
......
......@@ -31,5 +31,5 @@ class ngraph::pass::TopologicalSort : public FunctionPass
{
public:
TopologicalSort() {}
bool run_on_function(ngraph::Function*) override;
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
};
......@@ -23,15 +23,15 @@
using namespace ngraph;
using namespace std;
bool pass::VisualizeTree::run_on_module(vector<ngraph::Function*>& functions)
bool pass::VisualizeTree::run_on_module(vector<shared_ptr<ngraph::Function>>& functions)
{
for (Function* f : functions)
for (shared_ptr<Function> f : functions)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(f->get_result(), [&](Node* node) {
traverse_nodes(f->get_result(), [&](shared_ptr<Node> node) {
for (auto arg : node->get_arguments())
{
m_ss << add_attributes(arg.get());
m_ss << add_attributes(arg);
m_ss << add_attributes(node);
m_ss << " " << arg->get_name() << " -> " << node->get_name();
m_ss << ";\n";
......@@ -49,7 +49,7 @@ pass::VisualizeTree::VisualizeTree(const string& file_name)
{
}
std::string pass::VisualizeTree::add_attributes(const Node* node)
std::string pass::VisualizeTree::add_attributes(shared_ptr<Node> node)
{
string rc;
if (!contains(m_nodes_with_attributes, node))
......@@ -60,7 +60,7 @@ std::string pass::VisualizeTree::add_attributes(const Node* node)
return rc;
}
std::string pass::VisualizeTree::get_attributes(const Node* node)
std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
{
stringstream ss;
if (node->is_parameter())
......
......@@ -32,14 +32,14 @@ class ngraph::pass::VisualizeTree : public ModulePass
{
public:
VisualizeTree(const std::string& file_name);
bool run_on_module(std::vector<ngraph::Function*>&) override;
bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
private:
std::string add_attributes(const Node* node);
std::string get_attributes(const Node* node);
std::string add_attributes(std::shared_ptr<Node> node);
std::string get_attributes(std::shared_ptr<Node> node);
void render() const;
std::stringstream m_ss;
std::string m_name;
std::set<const Node*> m_nodes_with_attributes;
std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
};
......@@ -961,7 +961,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Turn this into a pass
// Assign layouts
// For now, just make everyone row-major.
for (const Node* node : m_function->get_ordered_ops())
for (shared_ptr<Node> node : m_function->get_ordered_ops())
{
for (const descriptor::Output& output : node->get_outputs())
{
......@@ -998,7 +998,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
m_n_outputs = tensor_index.size() - m_n_inputs;
// All remaining tensor views
for (const Node* node : m_function->get_ordered_ops())
for (shared_ptr<Node> node : m_function->get_ordered_ops())
{
for (const descriptor::Output& output : node->get_outputs())
{
......@@ -1014,9 +1014,11 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Now we build the eigen-VM instructions
auto op_map = get_op_map();
for (const Node* node : m_function->get_ordered_ops())
for (shared_ptr<Node> node : m_function->get_ordered_ops())
{
auto handler_it = op_map.find(type_index(typeid(*node)));
auto& n = *node; // Work around a compiler warning (*node inside typeid may have effects
// with shared pointers, which is fine here but clang doesn't like it.)
auto handler_it = op_map.find(type_index(typeid(n)));
if (handler_it == op_map.end())
{
throw ngraph_error("Unhandled op during code generation");
......@@ -1034,7 +1036,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
auto tv = output.get_tensor_view();
out.push_back({tensor_index.at(tv), tv});
}
handler_it->second(node, this, function_map, in, out);
handler_it->second(node.get(), this, function_map, in, out);
}
m_instructions->push_back(make_shared<eigen::ReturnInstruction>());
m_is_compiled = true;
......
......@@ -137,15 +137,16 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list)
return seed;
}
void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p, std::function<void(Node*)> f)
void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p,
std::function<void(shared_ptr<Node>)> f)
{
std::unordered_set<Node*> instances_seen;
deque<Node*> stack;
stack.push_front(p.get());
std::unordered_set<shared_ptr<Node>> instances_seen;
deque<shared_ptr<Node>> stack;
stack.push_front(p);
while (stack.size() > 0)
{
Node* n = stack.front();
shared_ptr<Node> n = stack.front();
if (instances_seen.find(n) == instances_seen.end())
{
instances_seen.insert(n);
......@@ -154,7 +155,7 @@ void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p, std::functio
stack.pop_front();
for (auto arg : n->get_arguments())
{
stack.push_front(arg.get());
stack.push_front(arg);
}
}
}
......@@ -163,7 +164,7 @@ void ngraph::free_nodes(shared_ptr<Node> p)
{
std::deque<Node*> sorted_list;
traverse_nodes(p, [&](Node* n) { sorted_list.push_front(n); });
traverse_nodes(p, [&](shared_ptr<Node> n) { sorted_list.push_front(n.get()); });
for (Node* n : sorted_list)
{
......
......@@ -195,7 +195,8 @@ namespace ngraph
return a * b;
}
void traverse_nodes(const std::shared_ptr<Node>& p, std::function<void(Node*)> f);
void traverse_nodes(const std::shared_ptr<Node>& p,
std::function<void(std::shared_ptr<Node>)> f);
void free_nodes(std::shared_ptr<Node>);
} // end namespace ngraph
......@@ -50,7 +50,7 @@ TEST(pass, liveness)
pass_manager.register_pass<pass::DumpSorted>(dump_file);
shared_ptr<Function> func = make_test_graph();
pass_manager.run_passes(func.get());
pass_manager.run_passes(func);
auto sorted = func->get_ordered_ops();
// for (const Node* node : sorted)
......
......@@ -39,7 +39,7 @@ TEST(pass_manager, add)
auto graph = make_test_graph();
size_t node_count = get_node_count(graph->get_result());
pass_manager.run_passes(graph.get());
pass_manager.run_passes(graph);
auto sorted = graph->get_ordered_ops();
EXPECT_EQ(node_count, sorted.size());
EXPECT_TRUE(validate_list(sorted));
......
......@@ -23,7 +23,7 @@ using namespace ngraph;
// This function traverses the list of ops and verifies that each op's dependencies (its inputs)
// is located earlier in the list. That is enough to be valid
bool validate_list(const list<Node*>& nodes)
bool validate_list(const list<shared_ptr<Node>>& nodes)
{
bool rc = true;
for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
......@@ -39,7 +39,7 @@ bool validate_list(const list<Node*>& nodes)
for (; tmp != nodes.rend(); tmp++)
{
auto dep_tmp = *tmp;
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp);
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp.get());
if (found != dependencies.end())
{
dependencies.erase(found);
......@@ -82,6 +82,6 @@ shared_ptr<Function> make_test_graph()
size_t get_node_count(std::shared_ptr<Node> n)
{
size_t node_count = 0;
traverse_nodes(n, [&](const Node* node) { node_count++; });
traverse_nodes(n, [&](shared_ptr<Node> node) { node_count++; });
return node_count;
}
......@@ -23,6 +23,6 @@ namespace ngraph
class Function;
}
bool validate_list(const std::list<ngraph::Node*>& nodes);
bool validate_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes);
std::shared_ptr<ngraph::Function> make_test_graph();
size_t get_node_count(std::shared_ptr<ngraph::Node> n);
......@@ -126,7 +126,7 @@ TEST(benchmark, topological_sort)
NGRAPH_INFO << "topological sort took " << timer.get_milliseconds() << "ms";
size_t node_count = 0;
traverse_nodes(result, [&](const Node* node) { node_count++; });
traverse_nodes(result, [&](shared_ptr<Node> node) { node_count++; });
NGRAPH_INFO << "node count " << node_count;
......@@ -135,6 +135,7 @@ TEST(benchmark, topological_sort)
timer.stop();
NGRAPH_INFO << "delete nodes took " << timer.get_milliseconds() << "ms";
}
TEST(topological_sort, collect_functions)
{
// First create "f(A,B,C) = (A+B)*C".
......@@ -174,7 +175,7 @@ TEST(topological_sort, collect_functions)
set<string> expected = {"f", "g", "h"};
auto functions = pass_manager.get_state().get_functions();
vector<string> fnames;
for (Function* func : functions)
for (shared_ptr<Function> func : functions)
{
fnames.push_back(func->get_name());
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment