Unverified Commit ddcfbda8 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

Remove m_arguments and m_users (#816)

* make Input descriptors node owners

* rename src_node to m_src_node
parent 577d5c6c
......@@ -27,6 +27,7 @@ Input::Input(Node* node, size_t index, Output& output)
, m_index(index)
, m_output(&output)
{
m_src_node = std::shared_ptr<Node>(output.get_node());
output.add_input(this);
}
......@@ -35,6 +36,7 @@ void Input::replace_output(Output& new_output)
m_output->remove_input(this);
new_output.add_input(this);
m_output = &new_output;
m_src_node = std::shared_ptr<Node>(new_output.get_node());
}
void Input::replace_output(std::shared_ptr<Node> node, size_t i)
......
......@@ -76,6 +76,8 @@ namespace ngraph
const element::Type& get_element_type() const;
protected:
//owner of an argument node (in lieu of m_arguments)
std::shared_ptr<Node> m_src_node;
Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors
Output* m_output;
......
......@@ -105,18 +105,6 @@ void ngraph::traverse_functions(std::shared_ptr<ngraph::Function> p,
}
}
void ngraph::free_nodes(shared_ptr<Function> p)
{
std::deque<Node*> sorted_list;
traverse_nodes(p, [&](shared_ptr<Node> n) { sorted_list.push_front(n.get()); });
for (Node* n : sorted_list)
{
n->clear_arguments();
}
}
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
{
if (target->is_output())
......@@ -140,24 +128,6 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
input->replace_output(replacement->get_outputs().at(i));
}
}
// Fix users and arguments
replace_node_users_arguments(target, replacement);
}
void ngraph::replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement)
{
for (auto user : target->users())
{
auto& args = const_cast<ngraph::NodeVector&>(user->get_arguments_FOR_GRAPH_REWRITE_ONLY());
auto it = std::find(begin(args), end(args), target);
assert(it != end(args));
it = args.erase(it);
args.insert(it, replacement);
const_cast<std::multiset<Node*>&>(replacement->users()).insert(user);
}
const_cast<std::multiset<Node*>&>(target->users()).clear();
}
std::list<std::shared_ptr<ngraph::Node>>
......@@ -338,18 +308,6 @@ pair<shared_ptr<op::Result>, shared_ptr<op::Parameter>>
src_output->remove_input(dst_input); // Remove [0]
dst_input->replace_output(par_node, 0); // Remove [0] (again), add [8], remove [1], add [9]
// Fix user / argument among src, dst and par
const_cast<multiset<Node*>&>(src_node->users()).erase(dst_node.get()); // Remove [2]
const_cast<multiset<Node*>&>(par_node->users()).insert(dst_node.get()); // Add [10]
auto& dst_args = const_cast<NodeVector&>(dst_node->get_arguments_FOR_GRAPH_REWRITE_ONLY());
auto it = find(dst_args.begin(), dst_args.end(), src_node);
if (it == dst_args.end())
{
throw ngraph_error("src_node is not an input to dst_node");
}
it = dst_args.erase(it); // Remove [3]
dst_args.insert(it, par_node); // Add [11]
// Add res node
shared_ptr<op::Result> res_node = make_shared<op::Result>(src_node); // Add [4], [5], [6], [7]
res_node->set_placement(src_node->get_placement());
......@@ -406,18 +364,6 @@ void ngraph::insert_new_node_between(const shared_ptr<Node>& src_node,
descriptor::Output* src_output = src_node->get_output_to(dst_node);
src_output->remove_input(dst_input); // Remove [0]
dst_input->replace_output(new_node, 0); // Remove [0] (again), add [8], remove [1], add [9]
// Fix user / argument
const_cast<multiset<Node*>&>(src_node->users()).erase(dst_node.get()); // Remove [2]
const_cast<multiset<Node*>&>(new_node->users()).insert(dst_node.get()); // Add [10]
auto& dst_args = const_cast<NodeVector&>(dst_node->get_arguments_FOR_GRAPH_REWRITE_ONLY());
auto it = find(dst_args.begin(), dst_args.end(), src_node);
if (it == dst_args.end())
{
throw ngraph_error("src_node is not an input to dst_node");
}
it = dst_args.erase(it); // Remove [3]
dst_args.insert(it, new_node); // Add [11]
}
// Assert that nodes in the function is colocated and return that placement
......
......@@ -47,13 +47,8 @@ namespace ngraph
void traverse_functions(std::shared_ptr<Function> p,
std::function<void(std::shared_ptr<Function>)> f);
void free_nodes(std::shared_ptr<Function>);
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement);
std::list<std::shared_ptr<Node>>
topological_sort(const std::list<std::shared_ptr<Node>>& nodes);
......
......@@ -39,13 +39,11 @@ Node::Node(const std::string& node_type, const NodeVector& arguments)
: m_node_type(node_type)
, m_instance_id(m_next_instance_id.fetch_add(1))
, m_unique_name(description() + "_" + to_string(m_instance_id))
, m_arguments(arguments)
{
// Add this node as a user of each argument.
size_t i = 0;
for (auto arg : m_arguments)
for (auto arg : arguments)
{
arg->m_users.insert(this);
for (descriptor::Output& output : arg->get_outputs())
{
m_inputs.emplace_back(this, i++, output);
......@@ -146,9 +144,9 @@ void Node::set_placement(Placement placement)
std::shared_ptr<Node> Node::get_input_op(size_t index)
{
for (auto arg : m_arguments)
for (auto& i : get_inputs())
{
if (arg->get_outputs().size() != 1)
if (i.get_output().get_node()->get_outputs().size() != 1)
{
throw "get_input_op called on an argument w/ multiple outputs";
}
......@@ -156,7 +154,15 @@ std::shared_ptr<Node> Node::get_input_op(size_t index)
return m_inputs.at(index).get_output().get_node();
}
NodeVector Node::get_input_ops() //const
Node::~Node()
{
for (auto& input : m_inputs)
{
input.get_output().remove_input(&input);
}
}
NodeVector Node::get_input_ops()
{
NodeVector result;
for (auto& i : get_inputs())
......@@ -165,10 +171,6 @@ NodeVector Node::get_input_ops() //const
result.push_back(i.get_output().get_node());
}
}
if (m_arguments != result)
{
throw ngraph_error("Arguments aren't equal: different values");
}
return result;
}
......
......@@ -79,17 +79,7 @@ namespace ngraph
protected:
Node(const std::string& node_type, const NodeVector& arguments);
virtual ~Node()
{
for (auto arg : m_arguments)
{
arg->m_users.erase(this);
}
for (auto& input : m_inputs)
{
input.get_output().remove_input(&input);
}
}
virtual ~Node();
virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) {}
public:
......@@ -98,8 +88,6 @@ namespace ngraph
const std::string& get_friendly_name() const;
const std::string& get_name() const;
void set_name(const std::string& name);
void clear_arguments() { m_arguments.clear(); }
const std::multiset<Node*>& users() const { return m_users; }
/// Return true if this has the same implementing class as node. This
/// will be used by the pattern matcher when comparing a pattern
/// graph against the graph.
......@@ -207,7 +195,6 @@ namespace ngraph
void add_output(const element::Type& element_type, const Shape& shape);
std::string m_node_type;
std::multiset<Node*> m_users;
size_t m_instance_id;
std::string m_name;
const std::string m_unique_name;
......@@ -216,13 +203,5 @@ namespace ngraph
std::deque<descriptor::Output> m_outputs;
std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map;
Placement m_placement = Placement::DEFAULT;
private:
NodeVector m_arguments;
//m_arguments still needs to be kept in sync with i/o since get_input_ops
//is pretty ubiquitous and might be called after the original graph was modified.
//get_input_ops uses m_arguments to check if a node view reconstruction from i/o
//is correct.
NodeVector& get_arguments_FOR_GRAPH_REWRITE_ONLY() { return m_arguments; }
};
}
......@@ -43,20 +43,6 @@ bool ngraph::pass::GetOutputElementElimination::run_on_function(std::shared_ptr<
{
auto multi = goe->get_inputs().at(0).get_output().get_node();
input.replace_output(goe->get_inputs().at(goe->get_n()).get_output());
//fix node arguments
auto& n_args =
const_cast<ngraph::NodeVector&>(n->get_arguments_FOR_GRAPH_REWRITE_ONLY());
auto it = std::find(begin(n_args), end(n_args), goe);
if (it == end(n_args))
{
throw ngraph_error("Expected to find GetOutputElement in n's inputs");
}
*it = multi;
//fix multi's users
const_cast<std::multiset<Node*>&>(multi->users()).insert(n.get());
//we don't need to fix anything w.r.t GetOutputElement as it will become unreachable
optimized = true;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment