Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
118e0679
Commit
118e0679
authored
Oct 06, 2017
by
Jai Menon
Committed by
GitHub
Oct 06, 2017
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into jmenon/codegen
parents
c2ff1508
c3dfdf5a
Show whitespace changes
Inline
Side-by-side
Showing
36 changed files
with
474 additions
and
403 deletions
+474
-403
CMakeLists.txt
src/ngraph/CMakeLists.txt
+2
-3
function.cpp
src/ngraph/function.cpp
+76
-2
function.hpp
src/ngraph/function.hpp
+25
-2
node.cpp
src/ngraph/node.cpp
+26
-0
node.hpp
src/ngraph/node.hpp
+2
-0
assign_tensors.cpp
src/ngraph/pass/assign_tensors.cpp
+3
-21
assign_tensors.hpp
src/ngraph/pass/assign_tensors.hpp
+3
-5
collect_functions.cpp
src/ngraph/pass/collect_functions.cpp
+34
-29
collect_functions.hpp
src/ngraph/pass/collect_functions.hpp
+19
-1
dump_sorted.cpp
src/ngraph/pass/dump_sorted.cpp
+6
-3
dump_sorted.hpp
src/ngraph/pass/dump_sorted.hpp
+3
-3
liveness.cpp
src/ngraph/pass/liveness.cpp
+1
-19
liveness.hpp
src/ngraph/pass/liveness.hpp
+3
-5
manager.cpp
src/ngraph/pass/manager.cpp
+70
-63
manager.hpp
src/ngraph/pass/manager.hpp
+7
-36
manager_state.cpp
src/ngraph/pass/manager_state.cpp
+17
-20
manager_state.hpp
src/ngraph/pass/manager_state.hpp
+16
-9
memory_layout.cpp
src/ngraph/pass/memory_layout.cpp
+1
-19
memory_layout.hpp
src/ngraph/pass/memory_layout.hpp
+3
-5
memory_visualize.cpp
src/ngraph/pass/memory_visualize.cpp
+8
-8
memory_visualize.hpp
src/ngraph/pass/memory_visualize.hpp
+3
-5
pass.cpp
src/ngraph/pass/pass.cpp
+2
-2
pass.hpp
src/ngraph/pass/pass.hpp
+42
-2
propagate_types.cpp
src/ngraph/pass/propagate_types.cpp
+2
-2
propagate_types.hpp
src/ngraph/pass/propagate_types.hpp
+3
-3
topological_sort.cpp
src/ngraph/pass/topological_sort.cpp
+7
-5
topological_sort.hpp
src/ngraph/pass/topological_sort.hpp
+3
-3
tree_pass.cpp
src/ngraph/pass/tree_pass.cpp
+0
-15
visualize_tree.cpp
src/ngraph/pass/visualize_tree.cpp
+10
-5
visualize_tree.hpp
src/ngraph/pass/visualize_tree.hpp
+3
-3
external_function.cpp
src/ngraph/runtime/external_function.cpp
+3
-3
visualize.cpp
src/ngraph/visualize.cpp
+0
-92
pass_liveness.cpp
test/pass_liveness.cpp
+1
-1
pass_manager.cpp
test/pass_manager.cpp
+18
-5
pass_memory_layout.cpp
test/pass_memory_layout.cpp
+1
-1
topological_sort.cpp
test/topological_sort.cpp
+51
-3
No files found.
src/ngraph/CMakeLists.txt
View file @
118e0679
...
@@ -39,16 +39,16 @@ set (SRC
...
@@ -39,16 +39,16 @@ set (SRC
ops/unary_elementwise_arithmetic.cpp
ops/unary_elementwise_arithmetic.cpp
ops/unary_elementwise_builtin.cpp
ops/unary_elementwise_builtin.cpp
pass/assign_tensors.cpp
pass/assign_tensors.cpp
pass/c
all_pas
s.cpp
pass/c
ollect_function
s.cpp
pass/dump_sorted.cpp
pass/dump_sorted.cpp
pass/liveness.cpp
pass/liveness.cpp
pass/manager.cpp
pass/manager.cpp
pass/manager_state.cpp
pass/memory_layout.cpp
pass/memory_layout.cpp
pass/memory_visualize.cpp
pass/memory_visualize.cpp
pass/pass.cpp
pass/pass.cpp
pass/propagate_types.cpp
pass/propagate_types.cpp
pass/topological_sort.cpp
pass/topological_sort.cpp
pass/tree_pass.cpp
pass/visualize_tree.cpp
pass/visualize_tree.cpp
runtime/call_frame.cpp
runtime/call_frame.cpp
runtime/external_function.cpp
runtime/external_function.cpp
...
@@ -58,7 +58,6 @@ set (SRC
...
@@ -58,7 +58,6 @@ set (SRC
types/element_type.cpp
types/element_type.cpp
types/type.cpp
types/type.cpp
util.cpp
util.cpp
visualize.cpp
)
)
# find_program (GRAPHVIZ dot)
# find_program (GRAPHVIZ dot)
...
...
src/ngraph/function.cpp
View file @
118e0679
...
@@ -15,21 +15,95 @@
...
@@ -15,21 +15,95 @@
#include <memory>
#include <memory>
#include "ngraph/function.hpp"
#include "ngraph/function.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
size_t
Function
::
m_next_instance_id
=
0
;
Function
::
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
Function
::
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
shared_ptr
<
ValueType
>&
result_type
,
const
std
::
shared_ptr
<
ValueType
>&
result_type
,
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
)
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
,
const
std
::
string
&
name
)
:
m_result
(
result
)
:
m_result
(
result
)
,
m_parameters
(
parameters
)
,
m_parameters
(
parameters
)
,
m_name
(
"Function"
)
,
m_name
(
name
)
,
m_result_type
(
result_type
)
,
m_result_type
(
result_type
)
,
m_ordered_ops_valid
(
false
)
,
m_instance_id
(
m_next_instance_id
++
)
{
{
size_t
i
=
0
;
size_t
i
=
0
;
for
(
auto
parameter
:
parameters
)
for
(
auto
parameter
:
parameters
)
{
{
parameter
->
assign_function
(
this
,
i
++
);
parameter
->
assign_function
(
this
,
i
++
);
}
}
traverse_nodes
(
result
,
[
&
](
Node
*
node
)
{
m_ops
.
push_back
(
node
);
});
}
void
Function
::
set_ordered_ops
(
const
std
::
list
<
Node
*>&
ordered_ops
)
{
m_ordered_ops
=
ordered_ops
;
m_ordered_ops_valid
=
true
;
}
std
::
list
<
Node
*>&
Function
::
get_ops
()
{
return
m_ops
;
}
const
std
::
list
<
Node
*>&
Function
::
get_ops
()
const
{
return
m_ops
;
}
std
::
list
<
Node
*>&
Function
::
get_ordered_ops
()
{
if
(
!
m_ordered_ops_valid
)
{
throw
ngraph_error
(
"Access to ordered ops invalid"
);
}
return
m_ordered_ops
;
}
const
std
::
list
<
Node
*>&
Function
::
get_ordered_ops
()
const
{
if
(
!
m_ordered_ops_valid
)
{
throw
ngraph_error
(
"Access to ordered ops invalid"
);
}
return
m_ordered_ops
;
}
std
::
string
Function
::
get_name
()
const
{
string
rc
;
if
(
m_name
.
empty
())
{
rc
=
"Function_"
+
to_string
(
m_instance_id
);
}
else
{
rc
=
m_name
;
}
return
rc
;
}
void
Function
::
set_name
(
const
string
&
name
)
{
if
(
m_name
.
empty
())
{
m_name
=
name
;
}
else
{
throw
ngraph_error
(
"Function name may be set exactly once"
);
}
}
std
::
ostream
&
ngraph
::
operator
<<
(
std
::
ostream
&
out
,
const
Function
&
f
)
{
out
<<
"Function("
<<
f
.
get_name
()
<<
")"
;
return
out
;
}
}
src/ngraph/function.hpp
View file @
118e0679
...
@@ -15,11 +15,13 @@
...
@@ -15,11 +15,13 @@
#pragma once
#pragma once
#include <initializer_list>
#include <initializer_list>
#include <list>
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/parameter.hpp"
...
@@ -34,7 +36,8 @@ namespace ngraph
...
@@ -34,7 +36,8 @@ namespace ngraph
public
:
public
:
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
shared_ptr
<
ValueType
>&
result_type
,
const
std
::
shared_ptr
<
ValueType
>&
result_type
,
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
);
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
,
const
std
::
string
&
name
=
""
);
std
::
shared_ptr
<
Node
>
get_result
()
{
return
m_result
;
}
std
::
shared_ptr
<
Node
>
get_result
()
{
return
m_result
;
}
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>
get_parameters
()
const
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>
get_parameters
()
const
...
@@ -42,11 +45,31 @@ namespace ngraph
...
@@ -42,11 +45,31 @@ namespace ngraph
return
m_parameters
;
return
m_parameters
;
}
}
const
std
::
shared_ptr
<
ValueType
>
get_result_type
()
const
{
return
m_result_type
;
}
const
std
::
shared_ptr
<
ValueType
>
get_result_type
()
const
{
return
m_result_type
;
}
std
::
string
get_name
()
const
{
return
m_name
;
}
std
::
string
get_name
()
const
;
void
set_name
(
const
std
::
string
&
name
);
std
::
list
<
Node
*>&
get_ops
();
const
std
::
list
<
Node
*>&
get_ops
()
const
;
std
::
list
<
Node
*>&
get_ordered_ops
();
const
std
::
list
<
Node
*>&
get_ordered_ops
()
const
;
void
set_ordered_ops
(
const
std
::
list
<
Node
*>&
);
void
set_ordered_ops_valid
()
{
m_ordered_ops_valid
=
true
;
}
void
clear_ordered_ops_valid
()
{
m_ordered_ops_valid
=
false
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Function
&
);
protected
:
protected
:
std
::
shared_ptr
<
Node
>
m_result
;
std
::
shared_ptr
<
Node
>
m_result
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
op
::
Parameter
>>
m_parameters
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
op
::
Parameter
>>
m_parameters
;
std
::
string
m_name
;
std
::
string
m_name
;
std
::
shared_ptr
<
ValueType
>
m_result_type
;
std
::
shared_ptr
<
ValueType
>
m_result_type
;
bool
m_ordered_ops_valid
;
std
::
list
<
Node
*>
m_ordered_ops
;
std
::
list
<
Node
*>
m_ops
;
private
:
Function
(
const
Function
&
)
=
delete
;
Function
(
const
Function
&&
)
=
delete
;
static
size_t
m_next_instance_id
;
size_t
m_instance_id
;
};
};
}
}
src/ngraph/node.cpp
View file @
118e0679
...
@@ -104,6 +104,32 @@ std::string Node::get_node_id() const
...
@@ -104,6 +104,32 @@ std::string Node::get_node_id() const
return
ss
.
str
();
return
ss
.
str
();
}
}
std
::
string
Node
::
get_name
()
const
{
string
rc
;
if
(
m_name
.
empty
())
{
rc
=
description
()
+
"_"
+
to_string
(
m_instance_id
);
}
else
{
rc
=
m_name
;
}
return
rc
;
}
void
Node
::
set_name
(
const
string
&
name
)
{
if
(
m_name
.
empty
())
{
m_name
=
name
;
}
else
{
throw
ngraph_error
(
"Node name may be set exactly once"
);
}
}
namespace
ngraph
namespace
ngraph
{
{
ostream
&
operator
<<
(
ostream
&
out
,
const
Node
&
node
)
ostream
&
operator
<<
(
ostream
&
out
,
const
Node
&
node
)
...
...
src/ngraph/node.hpp
View file @
118e0679
...
@@ -55,6 +55,8 @@ namespace ngraph
...
@@ -55,6 +55,8 @@ namespace ngraph
public
:
public
:
/// The class name, must not contain spaces
/// The class name, must not contain spaces
virtual
std
::
string
description
()
const
=
0
;
virtual
std
::
string
description
()
const
=
0
;
std
::
string
get_name
()
const
;
void
set_name
(
const
std
::
string
&
name
);
/// Propagate types and check arguments for consistency
/// Propagate types and check arguments for consistency
virtual
void
propagate_types
()
=
0
;
virtual
void
propagate_types
()
=
0
;
...
...
src/ngraph/pass/assign_tensors.cpp
View file @
118e0679
...
@@ -25,15 +25,15 @@
...
@@ -25,15 +25,15 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
bool
pass
::
AssignTensors
::
run_on_call_
list
(
std
::
list
<
Node
*>&
node_list
)
bool
pass
::
AssignTensors
::
run_on_call_
graph
(
list
<
Node
*>&
nodes
)
{
{
for
(
Node
*
node
:
node
_list
)
for
(
Node
*
node
:
node
s
)
{
{
try
try
{
{
// We need to set the nodes is_output state prior to call assign_tensors
// We need to set the nodes is_output state prior to call assign_tensors
// so that the output state can be passes to the constructed tensors.
// so that the output state can be passes to the constructed tensors.
if
(
node
==
get_state
().
get_function
(
)
->
get_result
().
get
())
if
(
node
==
get_state
().
get_function
s
().
at
(
0
)
->
get_result
().
get
())
{
{
node
->
set_is_output
();
node
->
set_is_output
();
}
}
...
@@ -50,21 +50,3 @@ bool pass::AssignTensors::run_on_call_list(std::list<Node*>& node_list)
...
@@ -50,21 +50,3 @@ bool pass::AssignTensors::run_on_call_list(std::list<Node*>& node_list)
}
}
return
false
;
return
false
;
}
}
void
pass
::
AssignTensors
::
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
registered_passes
)
const
{
bool
found_propagate_types
=
false
;
for
(
auto
pass
:
registered_passes
)
{
if
(
dynamic_pointer_cast
<
PropagateTypes
>
(
pass
))
{
found_propagate_types
=
true
;
}
}
if
(
!
found_propagate_types
)
{
throw
runtime_error
(
"Dependency 'PropagateTypes' not found for pass 'AssignTensors'"
);
}
}
src/ngraph/pass/assign_tensors.hpp
View file @
118e0679
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#pragma once
#pragma once
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -25,12 +25,10 @@ namespace ngraph
...
@@ -25,12 +25,10 @@ namespace ngraph
class
Node
;
class
Node
;
}
}
class
ngraph
::
pass
::
AssignTensors
:
public
Call
Base
class
ngraph
::
pass
::
AssignTensors
:
public
Call
GraphPass
{
{
public
:
public
:
virtual
bool
run_on_call_list
(
std
::
list
<
Node
*>&
)
override
;
virtual
bool
run_on_call_graph
(
std
::
list
<
Node
*>&
nodes
)
override
;
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
)
const
override
;
private
:
private
:
};
};
src/ngraph/
visualize.h
pp
→
src/ngraph/
pass/collect_functions.c
pp
View file @
118e0679
...
@@ -12,34 +12,39 @@
...
@@ -12,34 +12,39 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/pass/collect_functions.hpp"
#include "ngraph/function.hpp"
#include <functional>
#include "ngraph/log.hpp"
#include <memory>
#include "ngraph/node.hpp"
#include <set>
#include "ngraph/ops/function_call.hpp"
#include <sstream>
#include "ngraph/ops/op.hpp"
#include "ngraph/util.hpp"
namespace
ngraph
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
::
pass
;
bool
CollectFunctions
::
run_on_function
(
ngraph
::
Function
*
func
)
{
{
class
Visualize
;
set
<
Function
*>
functions
;
class
Node
;
deque
<
Function
*>
stack
;
using
node_ptr
=
std
::
shared_ptr
<
Node
>
;
stack
.
push_back
(
func
);
while
(
stack
.
empty
()
==
false
)
{
Function
*
f
=
stack
.
front
();
stack
.
pop_front
();
functions
.
insert
(
f
);
traverse_nodes
(
f
->
get_result
(),
[
&
](
Node
*
node
)
{
op
::
FunctionCall
*
fc
=
dynamic_cast
<
op
::
FunctionCall
*>
(
node
);
if
(
fc
)
{
stack
.
push_back
(
fc
->
get_function
().
get
());
}
});
}
get_state
().
set_functions
(
functions
);
return
false
;
}
}
class
ngraph
::
Visualize
{
public
:
Visualize
(
const
std
::
string
&
name
=
"ngraph"
);
void
add
(
node_ptr
);
void
save_dot
(
const
std
::
string
&
path
)
const
;
private
:
std
::
string
add_attributes
(
const
Node
*
node
);
std
::
string
get_attributes
(
const
Node
*
node
);
std
::
stringstream
m_ss
;
std
::
string
m_name
;
std
::
set
<
const
Node
*>
m_nodes_with_attributes
;
};
src/ngraph/pass/c
all_pass.c
pp
→
src/ngraph/pass/c
ollect_functions.h
pp
View file @
118e0679
...
@@ -12,4 +12,22 @@
...
@@ -12,4 +12,22 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#include "ngraph/pass/call_pass.hpp"
#pragma once
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
namespace
pass
{
class
CollectFunctions
;
}
}
class
ngraph
::
pass
::
CollectFunctions
:
public
FunctionPass
{
public
:
bool
run_on_function
(
ngraph
::
Function
*
)
override
;
private
:
};
src/ngraph/pass/dump_sorted.cpp
View file @
118e0679
...
@@ -27,14 +27,16 @@ pass::DumpSorted::DumpSorted(const string& output_file)
...
@@ -27,14 +27,16 @@ pass::DumpSorted::DumpSorted(const string& output_file)
{
{
}
}
bool
pass
::
DumpSorted
::
run_on_
call_list
(
list
<
Node
*>&
node
s
)
bool
pass
::
DumpSorted
::
run_on_
module
(
vector
<
Function
*>&
function
s
)
{
{
ofstream
out
{
m_output_file
};
ofstream
out
{
m_output_file
};
if
(
out
)
if
(
out
)
{
{
for
(
const
Node
*
node
:
node
s
)
for
(
Function
*
f
:
function
s
)
{
{
out
<<
node
->
get_node_id
()
<<
"("
;
for
(
const
Node
*
node
:
f
->
get_ordered_ops
())
{
out
<<
node
->
get_name
()
<<
"("
;
vector
<
string
>
inputs
;
vector
<
string
>
inputs
;
for
(
const
Input
&
input
:
node
->
get_inputs
())
for
(
const
Input
&
input
:
node
->
get_inputs
())
{
{
...
@@ -65,6 +67,7 @@ bool pass::DumpSorted::run_on_call_list(list<Node*>& nodes)
...
@@ -65,6 +67,7 @@ bool pass::DumpSorted::run_on_call_list(list<Node*>& nodes)
}
}
}
}
}
}
}
return
false
;
return
false
;
}
}
src/ngraph/pass/dump_sorted.hpp
View file @
118e0679
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
#include <string>
#include <string>
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -27,12 +27,12 @@ namespace ngraph
...
@@ -27,12 +27,12 @@ namespace ngraph
class
Node
;
class
Node
;
}
}
class
ngraph
::
pass
::
DumpSorted
:
public
CallBase
class
ngraph
::
pass
::
DumpSorted
:
public
ModulePass
{
{
public
:
public
:
DumpSorted
(
const
std
::
string
&
output_file
);
DumpSorted
(
const
std
::
string
&
output_file
);
virtual
bool
run_on_
call_list
(
std
::
list
<
Node
*>&
)
override
;
virtual
bool
run_on_
module
(
std
::
vector
<
Function
*>&
)
override
;
private
:
private
:
const
std
::
string
m_output_file
;
const
std
::
string
m_output_file
;
...
...
src/ngraph/pass/liveness.cpp
View file @
118e0679
...
@@ -27,7 +27,7 @@ using namespace std;
...
@@ -27,7 +27,7 @@ using namespace std;
using
namespace
ngraph
;
using
namespace
ngraph
;
using
namespace
ngraph
::
descriptor
;
using
namespace
ngraph
::
descriptor
;
bool
pass
::
Liveness
::
run_on_call_
list
(
list
<
Node
*>&
ops
)
bool
pass
::
Liveness
::
run_on_call_
graph
(
list
<
Node
*>&
ops
)
{
{
unordered_set
<
Tensor
*>
currently_live
;
unordered_set
<
Tensor
*>
currently_live
;
...
@@ -123,24 +123,6 @@ bool pass::Liveness::run_on_call_list(list<Node*>& ops)
...
@@ -123,24 +123,6 @@ bool pass::Liveness::run_on_call_list(list<Node*>& ops)
return
false
;
return
false
;
}
}
void
pass
::
Liveness
::
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
registered_passes
)
const
{
bool
found_propagate_types
=
false
;
for
(
auto
pass
:
registered_passes
)
{
if
(
dynamic_pointer_cast
<
AssignTensors
>
(
pass
))
{
found_propagate_types
=
true
;
}
}
if
(
!
found_propagate_types
)
{
throw
runtime_error
(
"Dependency 'PropagateTypes' not found for pass 'AssignTensors'"
);
}
}
bool
pass
::
Liveness
::
is_temporary
(
const
Tensor
&
tensor
)
bool
pass
::
Liveness
::
is_temporary
(
const
Tensor
&
tensor
)
{
{
return
tensor
.
is_persistent
()
==
false
&&
tensor
.
is_input
()
==
false
&&
return
tensor
.
is_persistent
()
==
false
&&
tensor
.
is_input
()
==
false
&&
...
...
src/ngraph/pass/liveness.hpp
View file @
118e0679
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
#pragma once
#pragma once
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -26,12 +26,10 @@ namespace ngraph
...
@@ -26,12 +26,10 @@ namespace ngraph
class
Node
;
class
Node
;
}
}
class
ngraph
::
pass
::
Liveness
:
public
Call
Base
class
ngraph
::
pass
::
Liveness
:
public
Call
GraphPass
{
{
public
:
public
:
virtual
bool
run_on_call_list
(
std
::
list
<
Node
*>&
)
override
;
virtual
bool
run_on_call_graph
(
std
::
list
<
Node
*>&
)
override
;
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
)
const
override
;
private
:
private
:
bool
is_temporary
(
const
descriptor
::
Tensor
&
);
bool
is_temporary
(
const
descriptor
::
Tensor
&
);
...
...
src/ngraph/pass/manager.cpp
View file @
118e0679
...
@@ -19,40 +19,11 @@
...
@@ -19,40 +19,11 @@
#include "ngraph/log.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/pass.hpp"
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
Function
*
ngraph
::
pass
::
ManagerState
::
get_function
()
{
return
m_function
;
}
void
ngraph
::
pass
::
ManagerState
::
set_function
(
Function
*
func
)
{
m_function
=
func
;
}
size_t
ngraph
::
pass
::
ManagerState
::
get_temporary_pool_size
()
{
return
m_temporary_pool_size
;
}
void
ngraph
::
pass
::
ManagerState
::
set_temporary_pool_size
(
size_t
size
)
{
m_temporary_pool_size
=
size
;
}
std
::
list
<
Node
*>&
ngraph
::
pass
::
ManagerState
::
get_call_graph
()
{
return
m_call_graph
;
}
const
std
::
list
<
Node
*>&
ngraph
::
pass
::
ManagerState
::
get_call_graph
()
const
{
return
m_call_graph
;
}
ngraph
::
pass
::
Manager
::
Manager
()
ngraph
::
pass
::
Manager
::
Manager
()
{
{
}
}
...
@@ -65,26 +36,6 @@ void ngraph::pass::Manager::initialize_default_passes()
...
@@ -65,26 +36,6 @@ void ngraph::pass::Manager::initialize_default_passes()
{
{
}
}
void
ngraph
::
pass
::
Manager
::
register_pass_ptr
(
std
::
shared_ptr
<
TreeBase
>
p
)
{
if
(
p
==
nullptr
)
{
throw
invalid_argument
(
"null pass registered"
);
}
p
->
check_dependencies
(
m_tree_passes
);
m_tree_passes
.
push_back
(
p
);
}
void
ngraph
::
pass
::
Manager
::
register_pass_ptr
(
std
::
shared_ptr
<
CallBase
>
p
)
{
if
(
p
==
nullptr
)
{
throw
invalid_argument
(
"null pass registered"
);
}
p
->
check_dependencies
(
m_call_passes
);
m_call_passes
.
push_back
(
p
);
}
void
ngraph
::
pass
::
Manager
::
run_passes
(
shared_ptr
<
Function
>
func
)
void
ngraph
::
pass
::
Manager
::
run_passes
(
shared_ptr
<
Function
>
func
)
{
{
run_passes
(
func
.
get
());
run_passes
(
func
.
get
());
...
@@ -92,23 +43,79 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
...
@@ -92,23 +43,79 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
void
ngraph
::
pass
::
Manager
::
run_passes
(
Function
*
func
)
void
ngraph
::
pass
::
Manager
::
run_passes
(
Function
*
func
)
{
{
m_state
.
set_function
(
func
);
vector
<
Function
*>
fs
=
{
func
};
for
(
shared_ptr
<
TreeBase
>
p
:
m_tree_passes
)
get_state
().
set_functions
(
fs
);
for
(
shared_ptr
<
PassBase
>
pass
:
m_pass_list
)
{
pass
->
set_state
(
get_state
());
auto
module_pass
=
dynamic_pointer_cast
<
ModulePass
>
(
pass
);
auto
function_pass
=
dynamic_pointer_cast
<
FunctionPass
>
(
pass
);
auto
node_pass
=
dynamic_pointer_cast
<
NodePass
>
(
pass
);
auto
call_graph_pass
=
dynamic_pointer_cast
<
CallGraphPass
>
(
pass
);
if
(
module_pass
)
{
{
p
->
set_state
(
get_state
());
module_pass
->
run_on_module
(
fs
);
p
->
run_on_tree
(
func
->
get_result
());
}
}
else
if
(
function_pass
)
for
(
shared_ptr
<
CallBase
>&
p
:
m_call_passes
)
{
for
(
Function
*
f
:
fs
)
{
{
p
->
set_state
(
get_state
());
function_pass
->
run_on_function
(
f
);
p
->
run_on_call_list
(
get_state
().
get_call_graph
());
}
}
}
}
else
if
(
node_pass
)
const
std
::
list
<
ngraph
::
Node
*>&
ngraph
::
pass
::
Manager
::
get_call_graph
()
const
{
{
for
(
Function
*
f
:
fs
)
return
m_state
.
get_call_graph
();
{
for
(
Node
*
n
:
f
->
get_ops
())
{
node_pass
->
run_on_node
(
n
);
}
}
}
else
if
(
call_graph_pass
)
{
for
(
Function
*
f
:
fs
)
{
call_graph_pass
->
run_on_call_graph
(
f
->
get_ordered_ops
());
}
}
}
// for (shared_ptr<ModulePass>& p : m_module_passes)
// {
// p->set_state(get_state());
// p->run_on_module(fs);
// }
// for (Function* f : fs)
// {
// for (shared_ptr<FunctionPass> p : m_function_passes)
// {
// p->set_state(get_state());
// p->run_on_function(f);
// }
// }
// for (Function* f : fs)
// {
// NGRAPH_INFO;
// for (shared_ptr<NodePass> p : m_node_passes)
// {
// for (Node* node : f->get_ops())
// {
// NGRAPH_INFO;
// p->set_state(get_state());
// p->run_on_node(node);
// }
// }
// }
// for (shared_ptr<CallGraphPass>& p : m_call_graph_passes)
// {
// p->set_state(get_state());
// p->run_on_call_graph(func->get_ordered_ops());
// }
}
}
ngraph
::
pass
::
ManagerState
&
ngraph
::
pass
::
Manager
::
get_state
()
ngraph
::
pass
::
ManagerState
&
ngraph
::
pass
::
Manager
::
get_state
()
...
...
src/ngraph/pass/manager.hpp
View file @
118e0679
...
@@ -18,8 +18,8 @@
...
@@ -18,8 +18,8 @@
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "ngraph/pass/
call_pass
.hpp"
#include "ngraph/pass/
manager_state
.hpp"
#include "ngraph/pass/
tree_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -33,24 +33,6 @@ namespace ngraph
...
@@ -33,24 +33,6 @@ namespace ngraph
class
Function
;
class
Function
;
}
}
class
ngraph
::
pass
::
ManagerState
{
public
:
Function
*
get_function
();
void
set_function
(
Function
*
);
size_t
get_temporary_pool_size
();
void
set_temporary_pool_size
(
size_t
);
std
::
list
<
Node
*>&
get_call_graph
();
const
std
::
list
<
Node
*>&
get_call_graph
()
const
;
private
:
Function
*
m_function
=
nullptr
;
size_t
m_temporary_pool_size
=
0
;
std
::
list
<
Node
*>
m_call_graph
;
};
class
ngraph
::
pass
::
Manager
class
ngraph
::
pass
::
Manager
{
{
public
:
public
:
...
@@ -62,29 +44,18 @@ public:
...
@@ -62,29 +44,18 @@ public:
template
<
typename
T
,
class
...
Args
>
template
<
typename
T
,
class
...
Args
>
void
register_pass
(
Args
...
args
)
void
register_pass
(
Args
...
args
)
{
{
static_assert
(
std
::
is_base_of
<
pass
::
Base
,
T
>::
value
,
"pass not derived from pass base"
);
static_assert
(
std
::
is_base_of
<
pass
::
PassBase
,
T
>::
value
,
"pass not derived from pass base"
);
if
(
std
::
is_base_of
<
TreeBase
,
T
>::
value
)
auto
pass
=
std
::
make_shared
<
T
>
(
args
...);
{
auto
pass_base
=
std
::
static_pointer_cast
<
PassBase
>
(
pass
);
register_pass_ptr
(
std
::
make_shared
<
T
>
(
args
...));
m_pass_list
.
push_back
(
pass_base
);
}
else
if
(
std
::
is_base_of
<
CallBase
,
T
>::
value
)
{
register_pass_ptr
(
std
::
make_shared
<
T
>
(
args
...));
}
}
}
void
run_passes
(
Function
*
);
void
run_passes
(
Function
*
);
void
run_passes
(
std
::
shared_ptr
<
Function
>
);
void
run_passes
(
std
::
shared_ptr
<
Function
>
);
const
std
::
list
<
Node
*>&
get_call_graph
()
const
;
ManagerState
&
get_state
();
ManagerState
&
get_state
();
private
:
private
:
void
register_pass_ptr
(
std
::
shared_ptr
<
TreeBase
>
);
std
::
vector
<
std
::
shared_ptr
<
PassBase
>>
m_pass_list
;
void
register_pass_ptr
(
std
::
shared_ptr
<
CallBase
>
);
std
::
vector
<
std
::
shared_ptr
<
TreeBase
>>
m_tree_passes
;
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>
m_call_passes
;
ManagerState
m_state
;
ManagerState
m_state
;
};
};
src/ngraph/pass/
tree_pass.h
pp
→
src/ngraph/pass/
manager_state.c
pp
View file @
118e0679
...
@@ -12,31 +12,28 @@
...
@@ -12,31 +12,28 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
#pragma once
#include <iostream>
#include <list>
#include <memory>
#include <memory>
#include <vector>
#include "ngraph/pass/pass.hpp"
#include "ngraph/function.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/manager_state.hpp"
namespace
ngraph
using
namespace
std
;
{
using
namespace
ngraph
;
namespace
pass
{
class
TreeBase
;
}
class
Node
;
vector
<
Function
*>&
ngraph
::
pass
::
ManagerState
::
get_functions
()
{
return
m_function_list
;
}
}
class
ngraph
::
pass
::
TreeBase
:
public
Base
size_t
ngraph
::
pass
::
ManagerState
::
get_temporary_pool_size
()
{
{
public
:
return
m_temporary_pool_size
;
virtual
~
TreeBase
()
{}
}
// return true if changes were made to the tree
virtual
bool
run_on_tree
(
std
::
shared_ptr
<
Node
>
)
=
0
;
// derived class throws exception if its dependencies have not been met
void
ngraph
::
pass
::
ManagerState
::
set_temporary_pool_size
(
size_t
size
)
virtual
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
TreeBase
>>&
)
const
{}
{
};
m_temporary_pool_size
=
size
;
}
src/ngraph/pass/
call_pass
.hpp
→
src/ngraph/pass/
manager_state
.hpp
View file @
118e0679
...
@@ -14,29 +14,36 @@
...
@@ -14,29 +14,36 @@
#pragma once
#pragma once
#include <list>
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "ngraph/pass/pass.hpp"
namespace
ngraph
namespace
ngraph
{
{
namespace
pass
namespace
pass
{
{
class
CallBas
e
;
class
ManagerStat
e
;
}
}
class
Node
;
class
Node
;
class
Function
;
}
}
class
ngraph
::
pass
::
CallBase
:
public
Bas
e
class
ngraph
::
pass
::
ManagerStat
e
{
{
public
:
public
:
virtual
~
CallBase
()
{}
std
::
vector
<
Function
*>&
get_functions
();
virtual
bool
run_on_call_list
(
std
::
list
<
Node
*>&
)
=
0
;
template
<
typename
T
>
void
set_functions
(
const
T
&
collection
)
{
m_function_list
.
clear
();
m_function_list
.
insert
(
m_function_list
.
begin
(),
collection
.
begin
(),
collection
.
end
());
}
size_t
get_temporary_pool_size
();
void
set_temporary_pool_size
(
size_t
);
// derived class throws exception if its dependencies have not been met
virtual
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
)
const
{}
private
:
private
:
size_t
m_temporary_pool_size
=
0
;
std
::
vector
<
Function
*>
m_function_list
;
};
};
src/ngraph/pass/memory_layout.cpp
View file @
118e0679
...
@@ -27,7 +27,7 @@ using namespace std;
...
@@ -27,7 +27,7 @@ using namespace std;
using
namespace
ngraph
;
using
namespace
ngraph
;
using
namespace
ngraph
::
descriptor
;
using
namespace
ngraph
::
descriptor
;
bool
pass
::
MemoryLayout
::
run_on_call_
list
(
std
::
list
<
Node
*>&
node_list
)
bool
pass
::
MemoryLayout
::
run_on_call_
graph
(
std
::
list
<
Node
*>&
node_list
)
{
{
MemoryManager
mm
;
MemoryManager
mm
;
for
(
const
Node
*
node
:
node_list
)
for
(
const
Node
*
node
:
node_list
)
...
@@ -47,24 +47,6 @@ bool pass::MemoryLayout::run_on_call_list(std::list<Node*>& node_list)
...
@@ -47,24 +47,6 @@ bool pass::MemoryLayout::run_on_call_list(std::list<Node*>& node_list)
return
false
;
return
false
;
}
}
void
pass
::
MemoryLayout
::
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
registered_passes
)
const
{
bool
found_propagate_types
=
false
;
for
(
auto
pass
:
registered_passes
)
{
if
(
dynamic_pointer_cast
<
Liveness
>
(
pass
))
{
found_propagate_types
=
true
;
}
}
if
(
!
found_propagate_types
)
{
throw
runtime_error
(
"Dependency 'PropagateTypes' not found for pass 'AssignTensors'"
);
}
}
pass
::
MemoryManager
::
node
::
node
(
size_t
size
,
block_state
state
)
pass
::
MemoryManager
::
node
::
node
(
size_t
size
,
block_state
state
)
:
m_size
{
size
}
:
m_size
{
size
}
,
m_state
{
state
}
,
m_state
{
state
}
...
...
src/ngraph/pass/memory_layout.hpp
View file @
118e0679
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
#include <list>
#include <list>
#include <sstream>
#include <sstream>
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -31,12 +31,10 @@ namespace ngraph
...
@@ -31,12 +31,10 @@ namespace ngraph
class
Node
;
class
Node
;
}
}
class
ngraph
::
pass
::
MemoryLayout
:
public
Call
Base
class
ngraph
::
pass
::
MemoryLayout
:
public
Call
GraphPass
{
{
public
:
public
:
virtual
bool
run_on_call_list
(
std
::
list
<
Node
*>&
)
override
;
virtual
bool
run_on_call_graph
(
std
::
list
<
Node
*>&
)
override
;
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
)
const
override
;
private
:
private
:
};
};
...
...
src/ngraph/pass/memory_visualize.cpp
View file @
118e0679
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include "memory_visualize.hpp"
#include "memory_visualize.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/util.hpp"
#include "ngraph/util.hpp"
...
@@ -31,11 +32,13 @@ pass::MemoryVisualize::MemoryVisualize(const string& filename)
...
@@ -31,11 +32,13 @@ pass::MemoryVisualize::MemoryVisualize(const string& filename)
{
{
}
}
bool
pass
::
MemoryVisualize
::
run_on_
call_list
(
list
<
Node
*>&
_node
s
)
bool
pass
::
MemoryVisualize
::
run_on_
module
(
vector
<
Function
*>&
function
s
)
{
{
const
list
<
Node
*>
nodes
=
_nodes
;
ofstream
file
(
m_filename
);
ofstream
file
(
m_filename
);
{
{
for
(
const
Function
*
f
:
functions
)
{
const
list
<
Node
*>
nodes
=
f
->
get_ordered_ops
();
file
<<
"<!DOCTYPE html>
\n
<html>
\n
"
;
file
<<
"<!DOCTYPE html>
\n
<html>
\n
"
;
file
<<
"<head>
\n
"
;
file
<<
"<head>
\n
"
;
file
<<
" <style>
\n
"
;
file
<<
" <style>
\n
"
;
...
@@ -89,13 +92,10 @@ bool pass::MemoryVisualize::run_on_call_list(list<Node*>& _nodes)
...
@@ -89,13 +92,10 @@ bool pass::MemoryVisualize::run_on_call_list(list<Node*>& _nodes)
// file << "<hr>\n";
// file << "<hr>\n";
file
<<
"</body>
\n
</html>
\n
"
;
file
<<
"</body>
\n
</html>
\n
"
;
}
}
}
return
false
;
return
false
;
}
}
void
pass
::
MemoryVisualize
::
check_dependencies
(
const
vector
<
shared_ptr
<
CallBase
>>&
deps
)
const
{
}
const
Node
*
pass
::
MemoryVisualize
::
find_largest_op
(
const
list
<
Node
*>&
nodes
)
const
Node
*
pass
::
MemoryVisualize
::
find_largest_op
(
const
list
<
Node
*>&
nodes
)
{
{
const
Node
*
largest_op
=
nullptr
;
const
Node
*
largest_op
=
nullptr
;
...
@@ -207,7 +207,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
...
@@ -207,7 +207,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
size_t
x2
=
((
usage
/
memory_footprint
)
*
scale
)
+
offset
;
size_t
x2
=
((
usage
/
memory_footprint
)
*
scale
)
+
offset
;
file
<<
"<text x=
\"
"
<<
0
<<
"
\"
y=
\"
"
<<
y
+
text_offset
<<
"
\"
fill=
\"
"
file
<<
"<text x=
\"
"
<<
0
<<
"
\"
y=
\"
"
<<
y
+
text_offset
<<
"
\"
fill=
\"
"
<<
"black"
<<
"black"
<<
"
\"
>"
<<
node
->
get_n
ode_id
()
<<
"</text>
\n
"
;
<<
"
\"
>"
<<
node
->
get_n
ame
()
<<
"</text>
\n
"
;
file
<<
"<line x1=
\"
"
<<
x1
<<
"
\"
y1=
\"
"
<<
y
<<
"
\"
x2=
\"
"
<<
x2
<<
"
\"
y2=
\"
"
<<
y
file
<<
"<line x1=
\"
"
<<
x1
<<
"
\"
y1=
\"
"
<<
y
<<
"
\"
x2=
\"
"
<<
x2
<<
"
\"
y2=
\"
"
<<
y
<<
"
\"
"
;
<<
"
\"
"
;
file
<<
" style=
\"
stroke:forestgreen;stroke-width:"
<<
stroke_width
<<
"
\"
/>
\n
"
;
file
<<
" style=
\"
stroke:forestgreen;stroke-width:"
<<
stroke_width
<<
"
\"
/>
\n
"
;
...
@@ -231,7 +231,7 @@ void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<Node*>&
...
@@ -231,7 +231,7 @@ void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<Node*>&
{
{
int
weight
=
compute_op_weight
(
exop
);
int
weight
=
compute_op_weight
(
exop
);
file
<<
" <tr>"
;
file
<<
" <tr>"
;
file
<<
"<td>"
<<
exop
->
get_n
ode_id
()
<<
"</td>"
;
file
<<
"<td>"
<<
exop
->
get_n
ame
()
<<
"</td>"
;
file
<<
"<td align=
\"
right
\"
>"
<<
weight
<<
"</td>"
;
file
<<
"<td align=
\"
right
\"
>"
<<
weight
<<
"</td>"
;
file
<<
"</tr>
\n
"
;
file
<<
"</tr>
\n
"
;
}
}
...
...
src/ngraph/pass/memory_visualize.hpp
View file @
118e0679
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
#include <limits>
#include <limits>
#include <list>
#include <list>
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -29,13 +29,11 @@ namespace ngraph
...
@@ -29,13 +29,11 @@ namespace ngraph
class
Node
;
class
Node
;
}
}
class
ngraph
::
pass
::
MemoryVisualize
:
public
CallBase
class
ngraph
::
pass
::
MemoryVisualize
:
public
ModulePass
{
{
public
:
public
:
MemoryVisualize
(
const
std
::
string
&
filename
);
MemoryVisualize
(
const
std
::
string
&
filename
);
virtual
bool
run_on_call_list
(
std
::
list
<
Node
*>&
)
override
;
virtual
bool
run_on_module
(
std
::
vector
<
Function
*>&
)
override
;
void
check_dependencies
(
const
std
::
vector
<
std
::
shared_ptr
<
CallBase
>>&
)
const
override
;
private
:
private
:
const
Node
*
find_largest_op
(
const
std
::
list
<
Node
*>&
nodes
);
const
Node
*
find_largest_op
(
const
std
::
list
<
Node
*>&
nodes
);
...
...
src/ngraph/pass/pass.cpp
View file @
118e0679
...
@@ -15,12 +15,12 @@
...
@@ -15,12 +15,12 @@
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/manager.hpp"
ngraph
::
pass
::
ManagerState
&
ngraph
::
pass
::
Base
::
get_state
()
ngraph
::
pass
::
ManagerState
&
ngraph
::
pass
::
Pass
Base
::
get_state
()
{
{
return
*
m_state
;
return
*
m_state
;
}
}
void
ngraph
::
pass
::
Base
::
set_state
(
ManagerState
&
state
)
void
ngraph
::
pass
::
Pass
Base
::
set_state
(
ManagerState
&
state
)
{
{
m_state
=
&
state
;
m_state
=
&
state
;
}
}
src/ngraph/pass/pass.hpp
View file @
118e0679
...
@@ -14,21 +14,33 @@
...
@@ -14,21 +14,33 @@
#pragma once
#pragma once
#include <list>
#include <memory>
#include <vector>
#include "ngraph/node.hpp"
namespace
ngraph
namespace
ngraph
{
{
namespace
pass
namespace
pass
{
{
class
Base
;
class
PassBase
;
class
ModulePass
;
class
FunctionPass
;
class
NodePass
;
class
CallGraphPass
;
class
Manager
;
class
Manager
;
class
ManagerState
;
class
ManagerState
;
}
}
class
Function
;
}
}
class
ngraph
::
pass
::
Base
class
ngraph
::
pass
::
Pass
Base
{
{
friend
class
Manager
;
friend
class
Manager
;
public
:
public
:
virtual
~
PassBase
()
{}
protected
:
protected
:
ManagerState
&
get_state
();
ManagerState
&
get_state
();
void
set_state
(
ManagerState
&
);
void
set_state
(
ManagerState
&
);
...
@@ -36,3 +48,31 @@ protected:
...
@@ -36,3 +48,31 @@ protected:
private
:
private
:
ManagerState
*
m_state
;
ManagerState
*
m_state
;
};
};
class
ngraph
::
pass
::
ModulePass
:
public
PassBase
{
public
:
virtual
~
ModulePass
()
{}
virtual
bool
run_on_module
(
std
::
vector
<
ngraph
::
Function
*>&
)
=
0
;
};
class
ngraph
::
pass
::
FunctionPass
:
public
PassBase
{
public
:
virtual
~
FunctionPass
()
{}
virtual
bool
run_on_function
(
ngraph
::
Function
*
)
=
0
;
};
class
ngraph
::
pass
::
NodePass
:
public
PassBase
{
public
:
virtual
~
NodePass
()
{}
virtual
bool
run_on_node
(
ngraph
::
Node
*
)
=
0
;
};
class
ngraph
::
pass
::
CallGraphPass
:
public
PassBase
{
public
:
virtual
~
CallGraphPass
()
{}
virtual
bool
run_on_call_graph
(
std
::
list
<
ngraph
::
Node
*>&
)
=
0
;
};
src/ngraph/pass/propagate_types.cpp
View file @
118e0679
...
@@ -20,9 +20,9 @@
...
@@ -20,9 +20,9 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
bool
pass
::
PropagateTypes
::
run_on_call_
list
(
std
::
list
<
Node
*>&
node_list
)
bool
pass
::
PropagateTypes
::
run_on_call_
graph
(
list
<
Node
*>&
nodes
)
{
{
for
(
Node
*
node
:
node
_list
)
for
(
Node
*
node
:
node
s
)
{
{
try
try
{
{
...
...
src/ngraph/pass/propagate_types.hpp
View file @
118e0679
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#pragma once
#pragma once
#include "ngraph/pass/
call_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -25,10 +25,10 @@ namespace ngraph
...
@@ -25,10 +25,10 @@ namespace ngraph
class
Node
;
class
Node
;
}
}
class
ngraph
::
pass
::
PropagateTypes
:
public
Call
Base
class
ngraph
::
pass
::
PropagateTypes
:
public
Call
GraphPass
{
{
public
:
public
:
virtual
bool
run_on_call_
list
(
std
::
list
<
Node
*>&
)
override
;
virtual
bool
run_on_call_
graph
(
std
::
list
<
Node
*>&
)
override
;
private
:
private
:
};
};
src/ngraph/pass/topological_sort.cpp
View file @
118e0679
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include <deque>
#include <deque>
#include <unordered_map>
#include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/log.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/manager.hpp"
...
@@ -24,14 +25,13 @@
...
@@ -24,14 +25,13 @@
using
namespace
ngraph
;
using
namespace
ngraph
;
using
namespace
std
;
using
namespace
std
;
bool
ngraph
::
pass
::
TopologicalSort
::
run_on_
tree
(
std
::
shared_ptr
<
Node
>
p
)
bool
ngraph
::
pass
::
TopologicalSort
::
run_on_
function
(
ngraph
::
Function
*
func
)
{
{
list
<
Node
*>&
sorted_list
=
get_state
().
get_call_graph
();
list
<
Node
*>
result_list
;
sorted_list
.
clear
();
deque
<
Node
*>
independent_nodes
;
deque
<
Node
*>
independent_nodes
;
unordered_map
<
Node
*
,
size_t
>
node_depencency_count
;
unordered_map
<
Node
*
,
size_t
>
node_depencency_count
;
traverse_nodes
(
p
,
[
&
](
Node
*
node
)
{
traverse_nodes
(
func
->
get_result
()
,
[
&
](
Node
*
node
)
{
node_depencency_count
[
node
]
=
node
->
get_arguments
().
size
();
node_depencency_count
[
node
]
=
node
->
get_arguments
().
size
();
if
(
node
->
get_arguments
().
size
()
==
0
)
if
(
node
->
get_arguments
().
size
()
==
0
)
{
{
...
@@ -42,7 +42,7 @@ bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p)
...
@@ -42,7 +42,7 @@ bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p)
while
(
independent_nodes
.
size
()
>
0
)
while
(
independent_nodes
.
size
()
>
0
)
{
{
auto
independent_node
=
independent_nodes
.
front
();
auto
independent_node
=
independent_nodes
.
front
();
sorted
_list
.
push_back
(
independent_node
);
result
_list
.
push_back
(
independent_node
);
independent_nodes
.
pop_front
();
independent_nodes
.
pop_front
();
for
(
auto
user
:
independent_node
->
users
())
for
(
auto
user
:
independent_node
->
users
())
...
@@ -56,5 +56,7 @@ bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p)
...
@@ -56,5 +56,7 @@ bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p)
}
}
}
}
func
->
set_ordered_ops
(
result_list
);
return
false
;
return
false
;
}
}
src/ngraph/pass/topological_sort.hpp
View file @
118e0679
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
#include <list>
#include <list>
#include <memory>
#include <memory>
#include "ngraph/pass/
tree_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -28,9 +28,9 @@ namespace ngraph
...
@@ -28,9 +28,9 @@ namespace ngraph
class
Node
;
class
Node
;
}
}
class
ngraph
::
pass
::
TopologicalSort
:
public
TreeBase
class
ngraph
::
pass
::
TopologicalSort
:
public
FunctionPass
{
{
public
:
public
:
TopologicalSort
()
{}
TopologicalSort
()
{}
bool
run_on_
tree
(
std
::
shared_ptr
<
Node
>
)
override
;
bool
run_on_
function
(
ngraph
::
Function
*
)
override
;
};
};
src/ngraph/pass/tree_pass.cpp
deleted
100644 → 0
View file @
c2ff1508
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/pass/tree_pass.hpp"
src/ngraph/pass/visualize_tree.cpp
View file @
118e0679
...
@@ -14,25 +14,30 @@
...
@@ -14,25 +14,30 @@
#include <fstream>
#include <fstream>
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp"
#include "ngraph/util.hpp"
using
namespace
ngraph
;
using
namespace
ngraph
;
using
namespace
std
;
using
namespace
std
;
bool
pass
::
VisualizeTree
::
run_on_
tree
(
std
::
shared_ptr
<
Node
>
base_node
)
bool
pass
::
VisualizeTree
::
run_on_
module
(
vector
<
ngraph
::
Function
*>&
functions
)
{
{
for
(
Function
*
f
:
functions
)
{
// map<size_t, list<node_ptr>> dependent_nodes;
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes
(
base_node
,
[
&
](
Node
*
node
)
{
traverse_nodes
(
f
->
get_result
()
,
[
&
](
Node
*
node
)
{
for
(
auto
arg
:
node
->
get_arguments
())
for
(
auto
arg
:
node
->
get_arguments
())
{
{
m_ss
<<
add_attributes
(
arg
.
get
());
m_ss
<<
add_attributes
(
arg
.
get
());
m_ss
<<
add_attributes
(
node
);
m_ss
<<
add_attributes
(
node
);
m_ss
<<
" "
<<
arg
->
get_node_id
()
<<
" -> "
<<
node
->
get_node_id
();
m_ss
<<
" "
<<
arg
->
get_name
()
<<
" -> "
<<
node
->
get_name
();
m_ss
<<
";
\n
"
;
m_ss
<<
";
\n
"
;
}
}
});
});
}
render
();
render
();
...
@@ -60,11 +65,11 @@ std::string pass::VisualizeTree::get_attributes(const Node* node)
...
@@ -60,11 +65,11 @@ std::string pass::VisualizeTree::get_attributes(const Node* node)
stringstream
ss
;
stringstream
ss
;
if
(
node
->
is_parameter
())
if
(
node
->
is_parameter
())
{
{
ss
<<
" "
<<
node
->
get_n
ode_id
()
<<
" [shape=box color=blue]
\n
"
;
ss
<<
" "
<<
node
->
get_n
ame
()
<<
" [shape=box color=blue]
\n
"
;
}
}
else
else
{
{
ss
<<
" "
<<
node
->
get_n
ode_id
()
<<
" [shape=ellipse color=black]
\n
"
;
ss
<<
" "
<<
node
->
get_n
ame
()
<<
" [shape=ellipse color=black]
\n
"
;
}
}
return
ss
.
str
();
return
ss
.
str
();
}
}
...
...
src/ngraph/pass/visualize_tree.hpp
View file @
118e0679
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
#include <sstream>
#include <sstream>
#include <string>
#include <string>
#include "ngraph/pass/
tree_
pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -29,11 +29,11 @@ namespace ngraph
...
@@ -29,11 +29,11 @@ namespace ngraph
class
Node
;
class
Node
;
}
}
class
ngraph
::
pass
::
VisualizeTree
:
public
TreeBase
class
ngraph
::
pass
::
VisualizeTree
:
public
ModulePass
{
{
public
:
public
:
VisualizeTree
(
const
std
::
string
&
file_name
);
VisualizeTree
(
const
std
::
string
&
file_name
);
bool
run_on_
tree
(
std
::
shared_ptr
<
Node
>
)
override
;
bool
run_on_
module
(
std
::
vector
<
ngraph
::
Function
*>&
)
override
;
private
:
private
:
std
::
string
add_attributes
(
const
Node
*
node
);
std
::
string
add_attributes
(
const
Node
*
node
);
...
...
src/ngraph/runtime/external_function.cpp
View file @
118e0679
...
@@ -659,7 +659,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
...
@@ -659,7 +659,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Turn this into a pass
// Turn this into a pass
// Assign layouts
// Assign layouts
// For now, just make everyone row-major.
// For now, just make everyone row-major.
for
(
const
Node
*
node
:
pass_manager
.
get_call_graph
())
for
(
const
Node
*
node
:
m_function
->
get_ordered_ops
())
{
{
for
(
const
descriptor
::
Output
&
output
:
node
->
get_outputs
())
for
(
const
descriptor
::
Output
&
output
:
node
->
get_outputs
())
{
{
...
@@ -696,7 +696,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
...
@@ -696,7 +696,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
m_n_outputs
=
tensor_index
.
size
()
-
m_n_inputs
;
m_n_outputs
=
tensor_index
.
size
()
-
m_n_inputs
;
// All remaining tensor views
// All remaining tensor views
for
(
const
Node
*
node
:
pass_manager
.
get_call_graph
())
for
(
const
Node
*
node
:
m_function
->
get_ordered_ops
())
{
{
for
(
const
descriptor
::
Output
&
output
:
node
->
get_outputs
())
for
(
const
descriptor
::
Output
&
output
:
node
->
get_outputs
())
{
{
...
@@ -712,7 +712,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
...
@@ -712,7 +712,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Now we build the eigen-VM instructions
// Now we build the eigen-VM instructions
auto
op_map
=
get_op_map
();
auto
op_map
=
get_op_map
();
for
(
const
Node
*
node
:
pass_manager
.
get_call_graph
())
for
(
const
Node
*
node
:
m_function
->
get_ordered_ops
())
{
{
auto
handler_it
=
op_map
.
find
(
type_index
(
typeid
(
*
node
)));
auto
handler_it
=
op_map
.
find
(
type_index
(
typeid
(
*
node
)));
if
(
handler_it
==
op_map
.
end
())
if
(
handler_it
==
op_map
.
end
())
...
...
src/ngraph/visualize.cpp
deleted
100644 → 0
View file @
c2ff1508
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cstdio>
#include <fstream>
#include <list>
#include "ngraph/node.hpp"
#include "ngraph/util.hpp"
#include "ngraph/visualize.hpp"
using
namespace
ngraph
;
using
namespace
std
;
Visualize
::
Visualize
(
const
string
&
name
)
:
m_name
{
name
}
{
}
void
Visualize
::
add
(
node_ptr
p
)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes
(
p
,
[
&
](
Node
*
node
)
{
for
(
auto
arg
:
node
->
get_arguments
())
{
m_ss
<<
add_attributes
(
arg
.
get
());
m_ss
<<
add_attributes
(
node
);
m_ss
<<
" "
<<
arg
->
get_node_id
()
<<
" -> "
<<
node
->
get_node_id
();
m_ss
<<
";
\n
"
;
}
});
}
std
::
string
Visualize
::
add_attributes
(
const
Node
*
node
)
{
string
rc
;
if
(
!
contains
(
m_nodes_with_attributes
,
node
))
{
m_nodes_with_attributes
.
insert
(
node
);
rc
=
get_attributes
(
node
);
}
return
rc
;
}
std
::
string
Visualize
::
get_attributes
(
const
Node
*
node
)
{
stringstream
ss
;
if
(
node
->
is_parameter
())
{
ss
<<
" "
<<
node
->
get_node_id
()
<<
" [shape=box color=blue]
\n
"
;
}
else
{
ss
<<
" "
<<
node
->
get_node_id
()
<<
" [shape=ellipse color=black]
\n
"
;
}
return
ss
.
str
();
}
void
Visualize
::
save_dot
(
const
string
&
path
)
const
{
#ifdef GRAPHVIZ_FOUND
auto
tmp_file
=
path
+
".tmp"
;
ofstream
out
(
tmp_file
);
if
(
out
)
{
out
<<
"digraph "
<<
m_name
<<
"
\n
{
\n
"
;
out
<<
m_ss
.
str
();
out
<<
"}
\n
"
;
out
.
close
();
stringstream
ss
;
ss
<<
"dot -Tpng "
<<
tmp_file
<<
" -o "
<<
path
;
auto
cmd
=
ss
.
str
();
auto
stream
=
popen
(
cmd
.
c_str
(),
"r"
);
pclose
(
stream
);
remove
(
tmp_file
.
c_str
());
}
#else
#endif
}
test/pass_liveness.cpp
View file @
118e0679
...
@@ -51,7 +51,7 @@ TEST(pass, liveness)
...
@@ -51,7 +51,7 @@ TEST(pass, liveness)
shared_ptr
<
Function
>
func
=
make_test_graph
();
shared_ptr
<
Function
>
func
=
make_test_graph
();
pass_manager
.
run_passes
(
func
.
get
());
pass_manager
.
run_passes
(
func
.
get
());
auto
sorted
=
pass_manager
.
get_call_graph
();
auto
sorted
=
func
->
get_ordered_ops
();
// for (const Node* node : sorted)
// for (const Node* node : sorted)
// {
// {
...
...
test/pass_manager.cpp
View file @
118e0679
...
@@ -40,15 +40,28 @@ TEST(pass_manager, add)
...
@@ -40,15 +40,28 @@ TEST(pass_manager, add)
auto
graph
=
make_test_graph
();
auto
graph
=
make_test_graph
();
size_t
node_count
=
get_node_count
(
graph
->
get_result
());
size_t
node_count
=
get_node_count
(
graph
->
get_result
());
pass_manager
.
run_passes
(
graph
.
get
());
pass_manager
.
run_passes
(
graph
.
get
());
auto
sorted
=
pass_manager
.
get_call_graph
();
auto
sorted
=
graph
->
get_ordered_ops
();
EXPECT_EQ
(
node_count
,
sorted
.
size
());
EXPECT_EQ
(
node_count
,
sorted
.
size
());
EXPECT_TRUE
(
validate_list
(
sorted
));
EXPECT_TRUE
(
validate_list
(
sorted
));
}
}
TEST
(
pass_manager
,
dependency
)
TEST
(
pass_manager
,
module_add_function
)
{
{
pass
::
Manager
pass_manager
;
// First create "f(A,B,C) = (A+B)*C".
auto
shape
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
C
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_f
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
f
=
make_shared
<
Function
>
((
A
+
B
)
*
C
,
rt_f
,
op
::
Parameters
{
A
,
B
,
C
});
pass_manager
.
register_pass
<
pass
::
TopologicalSort
>
();
// Now make "g(X,Y,Z) = f(X,Y,Z) + f(X,Y,Z)"
EXPECT_THROW
(
pass_manager
.
register_pass
<
pass
::
AssignTensors
>
(),
runtime_error
);
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Y
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Z
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_g
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
g
=
make_shared
<
Function
>
(
make_shared
<
op
::
FunctionCall
>
(
f
,
Nodes
{
X
,
Y
,
Z
})
+
make_shared
<
op
::
FunctionCall
>
(
f
,
Nodes
{
X
,
Y
,
Z
}),
rt_g
,
op
::
Parameters
{
X
,
Y
,
Z
});
}
}
test/pass_memory_layout.cpp
View file @
118e0679
...
@@ -218,7 +218,7 @@ TEST(memory_layout, basic)
...
@@ -218,7 +218,7 @@ TEST(memory_layout, basic)
auto
graph
=
make_test_graph
();
auto
graph
=
make_test_graph
();
pass_manager
.
run_passes
(
graph
);
pass_manager
.
run_passes
(
graph
);
auto
sorted
=
pass_manager
.
get_call_graph
();
auto
sorted
=
graph
->
get_ordered_ops
();
size_t
temporary_pool_size
=
pass_manager
.
get_state
().
get_temporary_pool_size
();
size_t
temporary_pool_size
=
pass_manager
.
get_state
().
get_temporary_pool_size
();
EXPECT_EQ
(
12
,
temporary_pool_size
);
EXPECT_EQ
(
12
,
temporary_pool_size
);
}
}
test/topological_sort.cpp
View file @
118e0679
...
@@ -21,10 +21,10 @@
...
@@ -21,10 +21,10 @@
#include "ngraph/log.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/collect_functions.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/util.hpp"
#include "ngraph/util.hpp"
#include "ngraph/visualize.hpp"
#include "test_tools.hpp"
#include "test_tools.hpp"
using
namespace
std
;
using
namespace
std
;
...
@@ -69,7 +69,7 @@ TEST(topological_sort, basic)
...
@@ -69,7 +69,7 @@ TEST(topological_sort, basic)
pass
::
Manager
pass_manager
;
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
TopologicalSort
>
();
pass_manager
.
register_pass
<
pass
::
TopologicalSort
>
();
pass_manager
.
run_passes
(
f0
);
pass_manager
.
run_passes
(
f0
);
auto
sorted_list
=
pass_manager
.
get_call_graph
();
auto
sorted_list
=
f0
->
get_ordered_ops
();
size_t
node_count
=
get_node_count
(
r0
);
size_t
node_count
=
get_node_count
(
r0
);
...
@@ -121,7 +121,7 @@ TEST(benchmark, topological_sort)
...
@@ -121,7 +121,7 @@ TEST(benchmark, topological_sort)
pass
::
Manager
pass_manager
;
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
TopologicalSort
>
();
pass_manager
.
register_pass
<
pass
::
TopologicalSort
>
();
pass_manager
.
run_passes
(
f0
);
pass_manager
.
run_passes
(
f0
);
auto
sorted_list
=
pass_manager
.
get_call_graph
();
auto
sorted_list
=
f0
->
get_ordered_ops
();
timer
.
stop
();
timer
.
stop
();
NGRAPH_INFO
<<
"topological sort took "
<<
timer
.
get_milliseconds
()
<<
"ms"
;
NGRAPH_INFO
<<
"topological sort took "
<<
timer
.
get_milliseconds
()
<<
"ms"
;
...
@@ -135,3 +135,51 @@ TEST(benchmark, topological_sort)
...
@@ -135,3 +135,51 @@ TEST(benchmark, topological_sort)
timer
.
stop
();
timer
.
stop
();
NGRAPH_INFO
<<
"delete nodes took "
<<
timer
.
get_milliseconds
()
<<
"ms"
;
NGRAPH_INFO
<<
"delete nodes took "
<<
timer
.
get_milliseconds
()
<<
"ms"
;
}
}
TEST
(
topological_sort
,
collect_functions
)
{
// First create "f(A,B,C) = (A+B)*C".
auto
shape
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
C
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_f
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
f
=
make_shared
<
Function
>
((
A
+
B
)
*
C
,
rt_f
,
op
::
Parameters
{
A
,
B
,
C
},
"f"
);
// Now make "g(X,Y,Z) = f(X,Y,Z) + f(X,Y,Z)"
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Y
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Z
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_g
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
g
=
make_shared
<
Function
>
(
make_shared
<
op
::
FunctionCall
>
(
f
,
Nodes
{
X
,
Y
,
Z
})
+
make_shared
<
op
::
FunctionCall
>
(
f
,
Nodes
{
X
,
Y
,
Z
}),
rt_g
,
op
::
Parameters
{
X
,
Y
,
Z
},
"g"
);
// Now make "h(X,Y,Z) = g(X,Y,Z) + g(X,Y,Z)"
auto
X1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Y1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Z1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_h
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
h
=
make_shared
<
Function
>
(
make_shared
<
op
::
FunctionCall
>
(
g
,
Nodes
{
X1
,
Y1
,
Z1
})
+
make_shared
<
op
::
FunctionCall
>
(
g
,
Nodes
{
X1
,
Y1
,
Z1
}),
rt_h
,
op
::
Parameters
{
X1
,
Y1
,
Z1
},
"h"
);
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
CollectFunctions
>
();
pass_manager
.
run_passes
(
h
);
set
<
string
>
expected
=
{
"f"
,
"g"
,
"h"
};
auto
functions
=
pass_manager
.
get_state
().
get_functions
();
vector
<
string
>
fnames
;
for
(
Function
*
func
:
functions
)
{
fnames
.
push_back
(
func
->
get_name
());
}
EXPECT_EQ
(
expected
.
size
(),
functions
.
size
());
EXPECT_TRUE
(
contains
(
fnames
,
"f"
));
EXPECT_TRUE
(
contains
(
fnames
,
"g"
));
EXPECT_TRUE
(
contains
(
fnames
,
"h"
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment