Unverified Commit ddcfbda8 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

Remove m_arguments and m_users (#816)

* make Input descriptors node owners

* rename src_node to m_src_node
parent 577d5c6c
...@@ -27,6 +27,7 @@ Input::Input(Node* node, size_t index, Output& output) ...@@ -27,6 +27,7 @@ Input::Input(Node* node, size_t index, Output& output)
, m_index(index) , m_index(index)
, m_output(&output) , m_output(&output)
{ {
m_src_node = std::shared_ptr<Node>(output.get_node());
output.add_input(this); output.add_input(this);
} }
...@@ -35,6 +36,7 @@ void Input::replace_output(Output& new_output) ...@@ -35,6 +36,7 @@ void Input::replace_output(Output& new_output)
m_output->remove_input(this); m_output->remove_input(this);
new_output.add_input(this); new_output.add_input(this);
m_output = &new_output; m_output = &new_output;
m_src_node = std::shared_ptr<Node>(new_output.get_node());
} }
void Input::replace_output(std::shared_ptr<Node> node, size_t i) void Input::replace_output(std::shared_ptr<Node> node, size_t i)
......
...@@ -76,6 +76,8 @@ namespace ngraph ...@@ -76,6 +76,8 @@ namespace ngraph
const element::Type& get_element_type() const; const element::Type& get_element_type() const;
protected: protected:
//owner of an argument node (in lieu of m_arguments)
std::shared_ptr<Node> m_src_node;
Node* m_node; // The node we are an input for Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors size_t m_index; // Index into all input tensors
Output* m_output; Output* m_output;
......
...@@ -105,18 +105,6 @@ void ngraph::traverse_functions(std::shared_ptr<ngraph::Function> p, ...@@ -105,18 +105,6 @@ void ngraph::traverse_functions(std::shared_ptr<ngraph::Function> p,
} }
} }
void ngraph::free_nodes(shared_ptr<Function> p)
{
std::deque<Node*> sorted_list;
traverse_nodes(p, [&](shared_ptr<Node> n) { sorted_list.push_front(n.get()); });
for (Node* n : sorted_list)
{
n->clear_arguments();
}
}
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement) void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
{ {
if (target->is_output()) if (target->is_output())
...@@ -140,24 +128,6 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re ...@@ -140,24 +128,6 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
input->replace_output(replacement->get_outputs().at(i)); input->replace_output(replacement->get_outputs().at(i));
} }
} }
// Fix users and arguments
replace_node_users_arguments(target, replacement);
}
void ngraph::replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement)
{
for (auto user : target->users())
{
auto& args = const_cast<ngraph::NodeVector&>(user->get_arguments_FOR_GRAPH_REWRITE_ONLY());
auto it = std::find(begin(args), end(args), target);
assert(it != end(args));
it = args.erase(it);
args.insert(it, replacement);
const_cast<std::multiset<Node*>&>(replacement->users()).insert(user);
}
const_cast<std::multiset<Node*>&>(target->users()).clear();
} }
std::list<std::shared_ptr<ngraph::Node>> std::list<std::shared_ptr<ngraph::Node>>
...@@ -338,18 +308,6 @@ pair<shared_ptr<op::Result>, shared_ptr<op::Parameter>> ...@@ -338,18 +308,6 @@ pair<shared_ptr<op::Result>, shared_ptr<op::Parameter>>
src_output->remove_input(dst_input); // Remove [0] src_output->remove_input(dst_input); // Remove [0]
dst_input->replace_output(par_node, 0); // Remove [0] (again), add [8], remove [1], add [9] dst_input->replace_output(par_node, 0); // Remove [0] (again), add [8], remove [1], add [9]
// Fix user / argument among src, dst and par
const_cast<multiset<Node*>&>(src_node->users()).erase(dst_node.get()); // Remove [2]
const_cast<multiset<Node*>&>(par_node->users()).insert(dst_node.get()); // Add [10]
auto& dst_args = const_cast<NodeVector&>(dst_node->get_arguments_FOR_GRAPH_REWRITE_ONLY());
auto it = find(dst_args.begin(), dst_args.end(), src_node);
if (it == dst_args.end())
{
throw ngraph_error("src_node is not an input to dst_node");
}
it = dst_args.erase(it); // Remove [3]
dst_args.insert(it, par_node); // Add [11]
// Add res node // Add res node
shared_ptr<op::Result> res_node = make_shared<op::Result>(src_node); // Add [4], [5], [6], [7] shared_ptr<op::Result> res_node = make_shared<op::Result>(src_node); // Add [4], [5], [6], [7]
res_node->set_placement(src_node->get_placement()); res_node->set_placement(src_node->get_placement());
...@@ -406,18 +364,6 @@ void ngraph::insert_new_node_between(const shared_ptr<Node>& src_node, ...@@ -406,18 +364,6 @@ void ngraph::insert_new_node_between(const shared_ptr<Node>& src_node,
descriptor::Output* src_output = src_node->get_output_to(dst_node); descriptor::Output* src_output = src_node->get_output_to(dst_node);
src_output->remove_input(dst_input); // Remove [0] src_output->remove_input(dst_input); // Remove [0]
dst_input->replace_output(new_node, 0); // Remove [0] (again), add [8], remove [1], add [9] dst_input->replace_output(new_node, 0); // Remove [0] (again), add [8], remove [1], add [9]
// Fix user / argument
const_cast<multiset<Node*>&>(src_node->users()).erase(dst_node.get()); // Remove [2]
const_cast<multiset<Node*>&>(new_node->users()).insert(dst_node.get()); // Add [10]
auto& dst_args = const_cast<NodeVector&>(dst_node->get_arguments_FOR_GRAPH_REWRITE_ONLY());
auto it = find(dst_args.begin(), dst_args.end(), src_node);
if (it == dst_args.end())
{
throw ngraph_error("src_node is not an input to dst_node");
}
it = dst_args.erase(it); // Remove [3]
dst_args.insert(it, new_node); // Add [11]
} }
// Assert that nodes in the function is colocated and return that placement // Assert that nodes in the function is colocated and return that placement
......
...@@ -47,13 +47,8 @@ namespace ngraph ...@@ -47,13 +47,8 @@ namespace ngraph
void traverse_functions(std::shared_ptr<Function> p, void traverse_functions(std::shared_ptr<Function> p,
std::function<void(std::shared_ptr<Function>)> f); std::function<void(std::shared_ptr<Function>)> f);
void free_nodes(std::shared_ptr<Function>);
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement); void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement);
std::list<std::shared_ptr<Node>> std::list<std::shared_ptr<Node>>
topological_sort(const std::list<std::shared_ptr<Node>>& nodes); topological_sort(const std::list<std::shared_ptr<Node>>& nodes);
......
...@@ -39,13 +39,11 @@ Node::Node(const std::string& node_type, const NodeVector& arguments) ...@@ -39,13 +39,11 @@ Node::Node(const std::string& node_type, const NodeVector& arguments)
: m_node_type(node_type) : m_node_type(node_type)
, m_instance_id(m_next_instance_id.fetch_add(1)) , m_instance_id(m_next_instance_id.fetch_add(1))
, m_unique_name(description() + "_" + to_string(m_instance_id)) , m_unique_name(description() + "_" + to_string(m_instance_id))
, m_arguments(arguments)
{ {
// Add this node as a user of each argument. // Add this node as a user of each argument.
size_t i = 0; size_t i = 0;
for (auto arg : m_arguments) for (auto arg : arguments)
{ {
arg->m_users.insert(this);
for (descriptor::Output& output : arg->get_outputs()) for (descriptor::Output& output : arg->get_outputs())
{ {
m_inputs.emplace_back(this, i++, output); m_inputs.emplace_back(this, i++, output);
...@@ -146,9 +144,9 @@ void Node::set_placement(Placement placement) ...@@ -146,9 +144,9 @@ void Node::set_placement(Placement placement)
std::shared_ptr<Node> Node::get_input_op(size_t index) std::shared_ptr<Node> Node::get_input_op(size_t index)
{ {
for (auto arg : m_arguments) for (auto& i : get_inputs())
{ {
if (arg->get_outputs().size() != 1) if (i.get_output().get_node()->get_outputs().size() != 1)
{ {
throw "get_input_op called on an argument w/ multiple outputs"; throw "get_input_op called on an argument w/ multiple outputs";
} }
...@@ -156,7 +154,15 @@ std::shared_ptr<Node> Node::get_input_op(size_t index) ...@@ -156,7 +154,15 @@ std::shared_ptr<Node> Node::get_input_op(size_t index)
return m_inputs.at(index).get_output().get_node(); return m_inputs.at(index).get_output().get_node();
} }
NodeVector Node::get_input_ops() //const Node::~Node()
{
for (auto& input : m_inputs)
{
input.get_output().remove_input(&input);
}
}
NodeVector Node::get_input_ops()
{ {
NodeVector result; NodeVector result;
for (auto& i : get_inputs()) for (auto& i : get_inputs())
...@@ -165,10 +171,6 @@ NodeVector Node::get_input_ops() //const ...@@ -165,10 +171,6 @@ NodeVector Node::get_input_ops() //const
result.push_back(i.get_output().get_node()); result.push_back(i.get_output().get_node());
} }
} }
if (m_arguments != result)
{
throw ngraph_error("Arguments aren't equal: different values");
}
return result; return result;
} }
......
...@@ -79,17 +79,7 @@ namespace ngraph ...@@ -79,17 +79,7 @@ namespace ngraph
protected: protected:
Node(const std::string& node_type, const NodeVector& arguments); Node(const std::string& node_type, const NodeVector& arguments);
virtual ~Node() virtual ~Node();
{
for (auto arg : m_arguments)
{
arg->m_users.erase(this);
}
for (auto& input : m_inputs)
{
input.get_output().remove_input(&input);
}
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) {} virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) {}
public: public:
...@@ -98,8 +88,6 @@ namespace ngraph ...@@ -98,8 +88,6 @@ namespace ngraph
const std::string& get_friendly_name() const; const std::string& get_friendly_name() const;
const std::string& get_name() const; const std::string& get_name() const;
void set_name(const std::string& name); void set_name(const std::string& name);
void clear_arguments() { m_arguments.clear(); }
const std::multiset<Node*>& users() const { return m_users; }
/// Return true if this has the same implementing class as node. This /// Return true if this has the same implementing class as node. This
/// will be used by the pattern matcher when comparing a pattern /// will be used by the pattern matcher when comparing a pattern
/// graph against the graph. /// graph against the graph.
...@@ -207,7 +195,6 @@ namespace ngraph ...@@ -207,7 +195,6 @@ namespace ngraph
void add_output(const element::Type& element_type, const Shape& shape); void add_output(const element::Type& element_type, const Shape& shape);
std::string m_node_type; std::string m_node_type;
std::multiset<Node*> m_users;
size_t m_instance_id; size_t m_instance_id;
std::string m_name; std::string m_name;
const std::string m_unique_name; const std::string m_unique_name;
...@@ -216,13 +203,5 @@ namespace ngraph ...@@ -216,13 +203,5 @@ namespace ngraph
std::deque<descriptor::Output> m_outputs; std::deque<descriptor::Output> m_outputs;
std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map; std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map;
Placement m_placement = Placement::DEFAULT; Placement m_placement = Placement::DEFAULT;
private:
NodeVector m_arguments;
//m_arguments still needs to be kept in sync with i/o since get_input_ops
//is pretty ubiquitous and might be called after the original graph was modified.
//get_input_ops uses m_arguments to check if a node view reconstruction from i/o
//is correct.
NodeVector& get_arguments_FOR_GRAPH_REWRITE_ONLY() { return m_arguments; }
}; };
} }
...@@ -43,20 +43,6 @@ bool ngraph::pass::GetOutputElementElimination::run_on_function(std::shared_ptr< ...@@ -43,20 +43,6 @@ bool ngraph::pass::GetOutputElementElimination::run_on_function(std::shared_ptr<
{ {
auto multi = goe->get_inputs().at(0).get_output().get_node(); auto multi = goe->get_inputs().at(0).get_output().get_node();
input.replace_output(goe->get_inputs().at(goe->get_n()).get_output()); input.replace_output(goe->get_inputs().at(goe->get_n()).get_output());
//fix node arguments
auto& n_args =
const_cast<ngraph::NodeVector&>(n->get_arguments_FOR_GRAPH_REWRITE_ONLY());
auto it = std::find(begin(n_args), end(n_args), goe);
if (it == end(n_args))
{
throw ngraph_error("Expected to find GetOutputElement in n's inputs");
}
*it = multi;
//fix multi's users
const_cast<std::multiset<Node*>&>(multi->users()).insert(n.get());
//we don't need to fix anything w.r.t GetOutputElement as it will become unreachable //we don't need to fix anything w.r.t GetOutputElement as it will become unreachable
optimized = true; optimized = true;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment