Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
4aff3ec0
Commit
4aff3ec0
authored
Oct 12, 2017
by
Robert Kimball
Committed by
GitHub
Oct 12, 2017
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #193 from NervanaSystems/bob/pass_ptr
change all passes to take shared_ptr rather than naked pointer
parents
96ae6fdb
14d74e83
Hide whitespace changes
Inline
Side-by-side
Showing
34 changed files
with
128 additions
and
127 deletions
+128
-127
function.cpp
src/ngraph/function.cpp
+6
-6
function.hpp
src/ngraph/function.hpp
+7
-7
function_call.hpp
src/ngraph/ops/function_call.hpp
+1
-1
assign_tensors.cpp
src/ngraph/pass/assign_tensors.cpp
+3
-3
assign_tensors.hpp
src/ngraph/pass/assign_tensors.hpp
+1
-1
collect_functions.cpp
src/ngraph/pass/collect_functions.cpp
+7
-7
collect_functions.hpp
src/ngraph/pass/collect_functions.hpp
+1
-1
dump_sorted.cpp
src/ngraph/pass/dump_sorted.cpp
+3
-3
dump_sorted.hpp
src/ngraph/pass/dump_sorted.hpp
+1
-1
liveness.cpp
src/ngraph/pass/liveness.cpp
+3
-3
liveness.hpp
src/ngraph/pass/liveness.hpp
+1
-1
manager.cpp
src/ngraph/pass/manager.cpp
+5
-10
manager.hpp
src/ngraph/pass/manager.hpp
+0
-1
manager_state.cpp
src/ngraph/pass/manager_state.cpp
+1
-1
manager_state.hpp
src/ngraph/pass/manager_state.hpp
+2
-2
memory_layout.cpp
src/ngraph/pass/memory_layout.cpp
+2
-2
memory_layout.hpp
src/ngraph/pass/memory_layout.hpp
+1
-1
memory_visualize.cpp
src/ngraph/pass/memory_visualize.cpp
+20
-20
memory_visualize.hpp
src/ngraph/pass/memory_visualize.hpp
+10
-10
pass.hpp
src/ngraph/pass/pass.hpp
+4
-4
propagate_types.cpp
src/ngraph/pass/propagate_types.cpp
+2
-2
propagate_types.hpp
src/ngraph/pass/propagate_types.hpp
+1
-1
topological_sort.cpp
src/ngraph/pass/topological_sort.cpp
+9
-7
topological_sort.hpp
src/ngraph/pass/topological_sort.hpp
+1
-1
visualize_tree.cpp
src/ngraph/pass/visualize_tree.cpp
+6
-6
visualize_tree.hpp
src/ngraph/pass/visualize_tree.hpp
+4
-4
external_function.cpp
src/ngraph/runtime/ngvm/external_function.cpp
+7
-5
util.cpp
src/ngraph/util.cpp
+8
-7
util.hpp
src/ngraph/util.hpp
+2
-1
pass_liveness.cpp
test/pass_liveness.cpp
+1
-1
pass_manager.cpp
test/pass_manager.cpp
+1
-1
test_tools.cpp
test/test_tools.cpp
+3
-3
test_tools.hpp
test/test_tools.hpp
+1
-1
topological_sort.cpp
test/topological_sort.cpp
+3
-2
No files found.
src/ngraph/function.cpp
View file @
4aff3ec0
...
@@ -39,26 +39,26 @@ Function::Function(const std::shared_ptr<Node>& result,
...
@@ -39,26 +39,26 @@ Function::Function(const std::shared_ptr<Node>& result,
parameter
->
assign_function
(
this
,
i
++
);
parameter
->
assign_function
(
this
,
i
++
);
}
}
traverse_nodes
(
result
,
[
&
](
Node
*
node
)
{
m_ops
.
push_back
(
node
);
});
traverse_nodes
(
result
,
[
&
](
shared_ptr
<
Node
>
node
)
{
m_ops
.
push_back
(
node
);
});
}
}
void
Function
::
set_ordered_ops
(
const
std
::
list
<
Node
*
>&
ordered_ops
)
void
Function
::
set_ordered_ops
(
const
std
::
list
<
shared_ptr
<
Node
>
>&
ordered_ops
)
{
{
m_ordered_ops
=
ordered_ops
;
m_ordered_ops
=
ordered_ops
;
m_ordered_ops_valid
=
true
;
m_ordered_ops_valid
=
true
;
}
}
std
::
list
<
Node
*
>&
Function
::
get_ops
()
std
::
list
<
shared_ptr
<
Node
>
>&
Function
::
get_ops
()
{
{
return
m_ops
;
return
m_ops
;
}
}
const
std
::
list
<
Node
*
>&
Function
::
get_ops
()
const
const
std
::
list
<
shared_ptr
<
Node
>
>&
Function
::
get_ops
()
const
{
{
return
m_ops
;
return
m_ops
;
}
}
std
::
list
<
Node
*
>&
Function
::
get_ordered_ops
()
std
::
list
<
shared_ptr
<
Node
>
>&
Function
::
get_ordered_ops
()
{
{
if
(
!
m_ordered_ops_valid
)
if
(
!
m_ordered_ops_valid
)
{
{
...
@@ -67,7 +67,7 @@ std::list<Node*>& Function::get_ordered_ops()
...
@@ -67,7 +67,7 @@ std::list<Node*>& Function::get_ordered_ops()
return
m_ordered_ops
;
return
m_ordered_ops
;
}
}
const
std
::
list
<
Node
*
>&
Function
::
get_ordered_ops
()
const
const
std
::
list
<
shared_ptr
<
Node
>
>&
Function
::
get_ordered_ops
()
const
{
{
if
(
!
m_ordered_ops_valid
)
if
(
!
m_ordered_ops_valid
)
{
{
...
...
src/ngraph/function.hpp
View file @
4aff3ec0
...
@@ -47,11 +47,11 @@ namespace ngraph
...
@@ -47,11 +47,11 @@ namespace ngraph
const
std
::
shared_ptr
<
ValueType
>
get_result_type
()
const
{
return
m_result_type
;
}
const
std
::
shared_ptr
<
ValueType
>
get_result_type
()
const
{
return
m_result_type
;
}
std
::
string
get_name
()
const
;
std
::
string
get_name
()
const
;
void
set_name
(
const
std
::
string
&
name
);
void
set_name
(
const
std
::
string
&
name
);
std
::
list
<
Node
*
>&
get_ops
();
std
::
list
<
std
::
shared_ptr
<
Node
>
>&
get_ops
();
const
std
::
list
<
Node
*
>&
get_ops
()
const
;
const
std
::
list
<
std
::
shared_ptr
<
Node
>
>&
get_ops
()
const
;
std
::
list
<
Node
*
>&
get_ordered_ops
();
std
::
list
<
std
::
shared_ptr
<
Node
>
>&
get_ordered_ops
();
const
std
::
list
<
Node
*
>&
get_ordered_ops
()
const
;
const
std
::
list
<
std
::
shared_ptr
<
Node
>
>&
get_ordered_ops
()
const
;
void
set_ordered_ops
(
const
std
::
list
<
Node
*
>&
);
void
set_ordered_ops
(
const
std
::
list
<
std
::
shared_ptr
<
Node
>
>&
);
void
set_ordered_ops_valid
()
{
m_ordered_ops_valid
=
true
;
}
void
set_ordered_ops_valid
()
{
m_ordered_ops_valid
=
true
;
}
void
clear_ordered_ops_valid
()
{
m_ordered_ops_valid
=
false
;
}
void
clear_ordered_ops_valid
()
{
m_ordered_ops_valid
=
false
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Function
&
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Function
&
);
...
@@ -62,8 +62,8 @@ namespace ngraph
...
@@ -62,8 +62,8 @@ namespace ngraph
std
::
string
m_name
;
std
::
string
m_name
;
std
::
shared_ptr
<
ValueType
>
m_result_type
;
std
::
shared_ptr
<
ValueType
>
m_result_type
;
bool
m_ordered_ops_valid
;
bool
m_ordered_ops_valid
;
std
::
list
<
Node
*
>
m_ordered_ops
;
std
::
list
<
std
::
shared_ptr
<
Node
>
>
m_ordered_ops
;
std
::
list
<
Node
*
>
m_ops
;
std
::
list
<
std
::
shared_ptr
<
Node
>
>
m_ops
;
private
:
private
:
Function
(
const
Function
&
)
=
delete
;
Function
(
const
Function
&
)
=
delete
;
...
...
src/ngraph/ops/function_call.hpp
View file @
4aff3ec0
...
@@ -28,7 +28,7 @@ namespace ngraph
...
@@ -28,7 +28,7 @@ namespace ngraph
/// @param function The function to be called
/// @param function The function to be called
/// @param args The function arguments
/// @param args The function arguments
///
///
FunctionCall
(
const
std
::
shared_ptr
<
Function
>&
function
,
FunctionCall
(
std
::
shared_ptr
<
Function
>
function
,
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
args
)
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
args
)
:
Builtin
(
args
)
:
Builtin
(
args
)
,
m_function
(
function
)
,
m_function
(
function
)
...
...
src/ngraph/pass/assign_tensors.cpp
View file @
4aff3ec0
...
@@ -25,15 +25,15 @@
...
@@ -25,15 +25,15 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
bool
pass
::
AssignTensors
::
run_on_call_graph
(
list
<
Node
*
>&
nodes
)
bool
pass
::
AssignTensors
::
run_on_call_graph
(
list
<
std
::
shared_ptr
<
Node
>
>&
nodes
)
{
{
for
(
Node
*
node
:
nodes
)
for
(
shared_ptr
<
Node
>
node
:
nodes
)
{
{
try
try
{
{
// We need to set the nodes is_output state prior to call assign_tensors
// We need to set the nodes is_output state prior to call assign_tensors
// so that the output state can be passes to the constructed tensors.
// so that the output state can be passes to the constructed tensors.
if
(
node
==
get_state
().
get_functions
().
at
(
0
)
->
get_result
()
.
get
()
)
if
(
node
==
get_state
().
get_functions
().
at
(
0
)
->
get_result
())
{
{
node
->
set_is_output
();
node
->
set_is_output
();
}
}
...
...
src/ngraph/pass/assign_tensors.hpp
View file @
4aff3ec0
...
@@ -27,7 +27,7 @@ namespace ngraph
...
@@ -27,7 +27,7 @@ namespace ngraph
class
ngraph
::
pass
::
AssignTensors
:
public
CallGraphPass
class
ngraph
::
pass
::
AssignTensors
:
public
CallGraphPass
{
{
public
:
public
:
virtual
bool
run_on_call_graph
(
std
::
list
<
Node
*
>&
nodes
)
override
;
virtual
bool
run_on_call_graph
(
std
::
list
<
std
::
shared_ptr
<
Node
>
>&
nodes
)
override
;
private
:
private
:
};
};
src/ngraph/pass/collect_functions.cpp
View file @
4aff3ec0
...
@@ -24,22 +24,22 @@ using namespace std;
...
@@ -24,22 +24,22 @@ using namespace std;
using
namespace
ngraph
;
using
namespace
ngraph
;
using
namespace
ngraph
::
pass
;
using
namespace
ngraph
::
pass
;
bool
CollectFunctions
::
run_on_function
(
ngraph
::
Function
*
func
)
bool
CollectFunctions
::
run_on_function
(
shared_ptr
<
ngraph
::
Function
>
func
)
{
{
set
<
Function
*
>
functions
;
set
<
shared_ptr
<
ngraph
::
Function
>
>
functions
;
deque
<
Function
*
>
stack
;
deque
<
shared_ptr
<
ngraph
::
Function
>
>
stack
;
stack
.
push_back
(
func
);
stack
.
push_back
(
func
);
while
(
stack
.
empty
()
==
false
)
while
(
stack
.
empty
()
==
false
)
{
{
Function
*
f
=
stack
.
front
();
shared_ptr
<
ngraph
::
Function
>
f
=
stack
.
front
();
stack
.
pop_front
();
stack
.
pop_front
();
functions
.
insert
(
f
);
functions
.
insert
(
f
);
traverse_nodes
(
f
->
get_result
(),
[
&
](
Node
*
node
)
{
traverse_nodes
(
f
->
get_result
(),
[
&
](
shared_ptr
<
Node
>
node
)
{
op
::
FunctionCall
*
fc
=
dynamic_cast
<
op
::
FunctionCall
*
>
(
node
);
shared_ptr
<
op
::
FunctionCall
>
fc
=
dynamic_pointer_cast
<
op
::
FunctionCall
>
(
node
);
if
(
fc
)
if
(
fc
)
{
{
stack
.
push_back
(
fc
->
get_function
()
.
get
()
);
stack
.
push_back
(
fc
->
get_function
());
}
}
});
});
}
}
...
...
src/ngraph/pass/collect_functions.hpp
View file @
4aff3ec0
...
@@ -27,7 +27,7 @@ namespace ngraph
...
@@ -27,7 +27,7 @@ namespace ngraph
class
ngraph
::
pass
::
CollectFunctions
:
public
FunctionPass
class
ngraph
::
pass
::
CollectFunctions
:
public
FunctionPass
{
{
public
:
public
:
bool
run_on_function
(
ngraph
::
Function
*
)
override
;
bool
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
)
override
;
private
:
private
:
};
};
src/ngraph/pass/dump_sorted.cpp
View file @
4aff3ec0
...
@@ -28,14 +28,14 @@ pass::DumpSorted::DumpSorted(const string& output_file)
...
@@ -28,14 +28,14 @@ pass::DumpSorted::DumpSorted(const string& output_file)
{
{
}
}
bool
pass
::
DumpSorted
::
run_on_module
(
vector
<
Function
*
>&
functions
)
bool
pass
::
DumpSorted
::
run_on_module
(
vector
<
shared_ptr
<
ngraph
::
Function
>
>&
functions
)
{
{
ofstream
out
{
m_output_file
};
ofstream
out
{
m_output_file
};
if
(
out
)
if
(
out
)
{
{
for
(
Function
*
f
:
functions
)
for
(
shared_ptr
<
Function
>
f
:
functions
)
{
{
for
(
const
Node
*
node
:
f
->
get_ordered_ops
())
for
(
const
shared_ptr
<
Node
>&
node
:
f
->
get_ordered_ops
())
{
{
out
<<
node
->
get_name
()
<<
"("
;
out
<<
node
->
get_name
()
<<
"("
;
vector
<
string
>
inputs
;
vector
<
string
>
inputs
;
...
...
src/ngraph/pass/dump_sorted.hpp
View file @
4aff3ec0
...
@@ -31,7 +31,7 @@ class ngraph::pass::DumpSorted : public ModulePass
...
@@ -31,7 +31,7 @@ class ngraph::pass::DumpSorted : public ModulePass
public
:
public
:
DumpSorted
(
const
std
::
string
&
output_file
);
DumpSorted
(
const
std
::
string
&
output_file
);
virtual
bool
run_on_module
(
std
::
vector
<
Function
*
>&
)
override
;
virtual
bool
run_on_module
(
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Function
>
>&
)
override
;
private
:
private
:
const
std
::
string
m_output_file
;
const
std
::
string
m_output_file
;
...
...
src/ngraph/pass/liveness.cpp
View file @
4aff3ec0
...
@@ -28,13 +28,13 @@ using namespace std;
...
@@ -28,13 +28,13 @@ using namespace std;
using
namespace
ngraph
;
using
namespace
ngraph
;
using
namespace
ngraph
::
descriptor
;
using
namespace
ngraph
::
descriptor
;
bool
pass
::
Liveness
::
run_on_call_graph
(
list
<
Node
*
>&
ops
)
bool
pass
::
Liveness
::
run_on_call_graph
(
list
<
shared_ptr
<
Node
>
>&
ops
)
{
{
unordered_set
<
Tensor
*>
currently_live
;
unordered_set
<
Tensor
*>
currently_live
;
for
(
auto
it
=
ops
.
rbegin
();
it
!=
ops
.
rend
();
it
++
)
for
(
auto
it
=
ops
.
rbegin
();
it
!=
ops
.
rend
();
it
++
)
{
{
Node
*
node
=
*
it
;
shared_ptr
<
Node
>
node
=
*
it
;
node
->
liveness_live_list
.
clear
();
node
->
liveness_live_list
.
clear
();
node
->
liveness_new_list
.
clear
();
node
->
liveness_new_list
.
clear
();
node
->
liveness_free_list
.
clear
();
node
->
liveness_free_list
.
clear
();
...
@@ -91,7 +91,7 @@ bool pass::Liveness::run_on_call_graph(list<Node*>& ops)
...
@@ -91,7 +91,7 @@ bool pass::Liveness::run_on_call_graph(list<Node*>& ops)
// Add outputs to live_list and remove from free_list
// Add outputs to live_list and remove from free_list
unordered_set
<
Tensor
*>
outputs
;
unordered_set
<
Tensor
*>
outputs
;
unordered_set
<
Tensor
*>
seen
;
unordered_set
<
Tensor
*>
seen
;
for
(
Node
*
node
:
ops
)
for
(
shared_ptr
<
Node
>
node
:
ops
)
{
{
for
(
Tensor
*
tensor
:
node
->
liveness_live_list
)
for
(
Tensor
*
tensor
:
node
->
liveness_live_list
)
{
{
...
...
src/ngraph/pass/liveness.hpp
View file @
4aff3ec0
...
@@ -28,7 +28,7 @@ namespace ngraph
...
@@ -28,7 +28,7 @@ namespace ngraph
class
ngraph
::
pass
::
Liveness
:
public
CallGraphPass
class
ngraph
::
pass
::
Liveness
:
public
CallGraphPass
{
{
public
:
public
:
virtual
bool
run_on_call_graph
(
std
::
list
<
Node
*
>&
)
override
;
virtual
bool
run_on_call_graph
(
std
::
list
<
std
::
shared_ptr
<
Node
>
>&
)
override
;
private
:
private
:
bool
is_temporary
(
const
descriptor
::
Tensor
&
);
bool
is_temporary
(
const
descriptor
::
Tensor
&
);
...
...
src/ngraph/pass/manager.cpp
View file @
4aff3ec0
...
@@ -38,12 +38,7 @@ void ngraph::pass::Manager::initialize_default_passes()
...
@@ -38,12 +38,7 @@ void ngraph::pass::Manager::initialize_default_passes()
void
ngraph
::
pass
::
Manager
::
run_passes
(
shared_ptr
<
Function
>
func
)
void
ngraph
::
pass
::
Manager
::
run_passes
(
shared_ptr
<
Function
>
func
)
{
{
run_passes
(
func
.
get
());
vector
<
shared_ptr
<
Function
>>
fs
=
{
func
};
}
void
ngraph
::
pass
::
Manager
::
run_passes
(
Function
*
func
)
{
vector
<
Function
*>
fs
=
{
func
};
get_state
().
set_functions
(
fs
);
get_state
().
set_functions
(
fs
);
for
(
shared_ptr
<
PassBase
>
pass
:
m_pass_list
)
for
(
shared_ptr
<
PassBase
>
pass
:
m_pass_list
)
...
@@ -59,16 +54,16 @@ void ngraph::pass::Manager::run_passes(Function* func)
...
@@ -59,16 +54,16 @@ void ngraph::pass::Manager::run_passes(Function* func)
}
}
else
if
(
function_pass
)
else
if
(
function_pass
)
{
{
for
(
Function
*
f
:
fs
)
for
(
shared_ptr
<
Function
>
f
:
fs
)
{
{
function_pass
->
run_on_function
(
f
);
function_pass
->
run_on_function
(
f
);
}
}
}
}
else
if
(
node_pass
)
else
if
(
node_pass
)
{
{
for
(
Function
*
f
:
fs
)
for
(
shared_ptr
<
Function
>
f
:
fs
)
{
{
for
(
Node
*
n
:
f
->
get_ops
())
for
(
shared_ptr
<
Node
>
n
:
f
->
get_ops
())
{
{
node_pass
->
run_on_node
(
n
);
node_pass
->
run_on_node
(
n
);
}
}
...
@@ -76,7 +71,7 @@ void ngraph::pass::Manager::run_passes(Function* func)
...
@@ -76,7 +71,7 @@ void ngraph::pass::Manager::run_passes(Function* func)
}
}
else
if
(
call_graph_pass
)
else
if
(
call_graph_pass
)
{
{
for
(
Function
*
f
:
fs
)
for
(
shared_ptr
<
Function
>
f
:
fs
)
{
{
call_graph_pass
->
run_on_call_graph
(
f
->
get_ordered_ops
());
call_graph_pass
->
run_on_call_graph
(
f
->
get_ordered_ops
());
}
}
...
...
src/ngraph/pass/manager.hpp
View file @
4aff3ec0
...
@@ -47,7 +47,6 @@ public:
...
@@ -47,7 +47,6 @@ public:
m_pass_list
.
push_back
(
pass_base
);
m_pass_list
.
push_back
(
pass_base
);
}
}
void
run_passes
(
Function
*
);
void
run_passes
(
std
::
shared_ptr
<
Function
>
);
void
run_passes
(
std
::
shared_ptr
<
Function
>
);
ManagerState
&
get_state
();
ManagerState
&
get_state
();
...
...
src/ngraph/pass/manager_state.cpp
View file @
4aff3ec0
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
vector
<
Function
*
>&
ngraph
::
pass
::
ManagerState
::
get_functions
()
vector
<
shared_ptr
<
Function
>
>&
ngraph
::
pass
::
ManagerState
::
get_functions
()
{
{
return
m_function_list
;
return
m_function_list
;
}
}
...
...
src/ngraph/pass/manager_state.hpp
View file @
4aff3ec0
...
@@ -30,7 +30,7 @@ namespace ngraph
...
@@ -30,7 +30,7 @@ namespace ngraph
class
ngraph
::
pass
::
ManagerState
class
ngraph
::
pass
::
ManagerState
{
{
public
:
public
:
std
::
vector
<
Function
*
>&
get_functions
();
std
::
vector
<
std
::
shared_ptr
<
Function
>
>&
get_functions
();
template
<
typename
T
>
template
<
typename
T
>
void
set_functions
(
const
T
&
collection
)
void
set_functions
(
const
T
&
collection
)
...
@@ -44,5 +44,5 @@ public:
...
@@ -44,5 +44,5 @@ public:
private
:
private
:
size_t
m_temporary_pool_size
=
0
;
size_t
m_temporary_pool_size
=
0
;
std
::
vector
<
Function
*
>
m_function_list
;
std
::
vector
<
std
::
shared_ptr
<
Function
>
>
m_function_list
;
};
};
src/ngraph/pass/memory_layout.cpp
View file @
4aff3ec0
...
@@ -26,10 +26,10 @@ using namespace std;
...
@@ -26,10 +26,10 @@ using namespace std;
using
namespace
ngraph
;
using
namespace
ngraph
;
using
namespace
ngraph
::
descriptor
;
using
namespace
ngraph
::
descriptor
;
bool
pass
::
MemoryLayout
::
run_on_call_graph
(
std
::
list
<
Node
*
>&
node_list
)
bool
pass
::
MemoryLayout
::
run_on_call_graph
(
std
::
list
<
std
::
shared_ptr
<
Node
>
>&
node_list
)
{
{
MemoryManager
mm
;
MemoryManager
mm
;
for
(
const
Node
*
node
:
node_list
)
for
(
shared_ptr
<
Node
>
node
:
node_list
)
{
{
for
(
Tensor
*
tensor
:
node
->
liveness_new_list
)
for
(
Tensor
*
tensor
:
node
->
liveness_new_list
)
{
{
...
...
src/ngraph/pass/memory_layout.hpp
View file @
4aff3ec0
...
@@ -33,7 +33,7 @@ namespace ngraph
...
@@ -33,7 +33,7 @@ namespace ngraph
class
ngraph
::
pass
::
MemoryLayout
:
public
CallGraphPass
class
ngraph
::
pass
::
MemoryLayout
:
public
CallGraphPass
{
{
public
:
public
:
virtual
bool
run_on_call_graph
(
std
::
list
<
Node
*
>&
)
override
;
virtual
bool
run_on_call_graph
(
std
::
list
<
std
::
shared_ptr
<
Node
>
>&
)
override
;
private
:
private
:
};
};
...
...
src/ngraph/pass/memory_visualize.cpp
View file @
4aff3ec0
...
@@ -32,13 +32,13 @@ pass::MemoryVisualize::MemoryVisualize(const string& filename)
...
@@ -32,13 +32,13 @@ pass::MemoryVisualize::MemoryVisualize(const string& filename)
{
{
}
}
bool
pass
::
MemoryVisualize
::
run_on_module
(
vector
<
Function
*
>&
functions
)
bool
pass
::
MemoryVisualize
::
run_on_module
(
vector
<
shared_ptr
<
ngraph
::
Function
>
>&
functions
)
{
{
ofstream
file
(
m_filename
);
ofstream
file
(
m_filename
);
{
{
for
(
const
Function
*
f
:
functions
)
for
(
shared_ptr
<
Function
>
f
:
functions
)
{
{
const
list
<
Node
*
>
nodes
=
f
->
get_ordered_ops
();
list
<
shared_ptr
<
Node
>
>
nodes
=
f
->
get_ordered_ops
();
file
<<
"<!DOCTYPE html>
\n
<html>
\n
"
;
file
<<
"<!DOCTYPE html>
\n
<html>
\n
"
;
file
<<
"<head>
\n
"
;
file
<<
"<head>
\n
"
;
file
<<
" <style>
\n
"
;
file
<<
" <style>
\n
"
;
...
@@ -62,7 +62,7 @@ bool pass::MemoryVisualize::run_on_module(vector<Function*>& functions)
...
@@ -62,7 +62,7 @@ bool pass::MemoryVisualize::run_on_module(vector<Function*>& functions)
file
<<
"<body>
\n
"
;
file
<<
"<body>
\n
"
;
unordered_set
<
descriptor
::
Tensor
*>
tensors
;
unordered_set
<
descriptor
::
Tensor
*>
tensors
;
size_t
temp_max_size
=
0
;
size_t
temp_max_size
=
0
;
for
(
Node
*
node
:
nodes
)
for
(
shared_ptr
<
Node
>
node
:
nodes
)
{
{
tensors
.
insert
(
node
->
liveness_live_list
.
begin
(),
node
->
liveness_live_list
.
end
());
tensors
.
insert
(
node
->
liveness_live_list
.
begin
(),
node
->
liveness_live_list
.
end
());
}
}
...
@@ -96,11 +96,11 @@ bool pass::MemoryVisualize::run_on_module(vector<Function*>& functions)
...
@@ -96,11 +96,11 @@ bool pass::MemoryVisualize::run_on_module(vector<Function*>& functions)
return
false
;
return
false
;
}
}
const
Node
*
pass
::
MemoryVisualize
::
find_largest_op
(
const
list
<
Node
*
>&
nodes
)
shared_ptr
<
Node
>
pass
::
MemoryVisualize
::
find_largest_op
(
const
list
<
shared_ptr
<
Node
>
>&
nodes
)
{
{
const
Node
*
largest_op
=
nullptr
;
shared_ptr
<
Node
>
largest_op
=
nullptr
;
size_t
largest_size
=
0
;
size_t
largest_size
=
0
;
for
(
const
Node
*
exop
:
nodes
)
for
(
shared_ptr
<
Node
>
exop
:
nodes
)
{
{
size_t
size
=
0
;
size_t
size
=
0
;
for
(
const
Tensor
*
tensor
:
exop
->
liveness_live_list
)
for
(
const
Tensor
*
tensor
:
exop
->
liveness_live_list
)
...
@@ -116,9 +116,9 @@ const Node* pass::MemoryVisualize::find_largest_op(const list<Node*>& nodes)
...
@@ -116,9 +116,9 @@ const Node* pass::MemoryVisualize::find_largest_op(const list<Node*>& nodes)
return
largest_op
;
return
largest_op
;
}
}
void
pass
::
MemoryVisualize
::
draw_tensor_weight
(
ostream
&
file
,
const
list
<
Node
*
>&
nodes
)
void
pass
::
MemoryVisualize
::
draw_tensor_weight
(
ostream
&
file
,
const
list
<
shared_ptr
<
Node
>
>&
nodes
)
{
{
const
Node
*
largest_op
=
find_largest_op
(
nodes
);
shared_ptr
<
Node
>
largest_op
=
find_largest_op
(
nodes
);
if
(
largest_op
)
if
(
largest_op
)
{
{
...
@@ -130,7 +130,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>&
...
@@ -130,7 +130,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>&
unordered_map
<
const
Tensor
*
,
size_t
>
age_list
;
unordered_map
<
const
Tensor
*
,
size_t
>
age_list
;
vector
<
const
Tensor
*>
tensor_set
;
vector
<
const
Tensor
*>
tensor_set
;
unordered_map
<
const
Tensor
*
,
const
Node
*
>
generator_op
;
unordered_map
<
const
Tensor
*
,
shared_ptr
<
Node
>
>
generator_op
;
file
<<
"<table>
\n
"
;
file
<<
"<table>
\n
"
;
file
<<
" <tr>"
;
file
<<
" <tr>"
;
file
<<
"<th align=
\"
left
\"
>tensor</th>"
;
file
<<
"<th align=
\"
left
\"
>tensor</th>"
;
...
@@ -139,7 +139,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>&
...
@@ -139,7 +139,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>&
file
<<
"<th align=
\"
right
\"
>generator weight</th>"
;
file
<<
"<th align=
\"
right
\"
>generator weight</th>"
;
file
<<
"</tr>
\n
"
;
file
<<
"</tr>
\n
"
;
size_t
i
=
0
;
size_t
i
=
0
;
for
(
const
Node
*
exop
:
nodes
)
for
(
shared_ptr
<
Node
>
exop
:
nodes
)
{
{
for
(
const
Tensor
*
tensor
:
exop
->
liveness_new_list
)
for
(
const
Tensor
*
tensor
:
exop
->
liveness_new_list
)
{
{
...
@@ -179,7 +179,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>&
...
@@ -179,7 +179,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>&
}
}
}
}
void
pass
::
MemoryVisualize
::
draw_histogram
(
ostream
&
file
,
const
list
<
Node
*
>&
nodes
)
void
pass
::
MemoryVisualize
::
draw_histogram
(
ostream
&
file
,
const
list
<
shared_ptr
<
Node
>
>&
nodes
)
{
{
size_t
stroke_width
=
14
;
size_t
stroke_width
=
14
;
size_t
text_offset
=
4
;
size_t
text_offset
=
4
;
...
@@ -188,7 +188,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
...
@@ -188,7 +188,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
size_t
scale
=
width
-
offset
;
size_t
scale
=
width
-
offset
;
size_t
line_spacing
=
stroke_width
*
1.5
;
size_t
line_spacing
=
stroke_width
*
1.5
;
size_t
line_count
=
0
;
size_t
line_count
=
0
;
for
(
const
Node
*
node
:
nodes
)
for
(
shared_ptr
<
Node
>
node
:
nodes
)
{
{
(
void
)
node
;
(
void
)
node
;
line_count
+=
1
;
line_count
+=
1
;
...
@@ -198,7 +198,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
...
@@ -198,7 +198,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
file
<<
"<svg viewBox=
\"
0 0 "
<<
width
<<
" "
<<
height
<<
"
\"
>
\n
"
;
file
<<
"<svg viewBox=
\"
0 0 "
<<
width
<<
" "
<<
height
<<
"
\"
>
\n
"
;
size_t
y
=
0
;
size_t
y
=
0
;
for
(
const
Node
*
node
:
nodes
)
for
(
shared_ptr
<
Node
>
node
:
nodes
)
{
{
float
usage
=
float
(
MemoryVisualize
::
memory_usage
(
node
));
float
usage
=
float
(
MemoryVisualize
::
memory_usage
(
node
));
float
footprint
=
float
(
MemoryVisualize
::
memory_footprint
(
node
));
float
footprint
=
float
(
MemoryVisualize
::
memory_footprint
(
node
));
...
@@ -220,14 +220,14 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
...
@@ -220,14 +220,14 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
file
<<
"</svg>
\n
"
;
file
<<
"</svg>
\n
"
;
}
}
void
pass
::
MemoryVisualize
::
draw_op_influence
(
ostream
&
file
,
const
list
<
Node
*
>&
nodes
)
void
pass
::
MemoryVisualize
::
draw_op_influence
(
ostream
&
file
,
const
list
<
shared_ptr
<
Node
>
>&
nodes
)
{
{
file
<<
"<table>
\n
"
;
file
<<
"<table>
\n
"
;
file
<<
" <tr>"
;
file
<<
" <tr>"
;
file
<<
"<th align=
\"
left
\"
>op</th>"
;
file
<<
"<th align=
\"
left
\"
>op</th>"
;
file
<<
"<th align=
\"
right
\"
>influence</th>"
;
file
<<
"<th align=
\"
right
\"
>influence</th>"
;
file
<<
"</tr>
\n
"
;
file
<<
"</tr>
\n
"
;
for
(
const
Node
*
exop
:
nodes
)
for
(
shared_ptr
<
Node
>
exop
:
nodes
)
{
{
int
weight
=
compute_op_weight
(
exop
);
int
weight
=
compute_op_weight
(
exop
);
file
<<
" <tr>"
;
file
<<
" <tr>"
;
...
@@ -237,7 +237,7 @@ void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<Node*>&
...
@@ -237,7 +237,7 @@ void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<Node*>&
}
}
}
}
int
pass
::
MemoryVisualize
::
compute_op_weight
(
const
Node
*
exop
)
int
pass
::
MemoryVisualize
::
compute_op_weight
(
const
shared_ptr
<
Node
>
exop
)
{
{
int
mass
=
0
;
int
mass
=
0
;
// for input_decl in exop.input_decls:
// for input_decl in exop.input_decls:
...
@@ -265,17 +265,17 @@ int pass::MemoryVisualize::compute_op_weight(const Node* exop)
...
@@ -265,17 +265,17 @@ int pass::MemoryVisualize::compute_op_weight(const Node* exop)
return
mass
;
return
mass
;
}
}
size_t
pass
::
MemoryVisualize
::
memory_usage
(
const
Node
*
node
)
size_t
pass
::
MemoryVisualize
::
memory_usage
(
shared_ptr
<
Node
>
node
)
{
{
return
0
;
return
0
;
}
}
size_t
pass
::
MemoryVisualize
::
memory_footprint
(
const
Node
*
node
)
size_t
pass
::
MemoryVisualize
::
memory_footprint
(
shared_ptr
<
Node
>
node
)
{
{
return
0
;
return
0
;
}
}
size_t
pass
::
MemoryVisualize
::
memory_footprint
(
const
std
::
list
<
Node
*
>&
nodes
)
size_t
pass
::
MemoryVisualize
::
memory_footprint
(
const
std
::
list
<
shared_ptr
<
Node
>
>&
nodes
)
{
{
return
0
;
return
0
;
}
}
src/ngraph/pass/memory_visualize.hpp
View file @
4aff3ec0
...
@@ -32,18 +32,18 @@ class ngraph::pass::MemoryVisualize : public ModulePass
...
@@ -32,18 +32,18 @@ class ngraph::pass::MemoryVisualize : public ModulePass
{
{
public
:
public
:
MemoryVisualize
(
const
std
::
string
&
filename
);
MemoryVisualize
(
const
std
::
string
&
filename
);
virtual
bool
run_on_module
(
std
::
vector
<
Function
*
>&
)
override
;
virtual
bool
run_on_module
(
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Function
>
>&
)
override
;
private
:
private
:
const
Node
*
find_largest_op
(
const
std
::
list
<
Node
*
>&
nodes
);
std
::
shared_ptr
<
Node
>
find_largest_op
(
const
std
::
list
<
std
::
shared_ptr
<
Node
>
>&
nodes
);
void
draw_tensor_weight
(
std
::
ostream
&
file
,
const
std
::
list
<
Node
*
>&
nodes
);
void
draw_tensor_weight
(
std
::
ostream
&
file
,
const
std
::
list
<
std
::
shared_ptr
<
Node
>
>&
nodes
);
void
draw_histogram
(
std
::
ostream
&
file
,
const
std
::
list
<
Node
*
>&
nodes
);
void
draw_histogram
(
std
::
ostream
&
file
,
const
std
::
list
<
std
::
shared_ptr
<
Node
>
>&
nodes
);
void
draw_op_influence
(
std
::
ostream
&
file
,
const
std
::
list
<
Node
*
>&
nodes
);
void
draw_op_influence
(
std
::
ostream
&
file
,
const
std
::
list
<
std
::
shared_ptr
<
Node
>
>&
nodes
);
int
compute_op_weight
(
const
Node
*
exop
);
int
compute_op_weight
(
std
::
shared_ptr
<
Node
>
exop
);
static
size_t
memory_usage
(
const
Node
*
);
static
size_t
memory_usage
(
std
::
shared_ptr
<
Node
>
);
static
size_t
memory_footprint
(
const
Node
*
);
static
size_t
memory_footprint
(
std
::
shared_ptr
<
Node
>
);
static
size_t
memory_footprint
(
const
std
::
list
<
Node
*
>&
);
static
size_t
memory_footprint
(
const
std
::
list
<
std
::
shared_ptr
<
Node
>
>&
);
const
std
::
string
m_filename
;
const
std
::
string
m_filename
;
};
};
src/ngraph/pass/pass.hpp
View file @
4aff3ec0
...
@@ -53,26 +53,26 @@ class ngraph::pass::ModulePass : public PassBase
...
@@ -53,26 +53,26 @@ class ngraph::pass::ModulePass : public PassBase
{
{
public
:
public
:
virtual
~
ModulePass
()
{}
virtual
~
ModulePass
()
{}
virtual
bool
run_on_module
(
std
::
vector
<
ngraph
::
Function
*
>&
)
=
0
;
virtual
bool
run_on_module
(
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Function
>
>&
)
=
0
;
};
};
class
ngraph
::
pass
::
FunctionPass
:
public
PassBase
class
ngraph
::
pass
::
FunctionPass
:
public
PassBase
{
{
public
:
public
:
virtual
~
FunctionPass
()
{}
virtual
~
FunctionPass
()
{}
virtual
bool
run_on_function
(
ngraph
::
Function
*
)
=
0
;
virtual
bool
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
)
=
0
;
};
};
class
ngraph
::
pass
::
NodePass
:
public
PassBase
class
ngraph
::
pass
::
NodePass
:
public
PassBase
{
{
public
:
public
:
virtual
~
NodePass
()
{}
virtual
~
NodePass
()
{}
virtual
bool
run_on_node
(
ngraph
::
Node
*
)
=
0
;
virtual
bool
run_on_node
(
std
::
shared_ptr
<
ngraph
::
Node
>
)
=
0
;
};
};
class
ngraph
::
pass
::
CallGraphPass
:
public
PassBase
class
ngraph
::
pass
::
CallGraphPass
:
public
PassBase
{
{
public
:
public
:
virtual
~
CallGraphPass
()
{}
virtual
~
CallGraphPass
()
{}
virtual
bool
run_on_call_graph
(
std
::
list
<
ngraph
::
Node
*
>&
)
=
0
;
virtual
bool
run_on_call_graph
(
std
::
list
<
std
::
shared_ptr
<
ngraph
::
Node
>
>&
)
=
0
;
};
};
src/ngraph/pass/propagate_types.cpp
View file @
4aff3ec0
...
@@ -20,9 +20,9 @@
...
@@ -20,9 +20,9 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
bool
pass
::
PropagateTypes
::
run_on_call_graph
(
list
<
Node
*
>&
nodes
)
bool
pass
::
PropagateTypes
::
run_on_call_graph
(
list
<
shared_ptr
<
Node
>
>&
nodes
)
{
{
for
(
Node
*
node
:
nodes
)
for
(
shared_ptr
<
Node
>
node
:
nodes
)
{
{
try
try
{
{
...
...
src/ngraph/pass/propagate_types.hpp
View file @
4aff3ec0
...
@@ -27,7 +27,7 @@ namespace ngraph
...
@@ -27,7 +27,7 @@ namespace ngraph
class
ngraph
::
pass
::
PropagateTypes
:
public
CallGraphPass
class
ngraph
::
pass
::
PropagateTypes
:
public
CallGraphPass
{
{
public
:
public
:
virtual
bool
run_on_call_graph
(
std
::
list
<
Node
*
>&
)
override
;
virtual
bool
run_on_call_graph
(
std
::
list
<
std
::
shared_ptr
<
Node
>
>&
)
override
;
private
:
private
:
};
};
src/ngraph/pass/topological_sort.cpp
View file @
4aff3ec0
...
@@ -25,24 +25,26 @@
...
@@ -25,24 +25,26 @@
using
namespace
ngraph
;
using
namespace
ngraph
;
using
namespace
std
;
using
namespace
std
;
bool
ngraph
::
pass
::
TopologicalSort
::
run_on_function
(
ngraph
::
Function
*
func
)
bool
ngraph
::
pass
::
TopologicalSort
::
run_on_function
(
shared_ptr
<
ngraph
::
Function
>
func
)
{
{
list
<
Node
*
>
result_list
;
list
<
shared_ptr
<
Node
>
>
result_list
;
deque
<
Node
*>
independent_nodes
;
deque
<
Node
*>
independent_nodes
;
unordered_map
<
Node
*
,
size_t
>
node_depencency_count
;
unordered_map
<
const
Node
*
,
size_t
>
node_depencency_count
;
unordered_map
<
Node
*
,
shared_ptr
<
Node
>>
node_map
;
traverse_nodes
(
func
->
get_result
(),
[
&
](
Node
*
node
)
{
traverse_nodes
(
func
->
get_result
(),
[
&
](
shared_ptr
<
Node
>
node
)
{
node_depencency_count
[
node
]
=
node
->
get_arguments
().
size
();
node_map
[
node
.
get
()]
=
node
;
node_depencency_count
[
node
.
get
()]
=
node
->
get_arguments
().
size
();
if
(
node
->
get_arguments
().
size
()
==
0
)
if
(
node
->
get_arguments
().
size
()
==
0
)
{
{
independent_nodes
.
push_back
(
node
);
independent_nodes
.
push_back
(
node
.
get
()
);
}
}
});
});
while
(
independent_nodes
.
size
()
>
0
)
while
(
independent_nodes
.
size
()
>
0
)
{
{
auto
independent_node
=
independent_nodes
.
front
();
auto
independent_node
=
independent_nodes
.
front
();
result_list
.
push_back
(
independent_node
);
result_list
.
push_back
(
node_map
[
independent_node
]
);
independent_nodes
.
pop_front
();
independent_nodes
.
pop_front
();
for
(
auto
user
:
independent_node
->
users
())
for
(
auto
user
:
independent_node
->
users
())
...
...
src/ngraph/pass/topological_sort.hpp
View file @
4aff3ec0
...
@@ -31,5 +31,5 @@ class ngraph::pass::TopologicalSort : public FunctionPass
...
@@ -31,5 +31,5 @@ class ngraph::pass::TopologicalSort : public FunctionPass
{
{
public
:
public
:
TopologicalSort
()
{}
TopologicalSort
()
{}
bool
run_on_function
(
ngraph
::
Function
*
)
override
;
bool
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
)
override
;
};
};
src/ngraph/pass/visualize_tree.cpp
View file @
4aff3ec0
...
@@ -23,15 +23,15 @@
...
@@ -23,15 +23,15 @@
using
namespace
ngraph
;
using
namespace
ngraph
;
using
namespace
std
;
using
namespace
std
;
bool
pass
::
VisualizeTree
::
run_on_module
(
vector
<
ngraph
::
Function
*
>&
functions
)
bool
pass
::
VisualizeTree
::
run_on_module
(
vector
<
shared_ptr
<
ngraph
::
Function
>
>&
functions
)
{
{
for
(
Function
*
f
:
functions
)
for
(
shared_ptr
<
Function
>
f
:
functions
)
{
{
// map<size_t, list<node_ptr>> dependent_nodes;
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes
(
f
->
get_result
(),
[
&
](
Node
*
node
)
{
traverse_nodes
(
f
->
get_result
(),
[
&
](
shared_ptr
<
Node
>
node
)
{
for
(
auto
arg
:
node
->
get_arguments
())
for
(
auto
arg
:
node
->
get_arguments
())
{
{
m_ss
<<
add_attributes
(
arg
.
get
()
);
m_ss
<<
add_attributes
(
arg
);
m_ss
<<
add_attributes
(
node
);
m_ss
<<
add_attributes
(
node
);
m_ss
<<
" "
<<
arg
->
get_name
()
<<
" -> "
<<
node
->
get_name
();
m_ss
<<
" "
<<
arg
->
get_name
()
<<
" -> "
<<
node
->
get_name
();
m_ss
<<
";
\n
"
;
m_ss
<<
";
\n
"
;
...
@@ -49,7 +49,7 @@ pass::VisualizeTree::VisualizeTree(const string& file_name)
...
@@ -49,7 +49,7 @@ pass::VisualizeTree::VisualizeTree(const string& file_name)
{
{
}
}
std
::
string
pass
::
VisualizeTree
::
add_attributes
(
const
Node
*
node
)
std
::
string
pass
::
VisualizeTree
::
add_attributes
(
shared_ptr
<
Node
>
node
)
{
{
string
rc
;
string
rc
;
if
(
!
contains
(
m_nodes_with_attributes
,
node
))
if
(
!
contains
(
m_nodes_with_attributes
,
node
))
...
@@ -60,7 +60,7 @@ std::string pass::VisualizeTree::add_attributes(const Node* node)
...
@@ -60,7 +60,7 @@ std::string pass::VisualizeTree::add_attributes(const Node* node)
return
rc
;
return
rc
;
}
}
std
::
string
pass
::
VisualizeTree
::
get_attributes
(
const
Node
*
node
)
std
::
string
pass
::
VisualizeTree
::
get_attributes
(
shared_ptr
<
Node
>
node
)
{
{
stringstream
ss
;
stringstream
ss
;
if
(
node
->
is_parameter
())
if
(
node
->
is_parameter
())
...
...
src/ngraph/pass/visualize_tree.hpp
View file @
4aff3ec0
...
@@ -32,14 +32,14 @@ class ngraph::pass::VisualizeTree : public ModulePass
...
@@ -32,14 +32,14 @@ class ngraph::pass::VisualizeTree : public ModulePass
{
{
public
:
public
:
VisualizeTree
(
const
std
::
string
&
file_name
);
VisualizeTree
(
const
std
::
string
&
file_name
);
bool
run_on_module
(
std
::
vector
<
ngraph
::
Function
*
>&
)
override
;
bool
run_on_module
(
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Function
>
>&
)
override
;
private
:
private
:
std
::
string
add_attributes
(
const
Node
*
node
);
std
::
string
add_attributes
(
std
::
shared_ptr
<
Node
>
node
);
std
::
string
get_attributes
(
const
Node
*
node
);
std
::
string
get_attributes
(
std
::
shared_ptr
<
Node
>
node
);
void
render
()
const
;
void
render
()
const
;
std
::
stringstream
m_ss
;
std
::
stringstream
m_ss
;
std
::
string
m_name
;
std
::
string
m_name
;
std
::
set
<
const
Node
*
>
m_nodes_with_attributes
;
std
::
set
<
std
::
shared_ptr
<
Node
>
>
m_nodes_with_attributes
;
};
};
src/ngraph/runtime/ngvm/external_function.cpp
View file @
4aff3ec0
...
@@ -961,7 +961,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
...
@@ -961,7 +961,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Turn this into a pass
// Turn this into a pass
// Assign layouts
// Assign layouts
// For now, just make everyone row-major.
// For now, just make everyone row-major.
for
(
const
Node
*
node
:
m_function
->
get_ordered_ops
())
for
(
shared_ptr
<
Node
>
node
:
m_function
->
get_ordered_ops
())
{
{
for
(
const
descriptor
::
Output
&
output
:
node
->
get_outputs
())
for
(
const
descriptor
::
Output
&
output
:
node
->
get_outputs
())
{
{
...
@@ -998,7 +998,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
...
@@ -998,7 +998,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
m_n_outputs
=
tensor_index
.
size
()
-
m_n_inputs
;
m_n_outputs
=
tensor_index
.
size
()
-
m_n_inputs
;
// All remaining tensor views
// All remaining tensor views
for
(
const
Node
*
node
:
m_function
->
get_ordered_ops
())
for
(
shared_ptr
<
Node
>
node
:
m_function
->
get_ordered_ops
())
{
{
for
(
const
descriptor
::
Output
&
output
:
node
->
get_outputs
())
for
(
const
descriptor
::
Output
&
output
:
node
->
get_outputs
())
{
{
...
@@ -1014,9 +1014,11 @@ void ExternalFunction::compile(FunctionMap& function_map)
...
@@ -1014,9 +1014,11 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Now we build the eigen-VM instructions
// Now we build the eigen-VM instructions
auto
op_map
=
get_op_map
();
auto
op_map
=
get_op_map
();
for
(
const
Node
*
node
:
m_function
->
get_ordered_ops
())
for
(
shared_ptr
<
Node
>
node
:
m_function
->
get_ordered_ops
())
{
{
auto
handler_it
=
op_map
.
find
(
type_index
(
typeid
(
*
node
)));
auto
&
n
=
*
node
;
// Work around a compiler warning (*node inside typeid may have effects
// with shared pointers, which is fine here but clang doesn't like it.)
auto
handler_it
=
op_map
.
find
(
type_index
(
typeid
(
n
)));
if
(
handler_it
==
op_map
.
end
())
if
(
handler_it
==
op_map
.
end
())
{
{
throw
ngraph_error
(
"Unhandled op during code generation"
);
throw
ngraph_error
(
"Unhandled op during code generation"
);
...
@@ -1034,7 +1036,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
...
@@ -1034,7 +1036,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
auto
tv
=
output
.
get_tensor_view
();
auto
tv
=
output
.
get_tensor_view
();
out
.
push_back
({
tensor_index
.
at
(
tv
),
tv
});
out
.
push_back
({
tensor_index
.
at
(
tv
),
tv
});
}
}
handler_it
->
second
(
node
,
this
,
function_map
,
in
,
out
);
handler_it
->
second
(
node
.
get
()
,
this
,
function_map
,
in
,
out
);
}
}
m_instructions
->
push_back
(
make_shared
<
eigen
::
ReturnInstruction
>
());
m_instructions
->
push_back
(
make_shared
<
eigen
::
ReturnInstruction
>
());
m_is_compiled
=
true
;
m_is_compiled
=
true
;
...
...
src/ngraph/util.cpp
View file @
4aff3ec0
...
@@ -137,15 +137,16 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list)
...
@@ -137,15 +137,16 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list)
return
seed
;
return
seed
;
}
}
void
ngraph
::
traverse_nodes
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
p
,
std
::
function
<
void
(
Node
*
)
>
f
)
void
ngraph
::
traverse_nodes
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
p
,
std
::
function
<
void
(
shared_ptr
<
Node
>
)
>
f
)
{
{
std
::
unordered_set
<
Node
*
>
instances_seen
;
std
::
unordered_set
<
shared_ptr
<
Node
>
>
instances_seen
;
deque
<
Node
*
>
stack
;
deque
<
shared_ptr
<
Node
>
>
stack
;
stack
.
push_front
(
p
.
get
()
);
stack
.
push_front
(
p
);
while
(
stack
.
size
()
>
0
)
while
(
stack
.
size
()
>
0
)
{
{
Node
*
n
=
stack
.
front
();
shared_ptr
<
Node
>
n
=
stack
.
front
();
if
(
instances_seen
.
find
(
n
)
==
instances_seen
.
end
())
if
(
instances_seen
.
find
(
n
)
==
instances_seen
.
end
())
{
{
instances_seen
.
insert
(
n
);
instances_seen
.
insert
(
n
);
...
@@ -154,7 +155,7 @@ void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p, std::functio
...
@@ -154,7 +155,7 @@ void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p, std::functio
stack
.
pop_front
();
stack
.
pop_front
();
for
(
auto
arg
:
n
->
get_arguments
())
for
(
auto
arg
:
n
->
get_arguments
())
{
{
stack
.
push_front
(
arg
.
get
()
);
stack
.
push_front
(
arg
);
}
}
}
}
}
}
...
@@ -163,7 +164,7 @@ void ngraph::free_nodes(shared_ptr<Node> p)
...
@@ -163,7 +164,7 @@ void ngraph::free_nodes(shared_ptr<Node> p)
{
{
std
::
deque
<
Node
*>
sorted_list
;
std
::
deque
<
Node
*>
sorted_list
;
traverse_nodes
(
p
,
[
&
](
Node
*
n
)
{
sorted_list
.
push_front
(
n
);
});
traverse_nodes
(
p
,
[
&
](
shared_ptr
<
Node
>
n
)
{
sorted_list
.
push_front
(
n
.
get
()
);
});
for
(
Node
*
n
:
sorted_list
)
for
(
Node
*
n
:
sorted_list
)
{
{
...
...
src/ngraph/util.hpp
View file @
4aff3ec0
...
@@ -195,7 +195,8 @@ namespace ngraph
...
@@ -195,7 +195,8 @@ namespace ngraph
return
a
*
b
;
return
a
*
b
;
}
}
void
traverse_nodes
(
const
std
::
shared_ptr
<
Node
>&
p
,
std
::
function
<
void
(
Node
*
)
>
f
);
void
traverse_nodes
(
const
std
::
shared_ptr
<
Node
>&
p
,
std
::
function
<
void
(
std
::
shared_ptr
<
Node
>
)
>
f
);
void
free_nodes
(
std
::
shared_ptr
<
Node
>
);
void
free_nodes
(
std
::
shared_ptr
<
Node
>
);
}
// end namespace ngraph
}
// end namespace ngraph
test/pass_liveness.cpp
View file @
4aff3ec0
...
@@ -50,7 +50,7 @@ TEST(pass, liveness)
...
@@ -50,7 +50,7 @@ TEST(pass, liveness)
pass_manager
.
register_pass
<
pass
::
DumpSorted
>
(
dump_file
);
pass_manager
.
register_pass
<
pass
::
DumpSorted
>
(
dump_file
);
shared_ptr
<
Function
>
func
=
make_test_graph
();
shared_ptr
<
Function
>
func
=
make_test_graph
();
pass_manager
.
run_passes
(
func
.
get
()
);
pass_manager
.
run_passes
(
func
);
auto
sorted
=
func
->
get_ordered_ops
();
auto
sorted
=
func
->
get_ordered_ops
();
// for (const Node* node : sorted)
// for (const Node* node : sorted)
...
...
test/pass_manager.cpp
View file @
4aff3ec0
...
@@ -39,7 +39,7 @@ TEST(pass_manager, add)
...
@@ -39,7 +39,7 @@ TEST(pass_manager, add)
auto
graph
=
make_test_graph
();
auto
graph
=
make_test_graph
();
size_t
node_count
=
get_node_count
(
graph
->
get_result
());
size_t
node_count
=
get_node_count
(
graph
->
get_result
());
pass_manager
.
run_passes
(
graph
.
get
()
);
pass_manager
.
run_passes
(
graph
);
auto
sorted
=
graph
->
get_ordered_ops
();
auto
sorted
=
graph
->
get_ordered_ops
();
EXPECT_EQ
(
node_count
,
sorted
.
size
());
EXPECT_EQ
(
node_count
,
sorted
.
size
());
EXPECT_TRUE
(
validate_list
(
sorted
));
EXPECT_TRUE
(
validate_list
(
sorted
));
...
...
test/test_tools.cpp
View file @
4aff3ec0
...
@@ -23,7 +23,7 @@ using namespace ngraph;
...
@@ -23,7 +23,7 @@ using namespace ngraph;
// This function traverses the list of ops and verifies that each op's dependencies (its inputs)
// This function traverses the list of ops and verifies that each op's dependencies (its inputs)
// is located earlier in the list. That is enough to be valid
// is located earlier in the list. That is enough to be valid
bool
validate_list
(
const
list
<
Node
*
>&
nodes
)
bool
validate_list
(
const
list
<
shared_ptr
<
Node
>
>&
nodes
)
{
{
bool
rc
=
true
;
bool
rc
=
true
;
for
(
auto
it
=
nodes
.
rbegin
();
it
!=
nodes
.
rend
();
it
++
)
for
(
auto
it
=
nodes
.
rbegin
();
it
!=
nodes
.
rend
();
it
++
)
...
@@ -39,7 +39,7 @@ bool validate_list(const list<Node*>& nodes)
...
@@ -39,7 +39,7 @@ bool validate_list(const list<Node*>& nodes)
for
(;
tmp
!=
nodes
.
rend
();
tmp
++
)
for
(;
tmp
!=
nodes
.
rend
();
tmp
++
)
{
{
auto
dep_tmp
=
*
tmp
;
auto
dep_tmp
=
*
tmp
;
auto
found
=
find
(
dependencies
.
begin
(),
dependencies
.
end
(),
dep_tmp
);
auto
found
=
find
(
dependencies
.
begin
(),
dependencies
.
end
(),
dep_tmp
.
get
()
);
if
(
found
!=
dependencies
.
end
())
if
(
found
!=
dependencies
.
end
())
{
{
dependencies
.
erase
(
found
);
dependencies
.
erase
(
found
);
...
@@ -82,6 +82,6 @@ shared_ptr<Function> make_test_graph()
...
@@ -82,6 +82,6 @@ shared_ptr<Function> make_test_graph()
size_t
get_node_count
(
std
::
shared_ptr
<
Node
>
n
)
size_t
get_node_count
(
std
::
shared_ptr
<
Node
>
n
)
{
{
size_t
node_count
=
0
;
size_t
node_count
=
0
;
traverse_nodes
(
n
,
[
&
](
const
Node
*
node
)
{
node_count
++
;
});
traverse_nodes
(
n
,
[
&
](
shared_ptr
<
Node
>
node
)
{
node_count
++
;
});
return
node_count
;
return
node_count
;
}
}
test/test_tools.hpp
View file @
4aff3ec0
...
@@ -23,6 +23,6 @@ namespace ngraph
...
@@ -23,6 +23,6 @@ namespace ngraph
class
Function
;
class
Function
;
}
}
bool
validate_list
(
const
std
::
list
<
ngraph
::
Node
*
>&
nodes
);
bool
validate_list
(
const
std
::
list
<
std
::
shared_ptr
<
ngraph
::
Node
>
>&
nodes
);
std
::
shared_ptr
<
ngraph
::
Function
>
make_test_graph
();
std
::
shared_ptr
<
ngraph
::
Function
>
make_test_graph
();
size_t
get_node_count
(
std
::
shared_ptr
<
ngraph
::
Node
>
n
);
size_t
get_node_count
(
std
::
shared_ptr
<
ngraph
::
Node
>
n
);
test/topological_sort.cpp
View file @
4aff3ec0
...
@@ -126,7 +126,7 @@ TEST(benchmark, topological_sort)
...
@@ -126,7 +126,7 @@ TEST(benchmark, topological_sort)
NGRAPH_INFO
<<
"topological sort took "
<<
timer
.
get_milliseconds
()
<<
"ms"
;
NGRAPH_INFO
<<
"topological sort took "
<<
timer
.
get_milliseconds
()
<<
"ms"
;
size_t
node_count
=
0
;
size_t
node_count
=
0
;
traverse_nodes
(
result
,
[
&
](
const
Node
*
node
)
{
node_count
++
;
});
traverse_nodes
(
result
,
[
&
](
shared_ptr
<
Node
>
node
)
{
node_count
++
;
});
NGRAPH_INFO
<<
"node count "
<<
node_count
;
NGRAPH_INFO
<<
"node count "
<<
node_count
;
...
@@ -135,6 +135,7 @@ TEST(benchmark, topological_sort)
...
@@ -135,6 +135,7 @@ TEST(benchmark, topological_sort)
timer
.
stop
();
timer
.
stop
();
NGRAPH_INFO
<<
"delete nodes took "
<<
timer
.
get_milliseconds
()
<<
"ms"
;
NGRAPH_INFO
<<
"delete nodes took "
<<
timer
.
get_milliseconds
()
<<
"ms"
;
}
}
TEST
(
topological_sort
,
collect_functions
)
TEST
(
topological_sort
,
collect_functions
)
{
{
// First create "f(A,B,C) = (A+B)*C".
// First create "f(A,B,C) = (A+B)*C".
...
@@ -174,7 +175,7 @@ TEST(topological_sort, collect_functions)
...
@@ -174,7 +175,7 @@ TEST(topological_sort, collect_functions)
set
<
string
>
expected
=
{
"f"
,
"g"
,
"h"
};
set
<
string
>
expected
=
{
"f"
,
"g"
,
"h"
};
auto
functions
=
pass_manager
.
get_state
().
get_functions
();
auto
functions
=
pass_manager
.
get_state
().
get_functions
();
vector
<
string
>
fnames
;
vector
<
string
>
fnames
;
for
(
Function
*
func
:
functions
)
for
(
shared_ptr
<
Function
>
func
:
functions
)
{
{
fnames
.
push_back
(
func
->
get_name
());
fnames
.
push_back
(
func
->
get_name
());
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment