Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
ddcfbda8
Unverified
Commit
ddcfbda8
authored
Apr 06, 2018
by
Nick Korovaiko
Committed by
GitHub
Apr 06, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Remove m_arguments and m_users (#816)
* make Input descriptors node owners * rename src_node to m_src_node
parent
577d5c6c
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
17 additions
and
105 deletions
+17
-105
input.cpp
src/ngraph/descriptor/input.cpp
+2
-0
input.hpp
src/ngraph/descriptor/input.hpp
+2
-0
graph_util.cpp
src/ngraph/graph_util.cpp
+0
-54
graph_util.hpp
src/ngraph/graph_util.hpp
+0
-5
node.cpp
src/ngraph/node.cpp
+12
-10
node.hpp
src/ngraph/node.hpp
+1
-22
get_output_element_elimination.cpp
src/ngraph/pass/get_output_element_elimination.cpp
+0
-14
No files found.
src/ngraph/descriptor/input.cpp
View file @
ddcfbda8
...
...
@@ -27,6 +27,7 @@ Input::Input(Node* node, size_t index, Output& output)
,
m_index
(
index
)
,
m_output
(
&
output
)
{
m_src_node
=
std
::
shared_ptr
<
Node
>
(
output
.
get_node
());
output
.
add_input
(
this
);
}
...
...
@@ -35,6 +36,7 @@ void Input::replace_output(Output& new_output)
m_output
->
remove_input
(
this
);
new_output
.
add_input
(
this
);
m_output
=
&
new_output
;
m_src_node
=
std
::
shared_ptr
<
Node
>
(
new_output
.
get_node
());
}
void
Input
::
replace_output
(
std
::
shared_ptr
<
Node
>
node
,
size_t
i
)
...
...
src/ngraph/descriptor/input.hpp
View file @
ddcfbda8
...
...
@@ -76,6 +76,8 @@ namespace ngraph
const
element
::
Type
&
get_element_type
()
const
;
protected
:
//owner of an argument node (in lieu of m_arguments)
std
::
shared_ptr
<
Node
>
m_src_node
;
Node
*
m_node
;
// The node we are an input for
size_t
m_index
;
// Index into all input tensors
Output
*
m_output
;
...
...
src/ngraph/graph_util.cpp
View file @
ddcfbda8
...
...
@@ -105,18 +105,6 @@ void ngraph::traverse_functions(std::shared_ptr<ngraph::Function> p,
}
}
void
ngraph
::
free_nodes
(
shared_ptr
<
Function
>
p
)
{
std
::
deque
<
Node
*>
sorted_list
;
traverse_nodes
(
p
,
[
&
](
shared_ptr
<
Node
>
n
)
{
sorted_list
.
push_front
(
n
.
get
());
});
for
(
Node
*
n
:
sorted_list
)
{
n
->
clear_arguments
();
}
}
void
ngraph
::
replace_node
(
std
::
shared_ptr
<
Node
>
target
,
std
::
shared_ptr
<
Node
>
replacement
)
{
if
(
target
->
is_output
())
...
...
@@ -140,24 +128,6 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
input
->
replace_output
(
replacement
->
get_outputs
().
at
(
i
));
}
}
// Fix users and arguments
replace_node_users_arguments
(
target
,
replacement
);
}
void
ngraph
::
replace_node_users_arguments
(
std
::
shared_ptr
<
Node
>
target
,
std
::
shared_ptr
<
Node
>
replacement
)
{
for
(
auto
user
:
target
->
users
())
{
auto
&
args
=
const_cast
<
ngraph
::
NodeVector
&>
(
user
->
get_arguments_FOR_GRAPH_REWRITE_ONLY
());
auto
it
=
std
::
find
(
begin
(
args
),
end
(
args
),
target
);
assert
(
it
!=
end
(
args
));
it
=
args
.
erase
(
it
);
args
.
insert
(
it
,
replacement
);
const_cast
<
std
::
multiset
<
Node
*>&>
(
replacement
->
users
()).
insert
(
user
);
}
const_cast
<
std
::
multiset
<
Node
*>&>
(
target
->
users
()).
clear
();
}
std
::
list
<
std
::
shared_ptr
<
ngraph
::
Node
>>
...
...
@@ -338,18 +308,6 @@ pair<shared_ptr<op::Result>, shared_ptr<op::Parameter>>
src_output
->
remove_input
(
dst_input
);
// Remove [0]
dst_input
->
replace_output
(
par_node
,
0
);
// Remove [0] (again), add [8], remove [1], add [9]
// Fix user / argument among src, dst and par
const_cast
<
multiset
<
Node
*>&>
(
src_node
->
users
()).
erase
(
dst_node
.
get
());
// Remove [2]
const_cast
<
multiset
<
Node
*>&>
(
par_node
->
users
()).
insert
(
dst_node
.
get
());
// Add [10]
auto
&
dst_args
=
const_cast
<
NodeVector
&>
(
dst_node
->
get_arguments_FOR_GRAPH_REWRITE_ONLY
());
auto
it
=
find
(
dst_args
.
begin
(),
dst_args
.
end
(),
src_node
);
if
(
it
==
dst_args
.
end
())
{
throw
ngraph_error
(
"src_node is not an input to dst_node"
);
}
it
=
dst_args
.
erase
(
it
);
// Remove [3]
dst_args
.
insert
(
it
,
par_node
);
// Add [11]
// Add res node
shared_ptr
<
op
::
Result
>
res_node
=
make_shared
<
op
::
Result
>
(
src_node
);
// Add [4], [5], [6], [7]
res_node
->
set_placement
(
src_node
->
get_placement
());
...
...
@@ -406,18 +364,6 @@ void ngraph::insert_new_node_between(const shared_ptr<Node>& src_node,
descriptor
::
Output
*
src_output
=
src_node
->
get_output_to
(
dst_node
);
src_output
->
remove_input
(
dst_input
);
// Remove [0]
dst_input
->
replace_output
(
new_node
,
0
);
// Remove [0] (again), add [8], remove [1], add [9]
// Fix user / argument
const_cast
<
multiset
<
Node
*>&>
(
src_node
->
users
()).
erase
(
dst_node
.
get
());
// Remove [2]
const_cast
<
multiset
<
Node
*>&>
(
new_node
->
users
()).
insert
(
dst_node
.
get
());
// Add [10]
auto
&
dst_args
=
const_cast
<
NodeVector
&>
(
dst_node
->
get_arguments_FOR_GRAPH_REWRITE_ONLY
());
auto
it
=
find
(
dst_args
.
begin
(),
dst_args
.
end
(),
src_node
);
if
(
it
==
dst_args
.
end
())
{
throw
ngraph_error
(
"src_node is not an input to dst_node"
);
}
it
=
dst_args
.
erase
(
it
);
// Remove [3]
dst_args
.
insert
(
it
,
new_node
);
// Add [11]
}
// Assert that nodes in the function is colocated and return that placement
...
...
src/ngraph/graph_util.hpp
View file @
ddcfbda8
...
...
@@ -47,13 +47,8 @@ namespace ngraph
void
traverse_functions
(
std
::
shared_ptr
<
Function
>
p
,
std
::
function
<
void
(
std
::
shared_ptr
<
Function
>
)
>
f
);
void
free_nodes
(
std
::
shared_ptr
<
Function
>
);
void
replace_node
(
std
::
shared_ptr
<
Node
>
target
,
std
::
shared_ptr
<
Node
>
replacement
);
void
replace_node_users_arguments
(
std
::
shared_ptr
<
Node
>
target
,
std
::
shared_ptr
<
Node
>
replacement
);
std
::
list
<
std
::
shared_ptr
<
Node
>>
topological_sort
(
const
std
::
list
<
std
::
shared_ptr
<
Node
>>&
nodes
);
...
...
src/ngraph/node.cpp
View file @
ddcfbda8
...
...
@@ -39,13 +39,11 @@ Node::Node(const std::string& node_type, const NodeVector& arguments)
:
m_node_type
(
node_type
)
,
m_instance_id
(
m_next_instance_id
.
fetch_add
(
1
))
,
m_unique_name
(
description
()
+
"_"
+
to_string
(
m_instance_id
))
,
m_arguments
(
arguments
)
{
// Add this node as a user of each argument.
size_t
i
=
0
;
for
(
auto
arg
:
m_
arguments
)
for
(
auto
arg
:
arguments
)
{
arg
->
m_users
.
insert
(
this
);
for
(
descriptor
::
Output
&
output
:
arg
->
get_outputs
())
{
m_inputs
.
emplace_back
(
this
,
i
++
,
output
);
...
...
@@ -146,9 +144,9 @@ void Node::set_placement(Placement placement)
std
::
shared_ptr
<
Node
>
Node
::
get_input_op
(
size_t
index
)
{
for
(
auto
arg
:
m_arguments
)
for
(
auto
&
i
:
get_inputs
()
)
{
if
(
arg
->
get_outputs
().
size
()
!=
1
)
if
(
i
.
get_output
().
get_node
()
->
get_outputs
().
size
()
!=
1
)
{
throw
"get_input_op called on an argument w/ multiple outputs"
;
}
...
...
@@ -156,7 +154,15 @@ std::shared_ptr<Node> Node::get_input_op(size_t index)
return
m_inputs
.
at
(
index
).
get_output
().
get_node
();
}
NodeVector
Node
::
get_input_ops
()
//const
Node
::~
Node
()
{
for
(
auto
&
input
:
m_inputs
)
{
input
.
get_output
().
remove_input
(
&
input
);
}
}
NodeVector
Node
::
get_input_ops
()
{
NodeVector
result
;
for
(
auto
&
i
:
get_inputs
())
...
...
@@ -165,10 +171,6 @@ NodeVector Node::get_input_ops() //const
result
.
push_back
(
i
.
get_output
().
get_node
());
}
}
if
(
m_arguments
!=
result
)
{
throw
ngraph_error
(
"Arguments aren't equal: different values"
);
}
return
result
;
}
...
...
src/ngraph/node.hpp
View file @
ddcfbda8
...
...
@@ -79,17 +79,7 @@ namespace ngraph
protected
:
Node
(
const
std
::
string
&
node_type
,
const
NodeVector
&
arguments
);
virtual
~
Node
()
{
for
(
auto
arg
:
m_arguments
)
{
arg
->
m_users
.
erase
(
this
);
}
for
(
auto
&
input
:
m_inputs
)
{
input
.
get_output
().
remove_input
(
&
input
);
}
}
virtual
~
Node
();
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
{}
public
:
...
...
@@ -98,8 +88,6 @@ namespace ngraph
const
std
::
string
&
get_friendly_name
()
const
;
const
std
::
string
&
get_name
()
const
;
void
set_name
(
const
std
::
string
&
name
);
void
clear_arguments
()
{
m_arguments
.
clear
();
}
const
std
::
multiset
<
Node
*>&
users
()
const
{
return
m_users
;
}
/// Return true if this has the same implementing class as node. This
/// will be used by the pattern matcher when comparing a pattern
/// graph against the graph.
...
...
@@ -207,7 +195,6 @@ namespace ngraph
void
add_output
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
);
std
::
string
m_node_type
;
std
::
multiset
<
Node
*>
m_users
;
size_t
m_instance_id
;
std
::
string
m_name
;
const
std
::
string
m_unique_name
;
...
...
@@ -216,13 +203,5 @@ namespace ngraph
std
::
deque
<
descriptor
::
Output
>
m_outputs
;
std
::
unordered_map
<
Node
*
,
autodiff
::
Adjoints
>
m_adjoint_map
;
Placement
m_placement
=
Placement
::
DEFAULT
;
private
:
NodeVector
m_arguments
;
//m_arguments still needs to be kept in sync with i/o since get_input_ops
//is pretty ubiquitous and might be called after the original graph was modified.
//get_input_ops uses m_arguments to check if a node view reconstruction from i/o
//is correct.
NodeVector
&
get_arguments_FOR_GRAPH_REWRITE_ONLY
()
{
return
m_arguments
;
}
};
}
src/ngraph/pass/get_output_element_elimination.cpp
View file @
ddcfbda8
...
...
@@ -43,20 +43,6 @@ bool ngraph::pass::GetOutputElementElimination::run_on_function(std::shared_ptr<
{
auto
multi
=
goe
->
get_inputs
().
at
(
0
).
get_output
().
get_node
();
input
.
replace_output
(
goe
->
get_inputs
().
at
(
goe
->
get_n
()).
get_output
());
//fix node arguments
auto
&
n_args
=
const_cast
<
ngraph
::
NodeVector
&>
(
n
->
get_arguments_FOR_GRAPH_REWRITE_ONLY
());
auto
it
=
std
::
find
(
begin
(
n_args
),
end
(
n_args
),
goe
);
if
(
it
==
end
(
n_args
))
{
throw
ngraph_error
(
"Expected to find GetOutputElement in n's inputs"
);
}
*
it
=
multi
;
//fix multi's users
const_cast
<
std
::
multiset
<
Node
*>&>
(
multi
->
users
()).
insert
(
n
.
get
());
//we don't need to fix anything w.r.t GetOutputElement as it will become unreachable
optimized
=
true
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment