Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
118e0679
Commit
118e0679
authored
Oct 06, 2017
by
Jai Menon
Committed by
GitHub
Oct 06, 2017
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into jmenon/codegen
parents
c2ff1508
c3dfdf5a
Show whitespace changes
Inline
Side-by-side
Showing
36 changed files
with
474 additions
and
403 deletions
+474
-403
CMakeLists.txt
src/ngraph/CMakeLists.txt
+2
-3
function.cpp
src/ngraph/function.cpp
+76
-2
function.hpp
src/ngraph/function.hpp
+25
-2
node.cpp
src/ngraph/node.cpp
+26
-0
node.hpp
src/ngraph/node.hpp
+2
-0
assign_tensors.cpp
src/ngraph/pass/assign_tensors.cpp
+3
-21
assign_tensors.hpp
src/ngraph/pass/assign_tensors.hpp
+3
-5
collect_functions.cpp
src/ngraph/pass/collect_functions.cpp
+34
-29
collect_functions.hpp
src/ngraph/pass/collect_functions.hpp
+19
-1
dump_sorted.cpp
src/ngraph/pass/dump_sorted.cpp
+6
-3
dump_sorted.hpp
src/ngraph/pass/dump_sorted.hpp
+3
-3
liveness.cpp
src/ngraph/pass/liveness.cpp
+1
-19
liveness.hpp
src/ngraph/pass/liveness.hpp
+3
-5
manager.cpp
src/ngraph/pass/manager.cpp
+70
-63
manager.hpp
src/ngraph/pass/manager.hpp
+7
-36
manager_state.cpp
src/ngraph/pass/manager_state.cpp
+17
-20
manager_state.hpp
src/ngraph/pass/manager_state.hpp
+16
-9
memory_layout.cpp
src/ngraph/pass/memory_layout.cpp
+1
-19
memory_layout.hpp
src/ngraph/pass/memory_layout.hpp
+3
-5
memory_visualize.cpp
src/ngraph/pass/memory_visualize.cpp
+8
-8
memory_visualize.hpp
src/ngraph/pass/memory_visualize.hpp
+3
-5
pass.cpp
src/ngraph/pass/pass.cpp
+2
-2
pass.hpp
src/ngraph/pass/pass.hpp
+42
-2
propagate_types.cpp
src/ngraph/pass/propagate_types.cpp
+2
-2
propagate_types.hpp
src/ngraph/pass/propagate_types.hpp
+3
-3
topological_sort.cpp
src/ngraph/pass/topological_sort.cpp
+7
-5
topological_sort.hpp
src/ngraph/pass/topological_sort.hpp
+3
-3
tree_pass.cpp
src/ngraph/pass/tree_pass.cpp
+0
-15
visualize_tree.cpp
src/ngraph/pass/visualize_tree.cpp
+10
-5
visualize_tree.hpp
src/ngraph/pass/visualize_tree.hpp
+3
-3
external_function.cpp
src/ngraph/runtime/external_function.cpp
+3
-3
visualize.cpp
src/ngraph/visualize.cpp
+0
-92
pass_liveness.cpp
test/pass_liveness.cpp
+1
-1
pass_manager.cpp
test/pass_manager.cpp
+18
-5
pass_memory_layout.cpp
test/pass_memory_layout.cpp
+1
-1
topological_sort.cpp
test/topological_sort.cpp
+51
-3
No files found.
src/ngraph/CMakeLists.txt
View file @
118e0679
...
...
@@ -39,16 +39,16 @@ set (SRC
ops/unary_elementwise_arithmetic.cpp
ops/unary_elementwise_builtin.cpp
pass/assign_tensors.cpp
pass/c
all_pas
s.cpp
pass/c
ollect_function
s.cpp
pass/dump_sorted.cpp
pass/liveness.cpp
pass/manager.cpp
pass/manager_state.cpp
pass/memory_layout.cpp
pass/memory_visualize.cpp
pass/pass.cpp
pass/propagate_types.cpp
pass/topological_sort.cpp
pass/tree_pass.cpp
pass/visualize_tree.cpp
runtime/call_frame.cpp
runtime/external_function.cpp
...
...
@@ -58,7 +58,6 @@ set (SRC
types/element_type.cpp
types/type.cpp
util.cpp
visualize.cpp
)
# find_program (GRAPHVIZ dot)
...
...
src/ngraph/function.cpp
View file @
118e0679
...
...
@@ -15,21 +15,95 @@
#include <memory>
#include "ngraph/function.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
using
namespace
ngraph
;
size_t
Function
::
m_next_instance_id
=
0
;
Function
::
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
shared_ptr
<
ValueType
>&
result_type
,
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
)
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
,
const
std
::
string
&
name
)
:
m_result
(
result
)
,
m_parameters
(
parameters
)
,
m_name
(
"Function"
)
,
m_name
(
name
)
,
m_result_type
(
result_type
)
,
m_ordered_ops_valid
(
false
)
,
m_instance_id
(
m_next_instance_id
++
)
{
size_t
i
=
0
;
for
(
auto
parameter
:
parameters
)
{
parameter
->
assign_function
(
this
,
i
++
);
}
traverse_nodes
(
result
,
[
&
](
Node
*
node
)
{
m_ops
.
push_back
(
node
);
});
}
void
Function
::
set_ordered_ops
(
const
std
::
list
<
Node
*>&
ordered_ops
)
{
m_ordered_ops
=
ordered_ops
;
m_ordered_ops_valid
=
true
;
}
std
::
list
<
Node
*>&
Function
::
get_ops
()
{
return
m_ops
;
}
const
std
::
list
<
Node
*>&
Function
::
get_ops
()
const
{
return
m_ops
;
}
std
::
list
<
Node
*>&
Function
::
get_ordered_ops
()
{
if
(
!
m_ordered_ops_valid
)
{
throw
ngraph_error
(
"Access to ordered ops invalid"
);
}
return
m_ordered_ops
;
}
const
std
::
list
<
Node
*>&
Function
::
get_ordered_ops
()
const
{
if
(
!
m_ordered_ops_valid
)
{
throw
ngraph_error
(
"Access to ordered ops invalid"
);
}
return
m_ordered_ops
;
}
std
::
string
Function
::
get_name
()
const
{
string
rc
;
if
(
m_name
.
empty
())
{
rc
=
"Function_"
+
to_string
(
m_instance_id
);
}
else
{
rc
=
m_name
;
}
return
rc
;
}
void
Function
::
set_name
(
const
string
&
name
)
{
if
(
m_name
.
empty
())
{
m_name
=
name
;
}
else
{
throw
ngraph_error
(
"Function name may be set exactly once"
);
}
}
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
Function
&
f
)
{
out
<<
"Function("
<<
f
.
get_name
()
<<
")"
;
return
out
;
}
src/ngraph/function.hpp
View file @
118e0679
...
...
@@ -15,11 +15,13 @@
#pragma once
#include <initializer_list>
#include <list>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/ops/parameter.hpp"
...
...
@@ -34,7 +36,8 @@ namespace ngraph
public
:
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
shared_ptr
<
ValueType
>&
result_type
,
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
);
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
,
const
std
::
string
&
name
=
""
);
std
::
shared_ptr
<
Node
>
get_result
()
{
return
m_result
;
}
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>
get_parameters
()
const
...
...
@@ -42,11 +45,31 @@ namespace ngraph
return
m_parameters
;
}
const
std
::
shared_ptr
<
ValueType
>
get_result_type
()
const
{
return
m_result_type
;
}
std
::
string
get_name
()
const
{
return
m_name
;
}
std
::
string
get_name
()
const
;
void
set_name
(
const
std
::
string
&
name
);
std
::
list
<
Node
*>&
get_ops
();
const
std
::
list
<
Node
*>&
get_ops
()
const
;
std
::
list
<
Node
*>&
get_ordered_ops
();
const
std
::
list
<
Node
*>&
get_ordered_ops
()
const
;
void
set_ordered_ops
(
const
std
::
list
<
Node
*>&
);
void
set_ordered_ops_valid
()
{
m_ordered_ops_valid
=
true
;
}
void
clear_ordered_ops_valid
()
{
m_ordered_ops_valid
=
false
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Function
&
);
protected
:
std
::
shared_ptr
<
Node
>
m_result
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
op
::
Parameter
>>
m_parameters
;
std
::
string
m_name
;
std
::
shared_ptr
<
ValueType
>
m_result_type
;
bool
m_ordered_ops_valid
;
std
::
list
<
Node
*>
m_ordered_ops
;
std
::
list
<
Node
*>
m_ops
;
private
:
Function
(
const
Function
&
)
=
delete
;
Function
(
const
Function
&&
)
=
delete
;
static
size_t
m_next_instance_id
;
size_t
m_instance_id
;
};
}
src/ngraph/node.cpp
View file @
118e0679
...
...
@@ -104,6 +104,32 @@ std::string Node::get_node_id() const
return
ss
.
str
();
}
std
::
string
Node
::
get_name
()
const
{
string
rc
;
if
(
m_name
.
empty
())
{
rc
=
description
()
+
"_"
+
to_string
(
m_instance_id
);
}
else
{
rc
=
m_name
;
}
return
rc
;
}
void
Node
::
set_name
(
const
string
&
name
)
{
if
(
m_name
.
empty
())
{
m_name
=
name
;
}
else
{
throw
ngraph_error
(
"Node name may be set exactly once"
);
}
}
namespace
ngraph
{
ostream
&
operator
<<
(
ostream
&
out
,
const
Node
&
node
)
...
...
src/ngraph/node.hpp
View file @
118e0679
...
...
@@ -55,6 +55,8 @@ namespace ngraph
public
:
/// The class name, must not contain spaces
virtual
std
::
string
description
()
const
=
0
;
std
::
string
get_name
()
const
;
void
set_name
(
const
std
::
string
&
name
);
/// Propagate types and check arguments for consistency
virtual
void
propagate_types
()
=
0
;
...
...
src/ngraph/pass/assign_tensors.cpp
View file @
118e0679
...
...
@@ -25,15 +25,15 @@
using
namespace
std
;
using
namespace
ngraph
;
bool
pass
::
AssignTensors
::
run_on_call_
list
(
std
::
list
<
Node
*>&
node_list
)
bool
pass
::
AssignTensors
::
run_on_call_
graph
(
list
<
Node
*>&
nodes
)
{
for
(
Node
*
node
:
node
_list
)
for
(
Node
*
node
:
node
s
)
{
try
{
// We need to set the nodes is_output state prior to call assign_tensors
// so that the output state can be passes to the constructed tensors.
if
(
node
==
get_state
().
get_function
(
)
->
get_result
().
get
())
if
(
node
==
get_state
().
get_function
s
().
at
(
0
)
->
get_result
().
get
())
{
node
->
set_is_output
();
}
...
...
@@ -50,21 +50,3 @@ bool pass::AssignTensors::run_on_call_list(std::list<Node*>& node_list)
}
return
false
;
}
void
pass
::
AssignTensors
::
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
registered_passes
)
const
{
bool
found_propagate_types
=
false
;
for
(
auto
pass
:
registered_passes
)
{
if
(
dynamic_pointer_cast
<
PropagateTypes
>
(
pass
))
{
found_propagate_types
=
true
;
}
}
if
(
!
found_propagate_types
)
{
throw
runtime_error
(
"Dependency 'PropagateTypes' not found for pass 'AssignTensors'"
);
}
}
src/ngraph/pass/assign_tensors.hpp
View file @
118e0679
...
...
@@ -14,7 +14,7 @@
#pragma once
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -25,12 +25,10 @@ namespace ngraph
class
Node
;
}
class
ngraph
::
pass
::
AssignTensors
:
public
Call
Base
class
ngraph
::
pass
::
AssignTensors
:
public
Call
GraphPass
{
public
:
virtual
bool
run_on_call_list
(
std
::
list
<
Node
*>&
)
override
;
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
)
const
override
;
virtual
bool
run_on_call_graph
(
std
::
list
<
Node
*>&
nodes
)
override
;
private
:
};
src/ngraph/
visualize.h
pp
→
src/ngraph/
pass/collect_functions.c
pp
View file @
118e0679
...
...
@@ -12,34 +12,39 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <functional>
#include <memory>
#include <set>
#include <sstream>
namespace
ngraph
#include "ngraph/pass/collect_functions.hpp"
#include "ngraph/function.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
::
pass
;
bool
CollectFunctions
::
run_on_function
(
ngraph
::
Function
*
func
)
{
class
Visualize
;
class
Node
;
using
node_ptr
=
std
::
shared_ptr
<
Node
>
;
set
<
Function
*>
functions
;
deque
<
Function
*>
stack
;
stack
.
push_back
(
func
);
while
(
stack
.
empty
()
==
false
)
{
Function
*
f
=
stack
.
front
();
stack
.
pop_front
();
functions
.
insert
(
f
);
traverse_nodes
(
f
->
get_result
(),
[
&
](
Node
*
node
)
{
op
::
FunctionCall
*
fc
=
dynamic_cast
<
op
::
FunctionCall
*>
(
node
);
if
(
fc
)
{
stack
.
push_back
(
fc
->
get_function
().
get
());
}
});
}
get_state
().
set_functions
(
functions
);
return
false
;
}
class
ngraph
::
Visualize
{
public
:
Visualize
(
const
std
::
string
&
name
=
"ngraph"
);
void
add
(
node_ptr
);
void
save_dot
(
const
std
::
string
&
path
)
const
;
private
:
std
::
string
add_attributes
(
const
Node
*
node
);
std
::
string
get_attributes
(
const
Node
*
node
);
std
::
stringstream
m_ss
;
std
::
string
m_name
;
std
::
set
<
const
Node
*>
m_nodes_with_attributes
;
};
src/ngraph/pass/c
all_pass.c
pp
→
src/ngraph/pass/c
ollect_functions.h
pp
View file @
118e0679
...
...
@@ -12,4 +12,22 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/pass/call_pass.hpp"
#pragma once
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
namespace
pass
{
class
CollectFunctions
;
}
}
class
ngraph
::
pass
::
CollectFunctions
:
public
FunctionPass
{
public
:
bool
run_on_function
(
ngraph
::
Function
*
)
override
;
private
:
};
src/ngraph/pass/dump_sorted.cpp
View file @
118e0679
...
...
@@ -27,14 +27,16 @@ pass::DumpSorted::DumpSorted(const string& output_file)
{
}
bool
pass
::
DumpSorted
::
run_on_
call_list
(
list
<
Node
*>&
node
s
)
bool
pass
::
DumpSorted
::
run_on_
module
(
vector
<
Function
*>&
function
s
)
{
ofstream
out
{
m_output_file
};
if
(
out
)
{
for
(
const
Node
*
node
:
node
s
)
for
(
Function
*
f
:
function
s
)
{
out
<<
node
->
get_node_id
()
<<
"("
;
for
(
const
Node
*
node
:
f
->
get_ordered_ops
())
{
out
<<
node
->
get_name
()
<<
"("
;
vector
<
string
>
inputs
;
for
(
const
Input
&
input
:
node
->
get_inputs
())
{
...
...
@@ -65,6 +67,7 @@ bool pass::DumpSorted::run_on_call_list(list<Node*>& nodes)
}
}
}
}
return
false
;
}
src/ngraph/pass/dump_sorted.hpp
View file @
118e0679
...
...
@@ -16,7 +16,7 @@
#include <string>
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -27,12 +27,12 @@ namespace ngraph
class
Node
;
}
class
ngraph
::
pass
::
DumpSorted
:
public
CallBase
class
ngraph
::
pass
::
DumpSorted
:
public
ModulePass
{
public
:
DumpSorted
(
const
std
::
string
&
output_file
);
virtual
bool
run_on_
call_list
(
std
::
list
<
Node
*>&
)
override
;
virtual
bool
run_on_
module
(
std
::
vector
<
Function
*>&
)
override
;
private
:
const
std
::
string
m_output_file
;
...
...
src/ngraph/pass/liveness.cpp
View file @
118e0679
...
...
@@ -27,7 +27,7 @@ using namespace std;
using
namespace
ngraph
;
using
namespace
ngraph
::
descriptor
;
bool
pass
::
Liveness
::
run_on_call_
list
(
list
<
Node
*>&
ops
)
bool
pass
::
Liveness
::
run_on_call_
graph
(
list
<
Node
*>&
ops
)
{
unordered_set
<
Tensor
*>
currently_live
;
...
...
@@ -123,24 +123,6 @@ bool pass::Liveness::run_on_call_list(list<Node*>& ops)
return
false
;
}
void
pass
::
Liveness
::
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
registered_passes
)
const
{
bool
found_propagate_types
=
false
;
for
(
auto
pass
:
registered_passes
)
{
if
(
dynamic_pointer_cast
<
AssignTensors
>
(
pass
))
{
found_propagate_types
=
true
;
}
}
if
(
!
found_propagate_types
)
{
throw
runtime_error
(
"Dependency 'PropagateTypes' not found for pass 'AssignTensors'"
);
}
}
bool
pass
::
Liveness
::
is_temporary
(
const
Tensor
&
tensor
)
{
return
tensor
.
is_persistent
()
==
false
&&
tensor
.
is_input
()
==
false
&&
...
...
src/ngraph/pass/liveness.hpp
View file @
118e0679
...
...
@@ -15,7 +15,7 @@
#pragma once
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -26,12 +26,10 @@ namespace ngraph
class
Node
;
}
class
ngraph
::
pass
::
Liveness
:
public
Call
Base
class
ngraph
::
pass
::
Liveness
:
public
Call
GraphPass
{
public
:
virtual
bool
run_on_call_list
(
std
::
list
<
Node
*>&
)
override
;
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
)
const
override
;
virtual
bool
run_on_call_graph
(
std
::
list
<
Node
*>&
)
override
;
private
:
bool
is_temporary
(
const
descriptor
::
Tensor
&
);
...
...
src/ngraph/pass/manager.cpp
View file @
118e0679
...
...
@@ -19,40 +19,11 @@
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/pass.hpp"
using
namespace
std
;
using
namespace
ngraph
;
Function
*
ngraph
::
pass
::
ManagerState
::
get_function
()
{
return
m_function
;
}
void
ngraph
::
pass
::
ManagerState
::
set_function
(
Function
*
func
)
{
m_function
=
func
;
}
size_t
ngraph
::
pass
::
ManagerState
::
get_temporary_pool_size
()
{
return
m_temporary_pool_size
;
}
void
ngraph
::
pass
::
ManagerState
::
set_temporary_pool_size
(
size_t
size
)
{
m_temporary_pool_size
=
size
;
}
std
::
list
<
Node
*>&
ngraph
::
pass
::
ManagerState
::
get_call_graph
()
{
return
m_call_graph
;
}
const
std
::
list
<
Node
*>&
ngraph
::
pass
::
ManagerState
::
get_call_graph
()
const
{
return
m_call_graph
;
}
ngraph
::
pass
::
Manager
::
Manager
()
{
}
...
...
@@ -65,26 +36,6 @@ void ngraph::pass::Manager::initialize_default_passes()
{
}
void
ngraph
::
pass
::
Manager
::
register_pass_ptr
(
std
::
shared_ptr
<
TreeBase
>
p
)
{
if
(
p
==
nullptr
)
{
throw
invalid_argument
(
"null pass registered"
);
}
p
->
check_dependencies
(
m_tree_passes
);
m_tree_passes
.
push_back
(
p
);
}
void
ngraph
::
pass
::
Manager
::
register_pass_ptr
(
std
::
shared_ptr
<
CallBase
>
p
)
{
if
(
p
==
nullptr
)
{
throw
invalid_argument
(
"null pass registered"
);
}
p
->
check_dependencies
(
m_call_passes
);
m_call_passes
.
push_back
(
p
);
}
void
ngraph
::
pass
::
Manager
::
run_passes
(
shared_ptr
<
Function
>
func
)
{
run_passes
(
func
.
get
());
...
...
@@ -92,23 +43,79 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
void
ngraph
::
pass
::
Manager
::
run_passes
(
Function
*
func
)
{
m_state
.
set_function
(
func
);
for
(
shared_ptr
<
TreeBase
>
p
:
m_tree_passes
)
vector
<
Function
*>
fs
=
{
func
};
get_state
().
set_functions
(
fs
);
for
(
shared_ptr
<
PassBase
>
pass
:
m_pass_list
)
{
pass
->
set_state
(
get_state
());
auto
module_pass
=
dynamic_pointer_cast
<
ModulePass
>
(
pass
);
auto
function_pass
=
dynamic_pointer_cast
<
FunctionPass
>
(
pass
);
auto
node_pass
=
dynamic_pointer_cast
<
NodePass
>
(
pass
);
auto
call_graph_pass
=
dynamic_pointer_cast
<
CallGraphPass
>
(
pass
);
if
(
module_pass
)
{
p
->
set_state
(
get_state
());
p
->
run_on_tree
(
func
->
get_result
());
module_pass
->
run_on_module
(
fs
);
}
for
(
shared_ptr
<
CallBase
>&
p
:
m_call_passes
)
else
if
(
function_pass
)
{
for
(
Function
*
f
:
fs
)
{
p
->
set_state
(
get_state
());
p
->
run_on_call_list
(
get_state
().
get_call_graph
());
function_pass
->
run_on_function
(
f
);
}
}
const
std
::
list
<
ngraph
::
Node
*>&
ngraph
::
pass
::
Manager
::
get_call_graph
()
const
{
return
m_state
.
get_call_graph
();
}
else
if
(
node_pass
)
{
for
(
Function
*
f
:
fs
)
{
for
(
Node
*
n
:
f
->
get_ops
())
{
node_pass
->
run_on_node
(
n
);
}
}
}
else
if
(
call_graph_pass
)
{
for
(
Function
*
f
:
fs
)
{
call_graph_pass
->
run_on_call_graph
(
f
->
get_ordered_ops
());
}
}
}
// for (shared_ptr<ModulePass>& p : m_module_passes)
// {
// p->set_state(get_state());
// p->run_on_module(fs);
// }
// for (Function* f : fs)
// {
// for (shared_ptr<FunctionPass> p : m_function_passes)
// {
// p->set_state(get_state());
// p->run_on_function(f);
// }
// }
// for (Function* f : fs)
// {
// NGRAPH_INFO;
// for (shared_ptr<NodePass> p : m_node_passes)
// {
// for (Node* node : f->get_ops())
// {
// NGRAPH_INFO;
// p->set_state(get_state());
// p->run_on_node(node);
// }
// }
// }
// for (shared_ptr<CallGraphPass>& p : m_call_graph_passes)
// {
// p->set_state(get_state());
// p->run_on_call_graph(func->get_ordered_ops());
// }
}
ngraph
::
pass
::
ManagerState
&
ngraph
::
pass
::
Manager
::
get_state
()
...
...
src/ngraph/pass/manager.hpp
View file @
118e0679
...
...
@@ -18,8 +18,8 @@
#include <memory>
#include <vector>
#include "ngraph/pass/
call_pass
.hpp"
#include "ngraph/pass/
tree_
pass.hpp"
#include "ngraph/pass/
manager_state
.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -33,24 +33,6 @@ namespace ngraph
class
Function
;
}
class
ngraph
::
pass
::
ManagerState
{
public
:
Function
*
get_function
();
void
set_function
(
Function
*
);
size_t
get_temporary_pool_size
();
void
set_temporary_pool_size
(
size_t
);
std
::
list
<
Node
*>&
get_call_graph
();
const
std
::
list
<
Node
*>&
get_call_graph
()
const
;
private
:
Function
*
m_function
=
nullptr
;
size_t
m_temporary_pool_size
=
0
;
std
::
list
<
Node
*>
m_call_graph
;
};
class
ngraph
::
pass
::
Manager
{
public
:
...
...
@@ -62,29 +44,18 @@ public:
template
<
typename
T
,
class
...
Args
>
void
register_pass
(
Args
...
args
)
{
static_assert
(
std
::
is_base_of
<
pass
::
Base
,
T
>::
value
,
"pass not derived from pass base"
);
if
(
std
::
is_base_of
<
TreeBase
,
T
>::
value
)
{
register_pass_ptr
(
std
::
make_shared
<
T
>
(
args
...));
}
else
if
(
std
::
is_base_of
<
CallBase
,
T
>::
value
)
{
register_pass_ptr
(
std
::
make_shared
<
T
>
(
args
...));
}
static_assert
(
std
::
is_base_of
<
pass
::
PassBase
,
T
>::
value
,
"pass not derived from pass base"
);
auto
pass
=
std
::
make_shared
<
T
>
(
args
...);
auto
pass_base
=
std
::
static_pointer_cast
<
PassBase
>
(
pass
);
m_pass_list
.
push_back
(
pass_base
);
}
void
run_passes
(
Function
*
);
void
run_passes
(
std
::
shared_ptr
<
Function
>
);
const
std
::
list
<
Node
*>&
get_call_graph
()
const
;
ManagerState
&
get_state
();
private
:
void
register_pass_ptr
(
std
::
shared_ptr
<
TreeBase
>
);
void
register_pass_ptr
(
std
::
shared_ptr
<
CallBase
>
);
std
::
vector
<
std
::
shared_ptr
<
TreeBase
>>
m_tree_passes
;
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>
m_call_passes
;
std
::
vector
<
std
::
shared_ptr
<
PassBase
>>
m_pass_list
;
ManagerState
m_state
;
};
src/ngraph/pass/
tree_pass.h
pp
→
src/ngraph/pass/
manager_state.c
pp
View file @
118e0679
...
...
@@ -12,31 +12,28 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <list>
#include <iostream>
#include <memory>
#include <vector>
#include "ngraph/pass/pass.hpp"
#include "ngraph/function.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/manager_state.hpp"
namespace
ngraph
{
namespace
pass
{
class
TreeBase
;
}
using
namespace
std
;
using
namespace
ngraph
;
class
Node
;
vector
<
Function
*>&
ngraph
::
pass
::
ManagerState
::
get_functions
()
{
return
m_function_list
;
}
class
ngraph
::
pass
::
TreeBase
:
public
Base
size_t
ngraph
::
pass
::
ManagerState
::
get_temporary_pool_size
()
{
public
:
virtual
~
TreeBase
()
{}
// return true if changes were made to the tree
virtual
bool
run_on_tree
(
std
::
shared_ptr
<
Node
>
)
=
0
;
return
m_temporary_pool_size
;
}
// derived class throws exception if its dependencies have not been met
virtual
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
TreeBase
>>&
)
const
{}
};
void
ngraph
::
pass
::
ManagerState
::
set_temporary_pool_size
(
size_t
size
)
{
m_temporary_pool_size
=
size
;
}
src/ngraph/pass/
call_pass
.hpp
→
src/ngraph/pass/
manager_state
.hpp
View file @
118e0679
...
...
@@ -14,29 +14,36 @@
#pragma once
#include <list>
#include <memory>
#include <vector>
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
namespace
pass
{
class
CallBas
e
;
class
ManagerStat
e
;
}
class
Node
;
class
Function
;
}
class
ngraph
::
pass
::
CallBase
:
public
Bas
e
class
ngraph
::
pass
::
ManagerStat
e
{
public
:
virtual
~
CallBase
()
{}
virtual
bool
run_on_call_list
(
std
::
list
<
Node
*>&
)
=
0
;
std
::
vector
<
Function
*>&
get_functions
();
template
<
typename
T
>
void
set_functions
(
const
T
&
collection
)
{
m_function_list
.
clear
();
m_function_list
.
insert
(
m_function_list
.
begin
(),
collection
.
begin
(),
collection
.
end
());
}
size_t
get_temporary_pool_size
();
void
set_temporary_pool_size
(
size_t
);
// derived class throws exception if its dependencies have not been met
virtual
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
)
const
{}
private
:
size_t
m_temporary_pool_size
=
0
;
std
::
vector
<
Function
*>
m_function_list
;
};
src/ngraph/pass/memory_layout.cpp
View file @
118e0679
...
...
@@ -27,7 +27,7 @@ using namespace std;
using
namespace
ngraph
;
using
namespace
ngraph
::
descriptor
;
bool
pass
::
MemoryLayout
::
run_on_call_
list
(
std
::
list
<
Node
*>&
node_list
)
bool
pass
::
MemoryLayout
::
run_on_call_
graph
(
std
::
list
<
Node
*>&
node_list
)
{
MemoryManager
mm
;
for
(
const
Node
*
node
:
node_list
)
...
...
@@ -47,24 +47,6 @@ bool pass::MemoryLayout::run_on_call_list(std::list<Node*>& node_list)
return
false
;
}
void
pass
::
MemoryLayout
::
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
registered_passes
)
const
{
bool
found_propagate_types
=
false
;
for
(
auto
pass
:
registered_passes
)
{
if
(
dynamic_pointer_cast
<
Liveness
>
(
pass
))
{
found_propagate_types
=
true
;
}
}
if
(
!
found_propagate_types
)
{
throw
runtime_error
(
"Dependency 'PropagateTypes' not found for pass 'AssignTensors'"
);
}
}
pass
::
MemoryManager
::
node
::
node
(
size_t
size
,
block_state
state
)
:
m_size
{
size
}
,
m_state
{
state
}
...
...
src/ngraph/pass/memory_layout.hpp
View file @
118e0679
...
...
@@ -18,7 +18,7 @@
#include <list>
#include <sstream>
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -31,12 +31,10 @@ namespace ngraph
class
Node
;
}
class
ngraph
::
pass
::
MemoryLayout
:
public
Call
Base
class
ngraph
::
pass
::
MemoryLayout
:
public
Call
GraphPass
{
public
:
virtual
bool
run_on_call_list
(
std
::
list
<
Node
*>&
)
override
;
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
)
const
override
;
virtual
bool
run_on_call_graph
(
std
::
list
<
Node
*>&
)
override
;
private
:
};
...
...
src/ngraph/pass/memory_visualize.cpp
View file @
118e0679
...
...
@@ -19,6 +19,7 @@
#include "memory_visualize.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/util.hpp"
...
...
@@ -31,11 +32,13 @@ pass::MemoryVisualize::MemoryVisualize(const string& filename)
{
}
bool
pass
::
MemoryVisualize
::
run_on_
call_list
(
list
<
Node
*>&
_node
s
)
bool
pass
::
MemoryVisualize
::
run_on_
module
(
vector
<
Function
*>&
function
s
)
{
const
list
<
Node
*>
nodes
=
_nodes
;
ofstream
file
(
m_filename
);
{
for
(
const
Function
*
f
:
functions
)
{
const
list
<
Node
*>
nodes
=
f
->
get_ordered_ops
();
file
<<
"<!DOCTYPE html>
\n
<html>
\n
"
;
file
<<
"<head>
\n
"
;
file
<<
" <style>
\n
"
;
...
...
@@ -89,13 +92,10 @@ bool pass::MemoryVisualize::run_on_call_list(list<Node*>& _nodes)
// file << "<hr>\n";
file
<<
"</body>
\n
</html>
\n
"
;
}
}
return
false
;
}
void
pass
::
MemoryVisualize
::
check_dependencies
(
const
vector
<
shared_ptr
<
CallBase
>>&
deps
)
const
{
}
const
Node
*
pass
::
MemoryVisualize
::
find_largest_op
(
const
list
<
Node
*>&
nodes
)
{
const
Node
*
largest_op
=
nullptr
;
...
...
@@ -207,7 +207,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
size_t
x2
=
((
usage
/
memory_footprint
)
*
scale
)
+
offset
;
file
<<
"<text x=
\"
"
<<
0
<<
"
\"
y=
\"
"
<<
y
+
text_offset
<<
"
\"
fill=
\"
"
<<
"black"
<<
"
\"
>"
<<
node
->
get_n
ode_id
()
<<
"</text>
\n
"
;
<<
"
\"
>"
<<
node
->
get_n
ame
()
<<
"</text>
\n
"
;
file
<<
"<line x1=
\"
"
<<
x1
<<
"
\"
y1=
\"
"
<<
y
<<
"
\"
x2=
\"
"
<<
x2
<<
"
\"
y2=
\"
"
<<
y
<<
"
\"
"
;
file
<<
" style=
\"
stroke:forestgreen;stroke-width:"
<<
stroke_width
<<
"
\"
/>
\n
"
;
...
...
@@ -231,7 +231,7 @@ void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<Node*>&
{
int
weight
=
compute_op_weight
(
exop
);
file
<<
" <tr>"
;
file
<<
"<td>"
<<
exop
->
get_n
ode_id
()
<<
"</td>"
;
file
<<
"<td>"
<<
exop
->
get_n
ame
()
<<
"</td>"
;
file
<<
"<td align=
\"
right
\"
>"
<<
weight
<<
"</td>"
;
file
<<
"</tr>
\n
"
;
}
...
...
src/ngraph/pass/memory_visualize.hpp
View file @
118e0679
...
...
@@ -18,7 +18,7 @@
#include <limits>
#include <list>
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -29,13 +29,11 @@ namespace ngraph
class
Node
;
}
class
ngraph
::
pass
::
MemoryVisualize
:
public
CallBase
class
ngraph
::
pass
::
MemoryVisualize
:
public
ModulePass
{
public
:
MemoryVisualize
(
const
std
::
string
&
filename
);
virtual
bool
run_on_call_list
(
std
::
list
<
Node
*>&
)
override
;
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
)
const
override
;
virtual
bool
run_on_module
(
std
::
vector
<
Function
*>&
)
override
;
private
:
const
Node
*
find_largest_op
(
const
std
::
list
<
Node
*>&
nodes
);
...
...
src/ngraph/pass/pass.cpp
View file @
118e0679
...
...
@@ -15,12 +15,12 @@
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/manager.hpp"
ngraph
::
pass
::
ManagerState
&
ngraph
::
pass
::
Base
::
get_state
()
ngraph
::
pass
::
ManagerState
&
ngraph
::
pass
::
Pass
Base
::
get_state
()
{
return
*
m_state
;
}
void
ngraph
::
pass
::
Base
::
set_state
(
ManagerState
&
state
)
void
ngraph
::
pass
::
Pass
Base
::
set_state
(
ManagerState
&
state
)
{
m_state
=
&
state
;
}
src/ngraph/pass/pass.hpp
View file @
118e0679
...
...
@@ -14,21 +14,33 @@
#pragma once
#include <list>
#include <memory>
#include <vector>
#include "ngraph/node.hpp"
namespace
ngraph
{
namespace
pass
{
class
Base
;
class
PassBase
;
class
ModulePass
;
class
FunctionPass
;
class
NodePass
;
class
CallGraphPass
;
class
Manager
;
class
ManagerState
;
}
class
Function
;
}
class
ngraph
::
pass
::
Base
class
ngraph
::
pass
::
Pass
Base
{
friend
class
Manager
;
public
:
virtual
~
PassBase
()
{}
protected
:
ManagerState
&
get_state
();
void
set_state
(
ManagerState
&
);
...
...
@@ -36,3 +48,31 @@ protected:
private
:
ManagerState
*
m_state
;
};
class
ngraph
::
pass
::
ModulePass
:
public
PassBase
{
public
:
virtual
~
ModulePass
()
{}
virtual
bool
run_on_module
(
std
::
vector
<
ngraph
::
Function
*>&
)
=
0
;
};
class
ngraph
::
pass
::
FunctionPass
:
public
PassBase
{
public
:
virtual
~
FunctionPass
()
{}
virtual
bool
run_on_function
(
ngraph
::
Function
*
)
=
0
;
};
class
ngraph
::
pass
::
NodePass
:
public
PassBase
{
public
:
virtual
~
NodePass
()
{}
virtual
bool
run_on_node
(
ngraph
::
Node
*
)
=
0
;
};
class
ngraph
::
pass
::
CallGraphPass
:
public
PassBase
{
public
:
virtual
~
CallGraphPass
()
{}
virtual
bool
run_on_call_graph
(
std
::
list
<
ngraph
::
Node
*>&
)
=
0
;
};
src/ngraph/pass/propagate_types.cpp
View file @
118e0679
...
...
@@ -20,9 +20,9 @@
using
namespace
std
;
using
namespace
ngraph
;
bool
pass
::
PropagateTypes
::
run_on_call_
list
(
std
::
list
<
Node
*>&
node_list
)
bool
pass
::
PropagateTypes
::
run_on_call_
graph
(
list
<
Node
*>&
nodes
)
{
for
(
Node
*
node
:
node
_list
)
for
(
Node
*
node
:
node
s
)
{
try
{
...
...
src/ngraph/pass/propagate_types.hpp
View file @
118e0679
...
...
@@ -14,7 +14,7 @@
#pragma once
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -25,10 +25,10 @@ namespace ngraph
class
Node
;
}
class
ngraph
::
pass
::
PropagateTypes
:
public
Call
Base
class
ngraph
::
pass
::
PropagateTypes
:
public
Call
GraphPass
{
public
:
virtual
bool
run_on_call_
list
(
std
::
list
<
Node
*>&
)
override
;
virtual
bool
run_on_call_
graph
(
std
::
list
<
Node
*>&
)
override
;
private
:
};
src/ngraph/pass/topological_sort.cpp
View file @
118e0679
...
...
@@ -15,6 +15,7 @@
#include <deque>
#include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/manager.hpp"
...
...
@@ -24,14 +25,13 @@
using
namespace
ngraph
;
using
namespace
std
;
bool
ngraph
::
pass
::
TopologicalSort
::
run_on_
tree
(
std
::
shared_ptr
<
Node
>
p
)
bool
ngraph
::
pass
::
TopologicalSort
::
run_on_
function
(
ngraph
::
Function
*
func
)
{
list
<
Node
*>&
sorted_list
=
get_state
().
get_call_graph
();
sorted_list
.
clear
();
list
<
Node
*>
result_list
;
deque
<
Node
*>
independent_nodes
;
unordered_map
<
Node
*
,
size_t
>
node_depencency_count
;
traverse_nodes
(
p
,
[
&
](
Node
*
node
)
{
traverse_nodes
(
func
->
get_result
()
,
[
&
](
Node
*
node
)
{
node_depencency_count
[
node
]
=
node
->
get_arguments
().
size
();
if
(
node
->
get_arguments
().
size
()
==
0
)
{
...
...
@@ -42,7 +42,7 @@ bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p)
while
(
independent_nodes
.
size
()
>
0
)
{
auto
independent_node
=
independent_nodes
.
front
();
sorted
_list
.
push_back
(
independent_node
);
result
_list
.
push_back
(
independent_node
);
independent_nodes
.
pop_front
();
for
(
auto
user
:
independent_node
->
users
())
...
...
@@ -56,5 +56,7 @@ bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p)
}
}
func
->
set_ordered_ops
(
result_list
);
return
false
;
}
src/ngraph/pass/topological_sort.hpp
View file @
118e0679
...
...
@@ -17,7 +17,7 @@
#include <list>
#include <memory>
#include "ngraph/pass/
tree_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -28,9 +28,9 @@ namespace ngraph
class
Node
;
}
class
ngraph
::
pass
::
TopologicalSort
:
public
TreeBase
class
ngraph
::
pass
::
TopologicalSort
:
public
FunctionPass
{
public
:
TopologicalSort
()
{}
bool
run_on_
tree
(
std
::
shared_ptr
<
Node
>
)
override
;
bool
run_on_
function
(
ngraph
::
Function
*
)
override
;
};
src/ngraph/pass/tree_pass.cpp
deleted
100644 → 0
View file @
c2ff1508
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/pass/tree_pass.hpp"
src/ngraph/pass/visualize_tree.cpp
View file @
118e0679
...
...
@@ -14,25 +14,30 @@
#include <fstream>
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp"
using
namespace
ngraph
;
using
namespace
std
;
bool
pass
::
VisualizeTree
::
run_on_
tree
(
std
::
shared_ptr
<
Node
>
base_node
)
bool
pass
::
VisualizeTree
::
run_on_
module
(
vector
<
ngraph
::
Function
*>&
functions
)
{
for
(
Function
*
f
:
functions
)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes
(
base_node
,
[
&
](
Node
*
node
)
{
traverse_nodes
(
f
->
get_result
()
,
[
&
](
Node
*
node
)
{
for
(
auto
arg
:
node
->
get_arguments
())
{
m_ss
<<
add_attributes
(
arg
.
get
());
m_ss
<<
add_attributes
(
node
);
m_ss
<<
" "
<<
arg
->
get_node_id
()
<<
" -> "
<<
node
->
get_node_id
();
m_ss
<<
" "
<<
arg
->
get_name
()
<<
" -> "
<<
node
->
get_name
();
m_ss
<<
";
\n
"
;
}
});
}
render
();
...
...
@@ -60,11 +65,11 @@ std::string pass::VisualizeTree::get_attributes(const Node* node)
stringstream
ss
;
if
(
node
->
is_parameter
())
{
ss
<<
" "
<<
node
->
get_n
ode_id
()
<<
" [shape=box color=blue]
\n
"
;
ss
<<
" "
<<
node
->
get_n
ame
()
<<
" [shape=box color=blue]
\n
"
;
}
else
{
ss
<<
" "
<<
node
->
get_n
ode_id
()
<<
" [shape=ellipse color=black]
\n
"
;
ss
<<
" "
<<
node
->
get_n
ame
()
<<
" [shape=ellipse color=black]
\n
"
;
}
return
ss
.
str
();
}
...
...
src/ngraph/pass/visualize_tree.hpp
View file @
118e0679
...
...
@@ -18,7 +18,7 @@
#include <sstream>
#include <string>
#include "ngraph/pass/
tree_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -29,11 +29,11 @@ namespace ngraph
class
Node
;
}
class
ngraph
::
pass
::
VisualizeTree
:
public
TreeBase
class
ngraph
::
pass
::
VisualizeTree
:
public
ModulePass
{
public
:
VisualizeTree
(
const
std
::
string
&
file_name
);
bool
run_on_
tree
(
std
::
shared_ptr
<
Node
>
)
override
;
bool
run_on_
module
(
std
::
vector
<
ngraph
::
Function
*>&
)
override
;
private
:
std
::
string
add_attributes
(
const
Node
*
node
);
...
...
src/ngraph/runtime/external_function.cpp
View file @
118e0679
...
...
@@ -659,7 +659,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Turn this into a pass
// Assign layouts
// For now, just make everyone row-major.
for
(
const
Node
*
node
:
pass_manager
.
get_call_graph
())
for
(
const
Node
*
node
:
m_function
->
get_ordered_ops
())
{
for
(
const
descriptor
::
Output
&
output
:
node
->
get_outputs
())
{
...
...
@@ -696,7 +696,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
m_n_outputs
=
tensor_index
.
size
()
-
m_n_inputs
;
// All remaining tensor views
for
(
const
Node
*
node
:
pass_manager
.
get_call_graph
())
for
(
const
Node
*
node
:
m_function
->
get_ordered_ops
())
{
for
(
const
descriptor
::
Output
&
output
:
node
->
get_outputs
())
{
...
...
@@ -712,7 +712,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Now we build the eigen-VM instructions
auto
op_map
=
get_op_map
();
for
(
const
Node
*
node
:
pass_manager
.
get_call_graph
())
for
(
const
Node
*
node
:
m_function
->
get_ordered_ops
())
{
auto
handler_it
=
op_map
.
find
(
type_index
(
typeid
(
*
node
)));
if
(
handler_it
==
op_map
.
end
())
...
...
src/ngraph/visualize.cpp
deleted
100644 → 0
View file @
c2ff1508
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cstdio>
#include <fstream>
#include <list>
#include "ngraph/node.hpp"
#include "ngraph/util.hpp"
#include "ngraph/visualize.hpp"
using
namespace
ngraph
;
using
namespace
std
;
Visualize
::
Visualize
(
const
string
&
name
)
:
m_name
{
name
}
{
}
void
Visualize
::
add
(
node_ptr
p
)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes
(
p
,
[
&
](
Node
*
node
)
{
for
(
auto
arg
:
node
->
get_arguments
())
{
m_ss
<<
add_attributes
(
arg
.
get
());
m_ss
<<
add_attributes
(
node
);
m_ss
<<
" "
<<
arg
->
get_node_id
()
<<
" -> "
<<
node
->
get_node_id
();
m_ss
<<
";
\n
"
;
}
});
}
std
::
string
Visualize
::
add_attributes
(
const
Node
*
node
)
{
string
rc
;
if
(
!
contains
(
m_nodes_with_attributes
,
node
))
{
m_nodes_with_attributes
.
insert
(
node
);
rc
=
get_attributes
(
node
);
}
return
rc
;
}
std
::
string
Visualize
::
get_attributes
(
const
Node
*
node
)
{
stringstream
ss
;
if
(
node
->
is_parameter
())
{
ss
<<
" "
<<
node
->
get_node_id
()
<<
" [shape=box color=blue]
\n
"
;
}
else
{
ss
<<
" "
<<
node
->
get_node_id
()
<<
" [shape=ellipse color=black]
\n
"
;
}
return
ss
.
str
();
}
void
Visualize
::
save_dot
(
const
string
&
path
)
const
{
#ifdef GRAPHVIZ_FOUND
auto
tmp_file
=
path
+
".tmp"
;
ofstream
out
(
tmp_file
);
if
(
out
)
{
out
<<
"digraph "
<<
m_name
<<
"
\n
{
\n
"
;
out
<<
m_ss
.
str
();
out
<<
"}
\n
"
;
out
.
close
();
stringstream
ss
;
ss
<<
"dot -Tpng "
<<
tmp_file
<<
" -o "
<<
path
;
auto
cmd
=
ss
.
str
();
auto
stream
=
popen
(
cmd
.
c_str
(),
"r"
);
pclose
(
stream
);
remove
(
tmp_file
.
c_str
());
}
#else
#endif
}
test/pass_liveness.cpp
View file @
118e0679
...
...
@@ -51,7 +51,7 @@ TEST(pass, liveness)
shared_ptr
<
Function
>
func
=
make_test_graph
();
pass_manager
.
run_passes
(
func
.
get
());
auto
sorted
=
pass_manager
.
get_call_graph
();
auto
sorted
=
func
->
get_ordered_ops
();
// for (const Node* node : sorted)
// {
...
...
test/pass_manager.cpp
View file @
118e0679
...
...
@@ -40,15 +40,28 @@ TEST(pass_manager, add)
auto
graph
=
make_test_graph
();
size_t
node_count
=
get_node_count
(
graph
->
get_result
());
pass_manager
.
run_passes
(
graph
.
get
());
auto
sorted
=
pass_manager
.
get_call_graph
();
auto
sorted
=
graph
->
get_ordered_ops
();
EXPECT_EQ
(
node_count
,
sorted
.
size
());
EXPECT_TRUE
(
validate_list
(
sorted
));
}
TEST
(
pass_manager
,
dependency
)
TEST
(
pass_manager
,
module_add_function
)
{
pass
::
Manager
pass_manager
;
// First create "f(A,B,C) = (A+B)*C".
auto
shape
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
C
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_f
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
f
=
make_shared
<
Function
>
((
A
+
B
)
*
C
,
rt_f
,
op
::
Parameters
{
A
,
B
,
C
});
pass_manager
.
register_pass
<
pass
::
TopologicalSort
>
();
EXPECT_THROW
(
pass_manager
.
register_pass
<
pass
::
AssignTensors
>
(),
runtime_error
);
// Now make "g(X,Y,Z) = f(X,Y,Z) + f(X,Y,Z)"
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Y
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Z
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_g
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
g
=
make_shared
<
Function
>
(
make_shared
<
op
::
FunctionCall
>
(
f
,
Nodes
{
X
,
Y
,
Z
})
+
make_shared
<
op
::
FunctionCall
>
(
f
,
Nodes
{
X
,
Y
,
Z
}),
rt_g
,
op
::
Parameters
{
X
,
Y
,
Z
});
}
test/pass_memory_layout.cpp
View file @
118e0679
...
...
@@ -218,7 +218,7 @@ TEST(memory_layout, basic)
auto
graph
=
make_test_graph
();
pass_manager
.
run_passes
(
graph
);
auto
sorted
=
pass_manager
.
get_call_graph
();
auto
sorted
=
graph
->
get_ordered_ops
();
size_t
temporary_pool_size
=
pass_manager
.
get_state
().
get_temporary_pool_size
();
EXPECT_EQ
(
12
,
temporary_pool_size
);
}
test/topological_sort.cpp
View file @
118e0679
...
...
@@ -21,10 +21,10 @@
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/collect_functions.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/util.hpp"
#include "ngraph/visualize.hpp"
#include "test_tools.hpp"
using
namespace
std
;
...
...
@@ -69,7 +69,7 @@ TEST(topological_sort, basic)
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
TopologicalSort
>
();
pass_manager
.
run_passes
(
f0
);
auto
sorted_list
=
pass_manager
.
get_call_graph
();
auto
sorted_list
=
f0
->
get_ordered_ops
();
size_t
node_count
=
get_node_count
(
r0
);
...
...
@@ -121,7 +121,7 @@ TEST(benchmark, topological_sort)
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
TopologicalSort
>
();
pass_manager
.
run_passes
(
f0
);
auto
sorted_list
=
pass_manager
.
get_call_graph
();
auto
sorted_list
=
f0
->
get_ordered_ops
();
timer
.
stop
();
NGRAPH_INFO
<<
"topological sort took "
<<
timer
.
get_milliseconds
()
<<
"ms"
;
...
...
@@ -135,3 +135,51 @@ TEST(benchmark, topological_sort)
timer
.
stop
();
NGRAPH_INFO
<<
"delete nodes took "
<<
timer
.
get_milliseconds
()
<<
"ms"
;
}
TEST
(
topological_sort
,
collect_functions
)
{
// First create "f(A,B,C) = (A+B)*C".
auto
shape
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
C
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_f
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
f
=
make_shared
<
Function
>
((
A
+
B
)
*
C
,
rt_f
,
op
::
Parameters
{
A
,
B
,
C
},
"f"
);
// Now make "g(X,Y,Z) = f(X,Y,Z) + f(X,Y,Z)"
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Y
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Z
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_g
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
g
=
make_shared
<
Function
>
(
make_shared
<
op
::
FunctionCall
>
(
f
,
Nodes
{
X
,
Y
,
Z
})
+
make_shared
<
op
::
FunctionCall
>
(
f
,
Nodes
{
X
,
Y
,
Z
}),
rt_g
,
op
::
Parameters
{
X
,
Y
,
Z
},
"g"
);
// Now make "h(X,Y,Z) = g(X,Y,Z) + g(X,Y,Z)"
auto
X1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Y1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Z1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_h
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
h
=
make_shared
<
Function
>
(
make_shared
<
op
::
FunctionCall
>
(
g
,
Nodes
{
X1
,
Y1
,
Z1
})
+
make_shared
<
op
::
FunctionCall
>
(
g
,
Nodes
{
X1
,
Y1
,
Z1
}),
rt_h
,
op
::
Parameters
{
X1
,
Y1
,
Z1
},
"h"
);
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
CollectFunctions
>
();
pass_manager
.
run_passes
(
h
);
set
<
string
>
expected
=
{
"f"
,
"g"
,
"h"
};
auto
functions
=
pass_manager
.
get_state
().
get_functions
();
vector
<
string
>
fnames
;
for
(
Function
*
func
:
functions
)
{
fnames
.
push_back
(
func
->
get_name
());
}
EXPECT_EQ
(
expected
.
size
(),
functions
.
size
());
EXPECT_TRUE
(
contains
(
fnames
,
"f"
));
EXPECT_TRUE
(
contains
(
fnames
,
"g"
));
EXPECT_TRUE
(
contains
(
fnames
,
"h"
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment