Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
118e0679
Commit
118e0679
authored
7 years ago
by
Jai Menon
Committed by
GitHub
7 years ago
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into jmenon/codegen
parents
c2ff1508
c3dfdf5a
Show whitespace changes
Inline
Side-by-side
Showing
36 changed files
with
474 additions
and
403 deletions
+474
-403
CMakeLists.txt
src/ngraph/CMakeLists.txt
+2
-3
function.cpp
src/ngraph/function.cpp
+76
-2
function.hpp
src/ngraph/function.hpp
+25
-2
node.cpp
src/ngraph/node.cpp
+26
-0
node.hpp
src/ngraph/node.hpp
+2
-0
assign_tensors.cpp
src/ngraph/pass/assign_tensors.cpp
+3
-21
assign_tensors.hpp
src/ngraph/pass/assign_tensors.hpp
+3
-5
collect_functions.cpp
src/ngraph/pass/collect_functions.cpp
+34
-29
collect_functions.hpp
src/ngraph/pass/collect_functions.hpp
+19
-1
dump_sorted.cpp
src/ngraph/pass/dump_sorted.cpp
+6
-3
dump_sorted.hpp
src/ngraph/pass/dump_sorted.hpp
+3
-3
liveness.cpp
src/ngraph/pass/liveness.cpp
+1
-19
liveness.hpp
src/ngraph/pass/liveness.hpp
+3
-5
manager.cpp
src/ngraph/pass/manager.cpp
+70
-63
manager.hpp
src/ngraph/pass/manager.hpp
+7
-36
manager_state.cpp
src/ngraph/pass/manager_state.cpp
+17
-20
manager_state.hpp
src/ngraph/pass/manager_state.hpp
+16
-9
memory_layout.cpp
src/ngraph/pass/memory_layout.cpp
+1
-19
memory_layout.hpp
src/ngraph/pass/memory_layout.hpp
+3
-5
memory_visualize.cpp
src/ngraph/pass/memory_visualize.cpp
+8
-8
memory_visualize.hpp
src/ngraph/pass/memory_visualize.hpp
+3
-5
pass.cpp
src/ngraph/pass/pass.cpp
+2
-2
pass.hpp
src/ngraph/pass/pass.hpp
+42
-2
propagate_types.cpp
src/ngraph/pass/propagate_types.cpp
+2
-2
propagate_types.hpp
src/ngraph/pass/propagate_types.hpp
+3
-3
topological_sort.cpp
src/ngraph/pass/topological_sort.cpp
+7
-5
topological_sort.hpp
src/ngraph/pass/topological_sort.hpp
+3
-3
tree_pass.cpp
src/ngraph/pass/tree_pass.cpp
+0
-15
visualize_tree.cpp
src/ngraph/pass/visualize_tree.cpp
+10
-5
visualize_tree.hpp
src/ngraph/pass/visualize_tree.hpp
+3
-3
external_function.cpp
src/ngraph/runtime/external_function.cpp
+3
-3
visualize.cpp
src/ngraph/visualize.cpp
+0
-92
pass_liveness.cpp
test/pass_liveness.cpp
+1
-1
pass_manager.cpp
test/pass_manager.cpp
+18
-5
pass_memory_layout.cpp
test/pass_memory_layout.cpp
+1
-1
topological_sort.cpp
test/topological_sort.cpp
+51
-3
No files found.
src/ngraph/CMakeLists.txt
View file @
118e0679
...
...
@@ -39,16 +39,16 @@ set (SRC
ops/unary_elementwise_arithmetic.cpp
ops/unary_elementwise_builtin.cpp
pass/assign_tensors.cpp
pass/c
all_pas
s.cpp
pass/c
ollect_function
s.cpp
pass/dump_sorted.cpp
pass/liveness.cpp
pass/manager.cpp
pass/manager_state.cpp
pass/memory_layout.cpp
pass/memory_visualize.cpp
pass/pass.cpp
pass/propagate_types.cpp
pass/topological_sort.cpp
pass/tree_pass.cpp
pass/visualize_tree.cpp
runtime/call_frame.cpp
runtime/external_function.cpp
...
...
@@ -58,7 +58,6 @@ set (SRC
types/element_type.cpp
types/type.cpp
util.cpp
visualize.cpp
)
# find_program (GRAPHVIZ dot)
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/function.cpp
View file @
118e0679
...
...
@@ -15,21 +15,95 @@
#include <memory>
#include "ngraph/function.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
using
namespace
ngraph
;
size_t
Function
::
m_next_instance_id
=
0
;
Function
::
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
shared_ptr
<
ValueType
>&
result_type
,
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
)
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
,
const
std
::
string
&
name
)
:
m_result
(
result
)
,
m_parameters
(
parameters
)
,
m_name
(
"Function"
)
,
m_name
(
name
)
,
m_result_type
(
result_type
)
,
m_ordered_ops_valid
(
false
)
,
m_instance_id
(
m_next_instance_id
++
)
{
size_t
i
=
0
;
for
(
auto
parameter
:
parameters
)
{
parameter
->
assign_function
(
this
,
i
++
);
}
traverse_nodes
(
result
,
[
&
](
Node
*
node
)
{
m_ops
.
push_back
(
node
);
});
}
void
Function
::
set_ordered_ops
(
const
std
::
list
<
Node
*>&
ordered_ops
)
{
m_ordered_ops
=
ordered_ops
;
m_ordered_ops_valid
=
true
;
}
std
::
list
<
Node
*>&
Function
::
get_ops
()
{
return
m_ops
;
}
const
std
::
list
<
Node
*>&
Function
::
get_ops
()
const
{
return
m_ops
;
}
std
::
list
<
Node
*>&
Function
::
get_ordered_ops
()
{
if
(
!
m_ordered_ops_valid
)
{
throw
ngraph_error
(
"Access to ordered ops invalid"
);
}
return
m_ordered_ops
;
}
const
std
::
list
<
Node
*>&
Function
::
get_ordered_ops
()
const
{
if
(
!
m_ordered_ops_valid
)
{
throw
ngraph_error
(
"Access to ordered ops invalid"
);
}
return
m_ordered_ops
;
}
std
::
string
Function
::
get_name
()
const
{
string
rc
;
if
(
m_name
.
empty
())
{
rc
=
"Function_"
+
to_string
(
m_instance_id
);
}
else
{
rc
=
m_name
;
}
return
rc
;
}
void
Function
::
set_name
(
const
string
&
name
)
{
if
(
m_name
.
empty
())
{
m_name
=
name
;
}
else
{
throw
ngraph_error
(
"Function name may be set exactly once"
);
}
}
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
Function
&
f
)
{
out
<<
"Function("
<<
f
.
get_name
()
<<
")"
;
return
out
;
}
This diff is collapsed.
Click to expand it.
src/ngraph/function.hpp
View file @
118e0679
...
...
@@ -15,11 +15,13 @@
#pragma once
#include <initializer_list>
#include <list>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/ops/parameter.hpp"
...
...
@@ -34,7 +36,8 @@ namespace ngraph
public
:
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
shared_ptr
<
ValueType
>&
result_type
,
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
);
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
,
const
std
::
string
&
name
=
""
);
std
::
shared_ptr
<
Node
>
get_result
()
{
return
m_result
;
}
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>
get_parameters
()
const
...
...
@@ -42,11 +45,31 @@ namespace ngraph
return
m_parameters
;
}
const
std
::
shared_ptr
<
ValueType
>
get_result_type
()
const
{
return
m_result_type
;
}
std
::
string
get_name
()
const
{
return
m_name
;
}
std
::
string
get_name
()
const
;
void
set_name
(
const
std
::
string
&
name
);
std
::
list
<
Node
*>&
get_ops
();
const
std
::
list
<
Node
*>&
get_ops
()
const
;
std
::
list
<
Node
*>&
get_ordered_ops
();
const
std
::
list
<
Node
*>&
get_ordered_ops
()
const
;
void
set_ordered_ops
(
const
std
::
list
<
Node
*>&
);
void
set_ordered_ops_valid
()
{
m_ordered_ops_valid
=
true
;
}
void
clear_ordered_ops_valid
()
{
m_ordered_ops_valid
=
false
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Function
&
);
protected
:
std
::
shared_ptr
<
Node
>
m_result
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
op
::
Parameter
>>
m_parameters
;
std
::
string
m_name
;
std
::
shared_ptr
<
ValueType
>
m_result_type
;
bool
m_ordered_ops_valid
;
std
::
list
<
Node
*>
m_ordered_ops
;
std
::
list
<
Node
*>
m_ops
;
private
:
Function
(
const
Function
&
)
=
delete
;
Function
(
const
Function
&&
)
=
delete
;
static
size_t
m_next_instance_id
;
size_t
m_instance_id
;
};
}
This diff is collapsed.
Click to expand it.
src/ngraph/node.cpp
View file @
118e0679
...
...
@@ -104,6 +104,32 @@ std::string Node::get_node_id() const
return
ss
.
str
();
}
std
::
string
Node
::
get_name
()
const
{
string
rc
;
if
(
m_name
.
empty
())
{
rc
=
description
()
+
"_"
+
to_string
(
m_instance_id
);
}
else
{
rc
=
m_name
;
}
return
rc
;
}
void
Node
::
set_name
(
const
string
&
name
)
{
if
(
m_name
.
empty
())
{
m_name
=
name
;
}
else
{
throw
ngraph_error
(
"Node name may be set exactly once"
);
}
}
namespace
ngraph
{
ostream
&
operator
<<
(
ostream
&
out
,
const
Node
&
node
)
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/node.hpp
View file @
118e0679
...
...
@@ -55,6 +55,8 @@ namespace ngraph
public
:
/// The class name, must not contain spaces
virtual
std
::
string
description
()
const
=
0
;
std
::
string
get_name
()
const
;
void
set_name
(
const
std
::
string
&
name
);
/// Propagate types and check arguments for consistency
virtual
void
propagate_types
()
=
0
;
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/pass/assign_tensors.cpp
View file @
118e0679
...
...
@@ -25,15 +25,15 @@
using
namespace
std
;
using
namespace
ngraph
;
bool
pass
::
AssignTensors
::
run_on_call_
list
(
std
::
list
<
Node
*>&
node_list
)
bool
pass
::
AssignTensors
::
run_on_call_
graph
(
list
<
Node
*>&
nodes
)
{
for
(
Node
*
node
:
node
_list
)
for
(
Node
*
node
:
node
s
)
{
try
{
// We need to set the nodes is_output state prior to call assign_tensors
// so that the output state can be passes to the constructed tensors.
if
(
node
==
get_state
().
get_function
(
)
->
get_result
().
get
())
if
(
node
==
get_state
().
get_function
s
().
at
(
0
)
->
get_result
().
get
())
{
node
->
set_is_output
();
}
...
...
@@ -50,21 +50,3 @@ bool pass::AssignTensors::run_on_call_list(std::list<Node*>& node_list)
}
return
false
;
}
void
pass
::
AssignTensors
::
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
registered_passes
)
const
{
bool
found_propagate_types
=
false
;
for
(
auto
pass
:
registered_passes
)
{
if
(
dynamic_pointer_cast
<
PropagateTypes
>
(
pass
))
{
found_propagate_types
=
true
;
}
}
if
(
!
found_propagate_types
)
{
throw
runtime_error
(
"Dependency 'PropagateTypes' not found for pass 'AssignTensors'"
);
}
}
This diff is collapsed.
Click to expand it.
src/ngraph/pass/assign_tensors.hpp
View file @
118e0679
...
...
@@ -14,7 +14,7 @@
#pragma once
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -25,12 +25,10 @@ namespace ngraph
class
Node
;
}
class
ngraph
::
pass
::
AssignTensors
:
public
Call
Base
class
ngraph
::
pass
::
AssignTensors
:
public
Call
GraphPass
{
public
:
virtual
bool
run_on_call_list
(
std
::
list
<
Node
*>&
)
override
;
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
)
const
override
;
virtual
bool
run_on_call_graph
(
std
::
list
<
Node
*>&
nodes
)
override
;
private
:
};
This diff is collapsed.
Click to expand it.
src/ngraph/
visualize.h
pp
→
src/ngraph/
pass/collect_functions.c
pp
View file @
118e0679
...
...
@@ -12,34 +12,39 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <functional>
#include <memory>
#include <set>
#include <sstream>
namespace
ngraph
#include "ngraph/pass/collect_functions.hpp"
#include "ngraph/function.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
::
pass
;
bool
CollectFunctions
::
run_on_function
(
ngraph
::
Function
*
func
)
{
class
Visualize
;
class
Node
;
using
node_ptr
=
std
::
shared_ptr
<
Node
>
;
set
<
Function
*>
functions
;
deque
<
Function
*>
stack
;
stack
.
push_back
(
func
);
while
(
stack
.
empty
()
==
false
)
{
Function
*
f
=
stack
.
front
();
stack
.
pop_front
();
functions
.
insert
(
f
);
traverse_nodes
(
f
->
get_result
(),
[
&
](
Node
*
node
)
{
op
::
FunctionCall
*
fc
=
dynamic_cast
<
op
::
FunctionCall
*>
(
node
);
if
(
fc
)
{
stack
.
push_back
(
fc
->
get_function
().
get
());
}
});
}
get_state
().
set_functions
(
functions
);
return
false
;
}
class
ngraph
::
Visualize
{
public
:
Visualize
(
const
std
::
string
&
name
=
"ngraph"
);
void
add
(
node_ptr
);
void
save_dot
(
const
std
::
string
&
path
)
const
;
private
:
std
::
string
add_attributes
(
const
Node
*
node
);
std
::
string
get_attributes
(
const
Node
*
node
);
std
::
stringstream
m_ss
;
std
::
string
m_name
;
std
::
set
<
const
Node
*>
m_nodes_with_attributes
;
};
This diff is collapsed.
Click to expand it.
src/ngraph/pass/c
all_pass.c
pp
→
src/ngraph/pass/c
ollect_functions.h
pp
View file @
118e0679
...
...
@@ -12,4 +12,22 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/pass/call_pass.hpp"
#pragma once
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
namespace
pass
{
class
CollectFunctions
;
}
}
class
ngraph
::
pass
::
CollectFunctions
:
public
FunctionPass
{
public
:
bool
run_on_function
(
ngraph
::
Function
*
)
override
;
private
:
};
This diff is collapsed.
Click to expand it.
src/ngraph/pass/dump_sorted.cpp
View file @
118e0679
...
...
@@ -27,14 +27,16 @@ pass::DumpSorted::DumpSorted(const string& output_file)
{
}
bool
pass
::
DumpSorted
::
run_on_
call_list
(
list
<
Node
*>&
node
s
)
bool
pass
::
DumpSorted
::
run_on_
module
(
vector
<
Function
*>&
function
s
)
{
ofstream
out
{
m_output_file
};
if
(
out
)
{
for
(
const
Node
*
node
:
node
s
)
for
(
Function
*
f
:
function
s
)
{
out
<<
node
->
get_node_id
()
<<
"("
;
for
(
const
Node
*
node
:
f
->
get_ordered_ops
())
{
out
<<
node
->
get_name
()
<<
"("
;
vector
<
string
>
inputs
;
for
(
const
Input
&
input
:
node
->
get_inputs
())
{
...
...
@@ -65,6 +67,7 @@ bool pass::DumpSorted::run_on_call_list(list<Node*>& nodes)
}
}
}
}
return
false
;
}
This diff is collapsed.
Click to expand it.
src/ngraph/pass/dump_sorted.hpp
View file @
118e0679
...
...
@@ -16,7 +16,7 @@
#include <string>
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -27,12 +27,12 @@ namespace ngraph
class
Node
;
}
class
ngraph
::
pass
::
DumpSorted
:
public
CallBase
class
ngraph
::
pass
::
DumpSorted
:
public
ModulePass
{
public
:
DumpSorted
(
const
std
::
string
&
output_file
);
virtual
bool
run_on_
call_list
(
std
::
list
<
Node
*>&
)
override
;
virtual
bool
run_on_
module
(
std
::
vector
<
Function
*>&
)
override
;
private
:
const
std
::
string
m_output_file
;
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/pass/liveness.cpp
View file @
118e0679
...
...
@@ -27,7 +27,7 @@ using namespace std;
using
namespace
ngraph
;
using
namespace
ngraph
::
descriptor
;
bool
pass
::
Liveness
::
run_on_call_
list
(
list
<
Node
*>&
ops
)
bool
pass
::
Liveness
::
run_on_call_
graph
(
list
<
Node
*>&
ops
)
{
unordered_set
<
Tensor
*>
currently_live
;
...
...
@@ -123,24 +123,6 @@ bool pass::Liveness::run_on_call_list(list<Node*>& ops)
return
false
;
}
void
pass
::
Liveness
::
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
registered_passes
)
const
{
bool
found_propagate_types
=
false
;
for
(
auto
pass
:
registered_passes
)
{
if
(
dynamic_pointer_cast
<
AssignTensors
>
(
pass
))
{
found_propagate_types
=
true
;
}
}
if
(
!
found_propagate_types
)
{
throw
runtime_error
(
"Dependency 'PropagateTypes' not found for pass 'AssignTensors'"
);
}
}
bool
pass
::
Liveness
::
is_temporary
(
const
Tensor
&
tensor
)
{
return
tensor
.
is_persistent
()
==
false
&&
tensor
.
is_input
()
==
false
&&
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/pass/liveness.hpp
View file @
118e0679
...
...
@@ -15,7 +15,7 @@
#pragma once
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -26,12 +26,10 @@ namespace ngraph
class
Node
;
}
class
ngraph
::
pass
::
Liveness
:
public
Call
Base
class
ngraph
::
pass
::
Liveness
:
public
Call
GraphPass
{
public
:
virtual
bool
run_on_call_list
(
std
::
list
<
Node
*>&
)
override
;
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
)
const
override
;
virtual
bool
run_on_call_graph
(
std
::
list
<
Node
*>&
)
override
;
private
:
bool
is_temporary
(
const
descriptor
::
Tensor
&
);
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/pass/manager.cpp
View file @
118e0679
...
...
@@ -19,40 +19,11 @@
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/pass.hpp"
using
namespace
std
;
using
namespace
ngraph
;
Function
*
ngraph
::
pass
::
ManagerState
::
get_function
()
{
return
m_function
;
}
void
ngraph
::
pass
::
ManagerState
::
set_function
(
Function
*
func
)
{
m_function
=
func
;
}
size_t
ngraph
::
pass
::
ManagerState
::
get_temporary_pool_size
()
{
return
m_temporary_pool_size
;
}
void
ngraph
::
pass
::
ManagerState
::
set_temporary_pool_size
(
size_t
size
)
{
m_temporary_pool_size
=
size
;
}
std
::
list
<
Node
*>&
ngraph
::
pass
::
ManagerState
::
get_call_graph
()
{
return
m_call_graph
;
}
const
std
::
list
<
Node
*>&
ngraph
::
pass
::
ManagerState
::
get_call_graph
()
const
{
return
m_call_graph
;
}
ngraph
::
pass
::
Manager
::
Manager
()
{
}
...
...
@@ -65,26 +36,6 @@ void ngraph::pass::Manager::initialize_default_passes()
{
}
void
ngraph
::
pass
::
Manager
::
register_pass_ptr
(
std
::
shared_ptr
<
TreeBase
>
p
)
{
if
(
p
==
nullptr
)
{
throw
invalid_argument
(
"null pass registered"
);
}
p
->
check_dependencies
(
m_tree_passes
);
m_tree_passes
.
push_back
(
p
);
}
void
ngraph
::
pass
::
Manager
::
register_pass_ptr
(
std
::
shared_ptr
<
CallBase
>
p
)
{
if
(
p
==
nullptr
)
{
throw
invalid_argument
(
"null pass registered"
);
}
p
->
check_dependencies
(
m_call_passes
);
m_call_passes
.
push_back
(
p
);
}
void
ngraph
::
pass
::
Manager
::
run_passes
(
shared_ptr
<
Function
>
func
)
{
run_passes
(
func
.
get
());
...
...
@@ -92,23 +43,79 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
void
ngraph
::
pass
::
Manager
::
run_passes
(
Function
*
func
)
{
m_state
.
set_function
(
func
);
for
(
shared_ptr
<
TreeBase
>
p
:
m_tree_passes
)
vector
<
Function
*>
fs
=
{
func
};
get_state
().
set_functions
(
fs
);
for
(
shared_ptr
<
PassBase
>
pass
:
m_pass_list
)
{
pass
->
set_state
(
get_state
());
auto
module_pass
=
dynamic_pointer_cast
<
ModulePass
>
(
pass
);
auto
function_pass
=
dynamic_pointer_cast
<
FunctionPass
>
(
pass
);
auto
node_pass
=
dynamic_pointer_cast
<
NodePass
>
(
pass
);
auto
call_graph_pass
=
dynamic_pointer_cast
<
CallGraphPass
>
(
pass
);
if
(
module_pass
)
{
p
->
set_state
(
get_state
());
p
->
run_on_tree
(
func
->
get_result
());
module_pass
->
run_on_module
(
fs
);
}
for
(
shared_ptr
<
CallBase
>&
p
:
m_call_passes
)
else
if
(
function_pass
)
{
for
(
Function
*
f
:
fs
)
{
p
->
set_state
(
get_state
());
p
->
run_on_call_list
(
get_state
().
get_call_graph
());
function_pass
->
run_on_function
(
f
);
}
}
const
std
::
list
<
ngraph
::
Node
*>&
ngraph
::
pass
::
Manager
::
get_call_graph
()
const
{
return
m_state
.
get_call_graph
();
}
else
if
(
node_pass
)
{
for
(
Function
*
f
:
fs
)
{
for
(
Node
*
n
:
f
->
get_ops
())
{
node_pass
->
run_on_node
(
n
);
}
}
}
else
if
(
call_graph_pass
)
{
for
(
Function
*
f
:
fs
)
{
call_graph_pass
->
run_on_call_graph
(
f
->
get_ordered_ops
());
}
}
}
// for (shared_ptr<ModulePass>& p : m_module_passes)
// {
// p->set_state(get_state());
// p->run_on_module(fs);
// }
// for (Function* f : fs)
// {
// for (shared_ptr<FunctionPass> p : m_function_passes)
// {
// p->set_state(get_state());
// p->run_on_function(f);
// }
// }
// for (Function* f : fs)
// {
// NGRAPH_INFO;
// for (shared_ptr<NodePass> p : m_node_passes)
// {
// for (Node* node : f->get_ops())
// {
// NGRAPH_INFO;
// p->set_state(get_state());
// p->run_on_node(node);
// }
// }
// }
// for (shared_ptr<CallGraphPass>& p : m_call_graph_passes)
// {
// p->set_state(get_state());
// p->run_on_call_graph(func->get_ordered_ops());
// }
}
ngraph
::
pass
::
ManagerState
&
ngraph
::
pass
::
Manager
::
get_state
()
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/pass/manager.hpp
View file @
118e0679
...
...
@@ -18,8 +18,8 @@
#include <memory>
#include <vector>
#include "ngraph/pass/
call_pass
.hpp"
#include "ngraph/pass/
tree_
pass.hpp"
#include "ngraph/pass/
manager_state
.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -33,24 +33,6 @@ namespace ngraph
class
Function
;
}
class
ngraph
::
pass
::
ManagerState
{
public
:
Function
*
get_function
();
void
set_function
(
Function
*
);
size_t
get_temporary_pool_size
();
void
set_temporary_pool_size
(
size_t
);
std
::
list
<
Node
*>&
get_call_graph
();
const
std
::
list
<
Node
*>&
get_call_graph
()
const
;
private
:
Function
*
m_function
=
nullptr
;
size_t
m_temporary_pool_size
=
0
;
std
::
list
<
Node
*>
m_call_graph
;
};
class
ngraph
::
pass
::
Manager
{
public
:
...
...
@@ -62,29 +44,18 @@ public:
template
<
typename
T
,
class
...
Args
>
void
register_pass
(
Args
...
args
)
{
static_assert
(
std
::
is_base_of
<
pass
::
Base
,
T
>::
value
,
"pass not derived from pass base"
);
if
(
std
::
is_base_of
<
TreeBase
,
T
>::
value
)
{
register_pass_ptr
(
std
::
make_shared
<
T
>
(
args
...));
}
else
if
(
std
::
is_base_of
<
CallBase
,
T
>::
value
)
{
register_pass_ptr
(
std
::
make_shared
<
T
>
(
args
...));
}
static_assert
(
std
::
is_base_of
<
pass
::
PassBase
,
T
>::
value
,
"pass not derived from pass base"
);
auto
pass
=
std
::
make_shared
<
T
>
(
args
...);
auto
pass_base
=
std
::
static_pointer_cast
<
PassBase
>
(
pass
);
m_pass_list
.
push_back
(
pass_base
);
}
void
run_passes
(
Function
*
);
void
run_passes
(
std
::
shared_ptr
<
Function
>
);
const
std
::
list
<
Node
*>&
get_call_graph
()
const
;
ManagerState
&
get_state
();
private
:
void
register_pass_ptr
(
std
::
shared_ptr
<
TreeBase
>
);
void
register_pass_ptr
(
std
::
shared_ptr
<
CallBase
>
);
std
::
vector
<
std
::
shared_ptr
<
TreeBase
>>
m_tree_passes
;
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>
m_call_passes
;
std
::
vector
<
std
::
shared_ptr
<
PassBase
>>
m_pass_list
;
ManagerState
m_state
;
};
This diff is collapsed.
Click to expand it.
src/ngraph/pass/
tree_pass.h
pp
→
src/ngraph/pass/
manager_state.c
pp
View file @
118e0679
...
...
@@ -12,31 +12,28 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <list>
#include <iostream>
#include <memory>
#include <vector>
#include "ngraph/pass/pass.hpp"
#include "ngraph/function.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/manager_state.hpp"
namespace
ngraph
{
namespace
pass
{
class
TreeBase
;
}
using
namespace
std
;
using
namespace
ngraph
;
class
Node
;
vector
<
Function
*>&
ngraph
::
pass
::
ManagerState
::
get_functions
()
{
return
m_function_list
;
}
class
ngraph
::
pass
::
TreeBase
:
public
Base
size_t
ngraph
::
pass
::
ManagerState
::
get_temporary_pool_size
()
{
public
:
virtual
~
TreeBase
()
{}
// return true if changes were made to the tree
virtual
bool
run_on_tree
(
std
::
shared_ptr
<
Node
>
)
=
0
;
return
m_temporary_pool_size
;
}
// derived class throws exception if its dependencies have not been met
virtual
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
TreeBase
>>&
)
const
{}
};
void
ngraph
::
pass
::
ManagerState
::
set_temporary_pool_size
(
size_t
size
)
{
m_temporary_pool_size
=
size
;
}
This diff is collapsed.
Click to expand it.
src/ngraph/pass/
call_pass
.hpp
→
src/ngraph/pass/
manager_state
.hpp
View file @
118e0679
...
...
@@ -14,29 +14,36 @@
#pragma once
#include <list>
#include <memory>
#include <vector>
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
namespace
pass
{
class
CallBas
e
;
class
ManagerStat
e
;
}
class
Node
;
class
Function
;
}
class
ngraph
::
pass
::
CallBase
:
public
Bas
e
class
ngraph
::
pass
::
ManagerStat
e
{
public
:
virtual
~
CallBase
()
{}
virtual
bool
run_on_call_list
(
std
::
list
<
Node
*>&
)
=
0
;
std
::
vector
<
Function
*>&
get_functions
();
template
<
typename
T
>
void
set_functions
(
const
T
&
collection
)
{
m_function_list
.
clear
();
m_function_list
.
insert
(
m_function_list
.
begin
(),
collection
.
begin
(),
collection
.
end
());
}
size_t
get_temporary_pool_size
();
void
set_temporary_pool_size
(
size_t
);
// derived class throws exception if its dependencies have not been met
virtual
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
)
const
{}
private
:
size_t
m_temporary_pool_size
=
0
;
std
::
vector
<
Function
*>
m_function_list
;
};
This diff is collapsed.
Click to expand it.
src/ngraph/pass/memory_layout.cpp
View file @
118e0679
...
...
@@ -27,7 +27,7 @@ using namespace std;
using
namespace
ngraph
;
using
namespace
ngraph
::
descriptor
;
bool
pass
::
MemoryLayout
::
run_on_call_
list
(
std
::
list
<
Node
*>&
node_list
)
bool
pass
::
MemoryLayout
::
run_on_call_
graph
(
std
::
list
<
Node
*>&
node_list
)
{
MemoryManager
mm
;
for
(
const
Node
*
node
:
node_list
)
...
...
@@ -47,24 +47,6 @@ bool pass::MemoryLayout::run_on_call_list(std::list<Node*>& node_list)
return
false
;
}
void
pass
::
MemoryLayout
::
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
registered_passes
)
const
{
bool
found_propagate_types
=
false
;
for
(
auto
pass
:
registered_passes
)
{
if
(
dynamic_pointer_cast
<
Liveness
>
(
pass
))
{
found_propagate_types
=
true
;
}
}
if
(
!
found_propagate_types
)
{
throw
runtime_error
(
"Dependency 'PropagateTypes' not found for pass 'AssignTensors'"
);
}
}
pass
::
MemoryManager
::
node
::
node
(
size_t
size
,
block_state
state
)
:
m_size
{
size
}
,
m_state
{
state
}
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/pass/memory_layout.hpp
View file @
118e0679
...
...
@@ -18,7 +18,7 @@
#include <list>
#include <sstream>
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -31,12 +31,10 @@ namespace ngraph
class
Node
;
}
class
ngraph
::
pass
::
MemoryLayout
:
public
Call
Base
class
ngraph
::
pass
::
MemoryLayout
:
public
Call
GraphPass
{
public
:
virtual
bool
run_on_call_list
(
std
::
list
<
Node
*>&
)
override
;
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
)
const
override
;
virtual
bool
run_on_call_graph
(
std
::
list
<
Node
*>&
)
override
;
private
:
};
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/pass/memory_visualize.cpp
View file @
118e0679
...
...
@@ -19,6 +19,7 @@
#include "memory_visualize.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/util.hpp"
...
...
@@ -31,11 +32,13 @@ pass::MemoryVisualize::MemoryVisualize(const string& filename)
{
}
bool
pass
::
MemoryVisualize
::
run_on_
call_list
(
list
<
Node
*>&
_node
s
)
bool
pass
::
MemoryVisualize
::
run_on_
module
(
vector
<
Function
*>&
function
s
)
{
const
list
<
Node
*>
nodes
=
_nodes
;
ofstream
file
(
m_filename
);
{
for
(
const
Function
*
f
:
functions
)
{
const
list
<
Node
*>
nodes
=
f
->
get_ordered_ops
();
file
<<
"<!DOCTYPE html>
\n
<html>
\n
"
;
file
<<
"<head>
\n
"
;
file
<<
" <style>
\n
"
;
...
...
@@ -89,13 +92,10 @@ bool pass::MemoryVisualize::run_on_call_list(list<Node*>& _nodes)
// file << "<hr>\n";
file
<<
"</body>
\n
</html>
\n
"
;
}
}
return
false
;
}
void
pass
::
MemoryVisualize
::
check_dependencies
(
const
vector
<
shared_ptr
<
CallBase
>>&
deps
)
const
{
}
const
Node
*
pass
::
MemoryVisualize
::
find_largest_op
(
const
list
<
Node
*>&
nodes
)
{
const
Node
*
largest_op
=
nullptr
;
...
...
@@ -207,7 +207,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
size_t
x2
=
((
usage
/
memory_footprint
)
*
scale
)
+
offset
;
file
<<
"<text x=
\"
"
<<
0
<<
"
\"
y=
\"
"
<<
y
+
text_offset
<<
"
\"
fill=
\"
"
<<
"black"
<<
"
\"
>"
<<
node
->
get_n
ode_id
()
<<
"</text>
\n
"
;
<<
"
\"
>"
<<
node
->
get_n
ame
()
<<
"</text>
\n
"
;
file
<<
"<line x1=
\"
"
<<
x1
<<
"
\"
y1=
\"
"
<<
y
<<
"
\"
x2=
\"
"
<<
x2
<<
"
\"
y2=
\"
"
<<
y
<<
"
\"
"
;
file
<<
" style=
\"
stroke:forestgreen;stroke-width:"
<<
stroke_width
<<
"
\"
/>
\n
"
;
...
...
@@ -231,7 +231,7 @@ void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<Node*>&
{
int
weight
=
compute_op_weight
(
exop
);
file
<<
" <tr>"
;
file
<<
"<td>"
<<
exop
->
get_n
ode_id
()
<<
"</td>"
;
file
<<
"<td>"
<<
exop
->
get_n
ame
()
<<
"</td>"
;
file
<<
"<td align=
\"
right
\"
>"
<<
weight
<<
"</td>"
;
file
<<
"</tr>
\n
"
;
}
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/pass/memory_visualize.hpp
View file @
118e0679
...
...
@@ -18,7 +18,7 @@
#include <limits>
#include <list>
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -29,13 +29,11 @@ namespace ngraph
class
Node
;
}
class
ngraph
::
pass
::
MemoryVisualize
:
public
CallBase
class
ngraph
::
pass
::
MemoryVisualize
:
public
ModulePass
{
public
:
MemoryVisualize
(
const
std
::
string
&
filename
);
virtual
bool
run_on_call_list
(
std
::
list
<
Node
*>&
)
override
;
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
)
const
override
;
virtual
bool
run_on_module
(
std
::
vector
<
Function
*>&
)
override
;
private
:
const
Node
*
find_largest_op
(
const
std
::
list
<
Node
*>&
nodes
);
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/pass/pass.cpp
View file @
118e0679
...
...
@@ -15,12 +15,12 @@
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/manager.hpp"
ngraph
::
pass
::
ManagerState
&
ngraph
::
pass
::
Base
::
get_state
()
ngraph
::
pass
::
ManagerState
&
ngraph
::
pass
::
Pass
Base
::
get_state
()
{
return
*
m_state
;
}
void
ngraph
::
pass
::
Base
::
set_state
(
ManagerState
&
state
)
void
ngraph
::
pass
::
Pass
Base
::
set_state
(
ManagerState
&
state
)
{
m_state
=
&
state
;
}
This diff is collapsed.
Click to expand it.
src/ngraph/pass/pass.hpp
View file @
118e0679
...
...
@@ -14,21 +14,33 @@
#pragma once
#include <list>
#include <memory>
#include <vector>
#include "ngraph/node.hpp"
namespace
ngraph
{
namespace
pass
{
class
Base
;
class
PassBase
;
class
ModulePass
;
class
FunctionPass
;
class
NodePass
;
class
CallGraphPass
;
class
Manager
;
class
ManagerState
;
}
class
Function
;
}
class
ngraph
::
pass
::
Base
class
ngraph
::
pass
::
Pass
Base
{
friend
class
Manager
;
public
:
virtual
~
PassBase
()
{}
protected
:
ManagerState
&
get_state
();
void
set_state
(
ManagerState
&
);
...
...
@@ -36,3 +48,31 @@ protected:
private
:
ManagerState
*
m_state
;
};
class
ngraph
::
pass
::
ModulePass
:
public
PassBase
{
public
:
virtual
~
ModulePass
()
{}
virtual
bool
run_on_module
(
std
::
vector
<
ngraph
::
Function
*>&
)
=
0
;
};
class
ngraph
::
pass
::
FunctionPass
:
public
PassBase
{
public
:
virtual
~
FunctionPass
()
{}
virtual
bool
run_on_function
(
ngraph
::
Function
*
)
=
0
;
};
class
ngraph
::
pass
::
NodePass
:
public
PassBase
{
public
:
virtual
~
NodePass
()
{}
virtual
bool
run_on_node
(
ngraph
::
Node
*
)
=
0
;
};
class
ngraph
::
pass
::
CallGraphPass
:
public
PassBase
{
public
:
virtual
~
CallGraphPass
()
{}
virtual
bool
run_on_call_graph
(
std
::
list
<
ngraph
::
Node
*>&
)
=
0
;
};
This diff is collapsed.
Click to expand it.
src/ngraph/pass/propagate_types.cpp
View file @
118e0679
...
...
@@ -20,9 +20,9 @@
using
namespace
std
;
using
namespace
ngraph
;
bool
pass
::
PropagateTypes
::
run_on_call_
list
(
std
::
list
<
Node
*>&
node_list
)
bool
pass
::
PropagateTypes
::
run_on_call_
graph
(
list
<
Node
*>&
nodes
)
{
for
(
Node
*
node
:
node
_list
)
for
(
Node
*
node
:
node
s
)
{
try
{
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/pass/propagate_types.hpp
View file @
118e0679
...
...
@@ -14,7 +14,7 @@
#pragma once
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -25,10 +25,10 @@ namespace ngraph
class
Node
;
}
class
ngraph
::
pass
::
PropagateTypes
:
public
Call
Base
class
ngraph
::
pass
::
PropagateTypes
:
public
Call
GraphPass
{
public
:
virtual
bool
run_on_call_
list
(
std
::
list
<
Node
*>&
)
override
;
virtual
bool
run_on_call_
graph
(
std
::
list
<
Node
*>&
)
override
;
private
:
};
This diff is collapsed.
Click to expand it.
src/ngraph/pass/topological_sort.cpp
View file @
118e0679
...
...
@@ -15,6 +15,7 @@
#include <deque>
#include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/manager.hpp"
...
...
@@ -24,14 +25,13 @@
using
namespace
ngraph
;
using
namespace
std
;
bool
ngraph
::
pass
::
TopologicalSort
::
run_on_
tree
(
std
::
shared_ptr
<
Node
>
p
)
bool
ngraph
::
pass
::
TopologicalSort
::
run_on_
function
(
ngraph
::
Function
*
func
)
{
list
<
Node
*>&
sorted_list
=
get_state
().
get_call_graph
();
sorted_list
.
clear
();
list
<
Node
*>
result_list
;
deque
<
Node
*>
independent_nodes
;
unordered_map
<
Node
*
,
size_t
>
node_depencency_count
;
traverse_nodes
(
p
,
[
&
](
Node
*
node
)
{
traverse_nodes
(
func
->
get_result
()
,
[
&
](
Node
*
node
)
{
node_depencency_count
[
node
]
=
node
->
get_arguments
().
size
();
if
(
node
->
get_arguments
().
size
()
==
0
)
{
...
...
@@ -42,7 +42,7 @@ bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p)
while
(
independent_nodes
.
size
()
>
0
)
{
auto
independent_node
=
independent_nodes
.
front
();
sorted
_list
.
push_back
(
independent_node
);
result
_list
.
push_back
(
independent_node
);
independent_nodes
.
pop_front
();
for
(
auto
user
:
independent_node
->
users
())
...
...
@@ -56,5 +56,7 @@ bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p)
}
}
func
->
set_ordered_ops
(
result_list
);
return
false
;
}
This diff is collapsed.
Click to expand it.
src/ngraph/pass/topological_sort.hpp
View file @
118e0679
...
...
@@ -17,7 +17,7 @@
#include <list>
#include <memory>
#include "ngraph/pass/
tree_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -28,9 +28,9 @@ namespace ngraph
class
Node
;
}
class
ngraph
::
pass
::
TopologicalSort
:
public
TreeBase
class
ngraph
::
pass
::
TopologicalSort
:
public
FunctionPass
{
public
:
TopologicalSort
()
{}
bool
run_on_
tree
(
std
::
shared_ptr
<
Node
>
)
override
;
bool
run_on_
function
(
ngraph
::
Function
*
)
override
;
};
This diff is collapsed.
Click to expand it.
src/ngraph/pass/tree_pass.cpp
deleted
100644 → 0
View file @
c2ff1508
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/pass/tree_pass.hpp"
This diff is collapsed.
Click to expand it.
src/ngraph/pass/visualize_tree.cpp
View file @
118e0679
...
...
@@ -14,25 +14,30 @@
#include <fstream>
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp"
using
namespace
ngraph
;
using
namespace
std
;
bool
pass
::
VisualizeTree
::
run_on_
tree
(
std
::
shared_ptr
<
Node
>
base_node
)
bool
pass
::
VisualizeTree
::
run_on_
module
(
vector
<
ngraph
::
Function
*>&
functions
)
{
for
(
Function
*
f
:
functions
)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes
(
base_node
,
[
&
](
Node
*
node
)
{
traverse_nodes
(
f
->
get_result
()
,
[
&
](
Node
*
node
)
{
for
(
auto
arg
:
node
->
get_arguments
())
{
m_ss
<<
add_attributes
(
arg
.
get
());
m_ss
<<
add_attributes
(
node
);
m_ss
<<
" "
<<
arg
->
get_node_id
()
<<
" -> "
<<
node
->
get_node_id
();
m_ss
<<
" "
<<
arg
->
get_name
()
<<
" -> "
<<
node
->
get_name
();
m_ss
<<
";
\n
"
;
}
});
}
render
();
...
...
@@ -60,11 +65,11 @@ std::string pass::VisualizeTree::get_attributes(const Node* node)
stringstream
ss
;
if
(
node
->
is_parameter
())
{
ss
<<
" "
<<
node
->
get_n
ode_id
()
<<
" [shape=box color=blue]
\n
"
;
ss
<<
" "
<<
node
->
get_n
ame
()
<<
" [shape=box color=blue]
\n
"
;
}
else
{
ss
<<
" "
<<
node
->
get_n
ode_id
()
<<
" [shape=ellipse color=black]
\n
"
;
ss
<<
" "
<<
node
->
get_n
ame
()
<<
" [shape=ellipse color=black]
\n
"
;
}
return
ss
.
str
();
}
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/pass/visualize_tree.hpp
View file @
118e0679
...
...
@@ -18,7 +18,7 @@
#include <sstream>
#include <string>
#include "ngraph/pass/
tree_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
...
...
@@ -29,11 +29,11 @@ namespace ngraph
class
Node
;
}
class
ngraph
::
pass
::
VisualizeTree
:
public
TreeBase
class
ngraph
::
pass
::
VisualizeTree
:
public
ModulePass
{
public
:
VisualizeTree
(
const
std
::
string
&
file_name
);
bool
run_on_
tree
(
std
::
shared_ptr
<
Node
>
)
override
;
bool
run_on_
module
(
std
::
vector
<
ngraph
::
Function
*>&
)
override
;
private
:
std
::
string
add_attributes
(
const
Node
*
node
);
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/external_function.cpp
View file @
118e0679
...
...
@@ -659,7 +659,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Turn this into a pass
// Assign layouts
// For now, just make everyone row-major.
for
(
const
Node
*
node
:
pass_manager
.
get_call_graph
())
for
(
const
Node
*
node
:
m_function
->
get_ordered_ops
())
{
for
(
const
descriptor
::
Output
&
output
:
node
->
get_outputs
())
{
...
...
@@ -696,7 +696,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
m_n_outputs
=
tensor_index
.
size
()
-
m_n_inputs
;
// All remaining tensor views
for
(
const
Node
*
node
:
pass_manager
.
get_call_graph
())
for
(
const
Node
*
node
:
m_function
->
get_ordered_ops
())
{
for
(
const
descriptor
::
Output
&
output
:
node
->
get_outputs
())
{
...
...
@@ -712,7 +712,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Now we build the eigen-VM instructions
auto
op_map
=
get_op_map
();
for
(
const
Node
*
node
:
pass_manager
.
get_call_graph
())
for
(
const
Node
*
node
:
m_function
->
get_ordered_ops
())
{
auto
handler_it
=
op_map
.
find
(
type_index
(
typeid
(
*
node
)));
if
(
handler_it
==
op_map
.
end
())
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/visualize.cpp
deleted
100644 → 0
View file @
c2ff1508
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cstdio>
#include <fstream>
#include <list>
#include "ngraph/node.hpp"
#include "ngraph/util.hpp"
#include "ngraph/visualize.hpp"
using
namespace
ngraph
;
using
namespace
std
;
Visualize
::
Visualize
(
const
string
&
name
)
:
m_name
{
name
}
{
}
void
Visualize
::
add
(
node_ptr
p
)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes
(
p
,
[
&
](
Node
*
node
)
{
for
(
auto
arg
:
node
->
get_arguments
())
{
m_ss
<<
add_attributes
(
arg
.
get
());
m_ss
<<
add_attributes
(
node
);
m_ss
<<
" "
<<
arg
->
get_node_id
()
<<
" -> "
<<
node
->
get_node_id
();
m_ss
<<
";
\n
"
;
}
});
}
std
::
string
Visualize
::
add_attributes
(
const
Node
*
node
)
{
string
rc
;
if
(
!
contains
(
m_nodes_with_attributes
,
node
))
{
m_nodes_with_attributes
.
insert
(
node
);
rc
=
get_attributes
(
node
);
}
return
rc
;
}
std
::
string
Visualize
::
get_attributes
(
const
Node
*
node
)
{
stringstream
ss
;
if
(
node
->
is_parameter
())
{
ss
<<
" "
<<
node
->
get_node_id
()
<<
" [shape=box color=blue]
\n
"
;
}
else
{
ss
<<
" "
<<
node
->
get_node_id
()
<<
" [shape=ellipse color=black]
\n
"
;
}
return
ss
.
str
();
}
void
Visualize
::
save_dot
(
const
string
&
path
)
const
{
#ifdef GRAPHVIZ_FOUND
auto
tmp_file
=
path
+
".tmp"
;
ofstream
out
(
tmp_file
);
if
(
out
)
{
out
<<
"digraph "
<<
m_name
<<
"
\n
{
\n
"
;
out
<<
m_ss
.
str
();
out
<<
"}
\n
"
;
out
.
close
();
stringstream
ss
;
ss
<<
"dot -Tpng "
<<
tmp_file
<<
" -o "
<<
path
;
auto
cmd
=
ss
.
str
();
auto
stream
=
popen
(
cmd
.
c_str
(),
"r"
);
pclose
(
stream
);
remove
(
tmp_file
.
c_str
());
}
#else
#endif
}
This diff is collapsed.
Click to expand it.
test/pass_liveness.cpp
View file @
118e0679
...
...
@@ -51,7 +51,7 @@ TEST(pass, liveness)
shared_ptr
<
Function
>
func
=
make_test_graph
();
pass_manager
.
run_passes
(
func
.
get
());
auto
sorted
=
pass_manager
.
get_call_graph
();
auto
sorted
=
func
->
get_ordered_ops
();
// for (const Node* node : sorted)
// {
...
...
This diff is collapsed.
Click to expand it.
test/pass_manager.cpp
View file @
118e0679
...
...
@@ -40,15 +40,28 @@ TEST(pass_manager, add)
auto
graph
=
make_test_graph
();
size_t
node_count
=
get_node_count
(
graph
->
get_result
());
pass_manager
.
run_passes
(
graph
.
get
());
auto
sorted
=
pass_manager
.
get_call_graph
();
auto
sorted
=
graph
->
get_ordered_ops
();
EXPECT_EQ
(
node_count
,
sorted
.
size
());
EXPECT_TRUE
(
validate_list
(
sorted
));
}
TEST
(
pass_manager
,
dependency
)
TEST
(
pass_manager
,
module_add_function
)
{
pass
::
Manager
pass_manager
;
// First create "f(A,B,C) = (A+B)*C".
auto
shape
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
C
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_f
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
f
=
make_shared
<
Function
>
((
A
+
B
)
*
C
,
rt_f
,
op
::
Parameters
{
A
,
B
,
C
});
pass_manager
.
register_pass
<
pass
::
TopologicalSort
>
();
EXPECT_THROW
(
pass_manager
.
register_pass
<
pass
::
AssignTensors
>
(),
runtime_error
);
// Now make "g(X,Y,Z) = f(X,Y,Z) + f(X,Y,Z)"
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Y
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Z
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_g
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
g
=
make_shared
<
Function
>
(
make_shared
<
op
::
FunctionCall
>
(
f
,
Nodes
{
X
,
Y
,
Z
})
+
make_shared
<
op
::
FunctionCall
>
(
f
,
Nodes
{
X
,
Y
,
Z
}),
rt_g
,
op
::
Parameters
{
X
,
Y
,
Z
});
}
This diff is collapsed.
Click to expand it.
test/pass_memory_layout.cpp
View file @
118e0679
...
...
@@ -218,7 +218,7 @@ TEST(memory_layout, basic)
auto
graph
=
make_test_graph
();
pass_manager
.
run_passes
(
graph
);
auto
sorted
=
pass_manager
.
get_call_graph
();
auto
sorted
=
graph
->
get_ordered_ops
();
size_t
temporary_pool_size
=
pass_manager
.
get_state
().
get_temporary_pool_size
();
EXPECT_EQ
(
12
,
temporary_pool_size
);
}
This diff is collapsed.
Click to expand it.
test/topological_sort.cpp
View file @
118e0679
...
...
@@ -21,10 +21,10 @@
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/collect_functions.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/util.hpp"
#include "ngraph/visualize.hpp"
#include "test_tools.hpp"
using
namespace
std
;
...
...
@@ -69,7 +69,7 @@ TEST(topological_sort, basic)
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
TopologicalSort
>
();
pass_manager
.
run_passes
(
f0
);
auto
sorted_list
=
pass_manager
.
get_call_graph
();
auto
sorted_list
=
f0
->
get_ordered_ops
();
size_t
node_count
=
get_node_count
(
r0
);
...
...
@@ -121,7 +121,7 @@ TEST(benchmark, topological_sort)
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
TopologicalSort
>
();
pass_manager
.
run_passes
(
f0
);
auto
sorted_list
=
pass_manager
.
get_call_graph
();
auto
sorted_list
=
f0
->
get_ordered_ops
();
timer
.
stop
();
NGRAPH_INFO
<<
"topological sort took "
<<
timer
.
get_milliseconds
()
<<
"ms"
;
...
...
@@ -135,3 +135,51 @@ TEST(benchmark, topological_sort)
timer
.
stop
();
NGRAPH_INFO
<<
"delete nodes took "
<<
timer
.
get_milliseconds
()
<<
"ms"
;
}
TEST
(
topological_sort
,
collect_functions
)
{
// First create "f(A,B,C) = (A+B)*C".
auto
shape
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
C
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_f
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
f
=
make_shared
<
Function
>
((
A
+
B
)
*
C
,
rt_f
,
op
::
Parameters
{
A
,
B
,
C
},
"f"
);
// Now make "g(X,Y,Z) = f(X,Y,Z) + f(X,Y,Z)"
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Y
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Z
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_g
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
g
=
make_shared
<
Function
>
(
make_shared
<
op
::
FunctionCall
>
(
f
,
Nodes
{
X
,
Y
,
Z
})
+
make_shared
<
op
::
FunctionCall
>
(
f
,
Nodes
{
X
,
Y
,
Z
}),
rt_g
,
op
::
Parameters
{
X
,
Y
,
Z
},
"g"
);
// Now make "h(X,Y,Z) = g(X,Y,Z) + g(X,Y,Z)"
auto
X1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Y1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Z1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_h
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
h
=
make_shared
<
Function
>
(
make_shared
<
op
::
FunctionCall
>
(
g
,
Nodes
{
X1
,
Y1
,
Z1
})
+
make_shared
<
op
::
FunctionCall
>
(
g
,
Nodes
{
X1
,
Y1
,
Z1
}),
rt_h
,
op
::
Parameters
{
X1
,
Y1
,
Z1
},
"h"
);
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
CollectFunctions
>
();
pass_manager
.
run_passes
(
h
);
set
<
string
>
expected
=
{
"f"
,
"g"
,
"h"
};
auto
functions
=
pass_manager
.
get_state
().
get_functions
();
vector
<
string
>
fnames
;
for
(
Function
*
func
:
functions
)
{
fnames
.
push_back
(
func
->
get_name
());
}
EXPECT_EQ
(
expected
.
size
(),
functions
.
size
());
EXPECT_TRUE
(
contains
(
fnames
,
"f"
));
EXPECT_TRUE
(
contains
(
fnames
,
"g"
));
EXPECT_TRUE
(
contains
(
fnames
,
"h"
));
}
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment