Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
9b3d5732
Commit
9b3d5732
authored
Jul 09, 2019
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix top sort
parent
de37f9d3
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
103 additions
and
71 deletions
+103
-71
output.cpp
src/ngraph/descriptor/output.cpp
+13
-3
output.hpp
src/ngraph/descriptor/output.hpp
+3
-3
graph_util.cpp
src/ngraph/graph_util.cpp
+12
-12
graph_util.hpp
src/ngraph/graph_util.hpp
+33
-45
node.cpp
src/ngraph/node.cpp
+1
-1
node.hpp
src/ngraph/node.hpp
+1
-1
control_dependencies.cpp
test/control_dependencies.cpp
+4
-6
test_tools.cpp
test/util/test_tools.cpp
+32
-0
test_tools.hpp
test/util/test_tools.hpp
+4
-0
No files found.
src/ngraph/descriptor/output.cpp
View file @
9b3d5732
...
...
@@ -14,8 +14,10 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/descriptor/output.hpp"
#include <algorithm>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/node.hpp"
using
namespace
std
;
...
...
@@ -31,12 +33,20 @@ descriptor::Output::Output(Node* node, size_t index, const shared_ptr<Tensor>& t
// Add an input to the vector of inputs that use this output.
void
descriptor
::
Output
::
add_input
(
Input
*
input
)
{
m_inputs
.
insert
(
input
);
// Keep the inputs in insertion order to keep sorts deterministic
if
(
find
(
m_inputs
.
begin
(),
m_inputs
.
end
(),
input
)
==
m_inputs
.
end
())
{
m_inputs
.
push_back
(
input
);
}
}
void
descriptor
::
Output
::
remove_input
(
Input
*
input
)
{
m_inputs
.
erase
(
input
);
auto
it
=
find
(
m_inputs
.
begin
(),
m_inputs
.
end
(),
input
);
if
(
it
!=
m_inputs
.
end
())
{
m_inputs
.
erase
(
it
);
}
}
shared_ptr
<
Node
>
descriptor
::
Output
::
get_node
()
const
...
...
src/ngraph/descriptor/output.hpp
View file @
9b3d5732
...
...
@@ -17,7 +17,7 @@
#pragma once
#include <memory>
#include <
set
>
#include <
vector
>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/tensor.hpp"
...
...
@@ -48,7 +48,7 @@ namespace ngraph
void
set_tensor_ptr
(
const
std
::
shared_ptr
<
Tensor
>&
tensor
)
{
m_tensor
=
tensor
;
}
void
add_input
(
Input
*
input
);
void
remove_input
(
Input
*
input
);
const
std
::
set
<
Input
*>&
get_inputs
()
const
{
return
m_inputs
;
}
const
std
::
vector
<
Input
*>&
get_inputs
()
const
{
return
m_inputs
;
}
Tensor
&
get_tensor
()
const
;
/// \return the shape of the output
...
...
@@ -64,7 +64,7 @@ namespace ngraph
Node
*
m_node
;
size_t
m_index
;
std
::
shared_ptr
<
Tensor
>
m_tensor
;
std
::
set
<
Input
*>
m_inputs
;
std
::
vector
<
Input
*>
m_inputs
;
private
:
Output
(
const
Output
&
)
=
delete
;
...
...
src/ngraph/graph_util.cpp
View file @
9b3d5732
...
...
@@ -81,27 +81,27 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results,
while
(
stack
.
size
()
>
0
)
{
std
::
shared_ptr
<
Node
>
n
=
stack
.
front
();
stack
.
pop_front
();
if
(
instances_seen
.
count
(
n
)
==
0
)
{
instances_seen
.
insert
(
n
);
f
(
n
);
}
stack
.
pop_front
();
for
(
auto
arg
:
n
->
get_arguments
())
{
if
(
instances_seen
.
count
(
arg
)
==
0
)
for
(
auto
arg
:
n
->
get_arguments
())
{
stack
.
push_front
(
arg
);
if
(
instances_seen
.
count
(
arg
)
==
0
)
{
stack
.
push_front
(
arg
);
}
}
}
if
(
include_control_deps
)
{
for
(
auto
cdep
:
n
->
get_control_dependencies
())
if
(
include_control_deps
)
{
if
(
instances_seen
.
count
(
cdep
)
==
0
)
for
(
auto
cdep
:
n
->
get_control_dependencies
()
)
{
stack
.
push_front
(
cdep
);
if
(
instances_seen
.
count
(
cdep
)
==
0
)
{
stack
.
push_front
(
cdep
);
}
}
}
}
...
...
src/ngraph/graph_util.hpp
View file @
9b3d5732
...
...
@@ -20,6 +20,7 @@
#include <functional>
#include <list>
#include <memory>
#include <stack>
#include <string>
#include <unordered_map>
#include <unordered_set>
...
...
@@ -81,66 +82,53 @@ namespace ngraph
std
::
list
<
std
::
shared_ptr
<
Node
>>
topological_sort
(
const
T
&
nodes
,
bool
include_control_deps
=
false
)
{
std
::
deque
<
ngraph
::
Node
*>
independent_nodes
;
std
::
unordered_map
<
const
ngraph
::
Node
*
,
size_t
>
node_dependency_count
;
std
::
unordered_map
<
ngraph
::
Node
*
,
std
::
shared_ptr
<
ngraph
::
Node
>>
node_map
;
std
::
unordered_map
<
ngraph
::
Node
*
,
std
::
set
<
Node
*>>
control_deps_users
;
std
::
stack
<
ngraph
::
Node
*>
nodes_to_do
;
std
::
set
<
Node
*>
nodes_done
;
std
::
list
<
std
::
shared_ptr
<
Node
>>
result
;
for
(
auto
node
:
nodes
)
{
//build an equivalent of node->get_users() but for control dependencies
size_t
control_deps_count
=
0
;
if
(
include_control_deps
)
{
for
(
auto
cd
:
node
->
get_control_dependencies
())
{
control_deps_count
++
;
control_deps_users
[
cd
.
get
()].
insert
(
node
.
get
());
}
}
node_map
[
node
.
get
()]
=
node
;
size_t
deps_count
=
node
->
get_input_size
()
+
control_deps_count
;
node_dependency_count
[
node
.
get
()]
=
deps_count
;
if
(
deps_count
==
0
)
{
independent_nodes
.
push_back
(
node
.
get
());
}
nodes_to_do
.
push
(
node
.
get
());
}
std
::
list
<
std
::
shared_ptr
<
ngraph
::
Node
>>
result_list
;
while
(
independent_nodes
.
size
()
>
0
)
while
(
nodes_to_do
.
size
()
>
0
)
{
auto
independent_node
=
independent_nodes
.
front
();
result_list
.
push_back
(
node_map
[
independent_node
]);
independent_nodes
.
pop_front
();
for
(
const
std
::
shared_ptr
<
Node
>&
user
:
independent_node
->
get_users
())
Node
*
node
=
nodes_to_do
.
top
();
if
(
nodes_done
.
count
(
node
)
!=
0
)
{
if
(
--
node_dependency_count
[
user
.
get
()]
==
0
)
nodes_to_do
.
pop
();
continue
;
}
bool
can_add
=
true
;
size_t
arg_count
=
node
->
get_input_size
();
for
(
size_t
i
=
0
;
i
<
arg_count
;
++
i
)
{
Node
*
dep
=
node
->
input
(
arg_count
-
i
-
1
).
get_source_output
().
get_node
();
if
(
nodes_done
.
count
(
dep
)
==
0
)
{
independent_nodes
.
push_back
(
user
.
get
());
can_add
=
false
;
nodes_to_do
.
push
(
dep
);
}
}
if
(
include_control_deps
)
{
auto
cdit
=
control_deps_users
.
find
(
independent_node
);
if
(
cdit
!=
control_deps_users
.
end
())
for
(
auto
cd_user
:
cdit
->
second
)
for
(
auto
depptr
:
node
->
get_control_dependencies
())
{
Node
*
dep
=
depptr
.
get
();
if
(
nodes_done
.
count
(
dep
)
==
0
)
{
node_dependency_count
[
cd_user
]
-=
1
;
size_t
count
=
node_dependency_count
[
cd_user
];
if
(
count
==
0
)
{
independent_nodes
.
push_back
(
cd_user
);
}
can_add
=
false
;
nodes_to_do
.
push
(
dep
);
}
}
}
if
(
can_add
)
{
result
.
push_back
(
node
->
shared_from_this
());
nodes_to_do
.
pop
();
nodes_done
.
insert
(
node
);
}
}
NGRAPH_CHECK
(
nodes
.
size
()
==
result_list
.
size
());
return
result_list
;
return
result
;
}
// For cases, where `nodes` is a subset of the entire graph
...
...
src/ngraph/node.cpp
View file @
9b3d5732
...
...
@@ -344,7 +344,7 @@ shared_ptr<descriptor::Tensor> Node::get_output_tensor_ptr() const
return
m_outputs
.
at
(
0
).
get_tensor_ptr
();
}
const
std
::
set
<
descriptor
::
Input
*>&
Node
::
get_output_inputs
(
size_t
i
)
const
const
std
::
vector
<
descriptor
::
Input
*>&
Node
::
get_output_inputs
(
size_t
i
)
const
{
return
m_outputs
.
at
(
i
).
get_inputs
();
}
...
...
src/ngraph/node.hpp
View file @
9b3d5732
...
...
@@ -257,7 +257,7 @@ namespace ngraph
"output, or update calling code not to assume only one output"
);
/// Returns the set of inputs using output i
const
std
::
set
<
descriptor
::
Input
*>&
get_output_inputs
(
size_t
i
)
const
const
std
::
vector
<
descriptor
::
Input
*>&
get_output_inputs
(
size_t
i
)
const
NGRAPH_DEPRECATED
(
"use node->output(i).get_target_inputs() instead"
);
/// Returns the number of inputs for the op
...
...
test/control_dependencies.cpp
View file @
9b3d5732
...
...
@@ -87,8 +87,7 @@ TEST(control_dependencies, cdep_ops)
make_shared
<
ControlDependencyOp
>
(
NodeVector
{
A
},
std
::
set
<
std
::
shared_ptr
<
Node
>>
{
absn
});
auto
f
=
make_shared
<
Function
>
(
cdop
,
ParameterVector
{
A
,
B
});
auto
nodes
=
f
->
get_ordered_ops
(
true
);
ASSERT_EQ
(
nodes
.
back
()
->
get_argument
(
0
),
cdop
);
test_ordered_ops
(
f
);
}
TEST
(
control_dependencies
,
two_cdep_ops
)
...
...
@@ -102,8 +101,7 @@ TEST(control_dependencies, two_cdep_ops)
std
::
set
<
std
::
shared_ptr
<
Node
>>
{
absn
,
absn_c
});
auto
f
=
make_shared
<
Function
>
(
cdop
,
ParameterVector
{
A
,
B
,
C
});
auto
nodes
=
f
->
get_ordered_ops
(
true
);
ASSERT_EQ
(
nodes
.
back
()
->
get_argument
(
0
),
cdop
);
test_ordered_ops
(
f
);
}
TEST
(
control_dependencies
,
two_cdep_ops_op_on_top
)
...
...
@@ -117,8 +115,7 @@ TEST(control_dependencies, two_cdep_ops_op_on_top)
auto
absn_cdop
=
make_shared
<
op
::
Abs
>
(
cdop
);
auto
f
=
make_shared
<
Function
>
(
absn_cdop
,
ParameterVector
{
A
,
B
});
auto
nodes
=
f
->
get_ordered_ops
(
true
);
ASSERT_EQ
(
nodes
.
back
()
->
get_argument
(
0
),
absn_cdop
);
test_ordered_ops
(
f
);
}
TEST
(
control_dependencies
,
clone_function_cdop
)
...
...
@@ -129,6 +126,7 @@ TEST(control_dependencies, clone_function_cdop)
make_shared
<
ControlDependencyOp
>
(
NodeVector
{
A
},
std
::
set
<
std
::
shared_ptr
<
Node
>>
{
absn
});
auto
f
=
make_shared
<
Function
>
(
cdop
,
ParameterVector
{
A
});
test_ordered_ops
(
f
);
auto
clone
=
ngraph
::
clone_function
(
*
f
.
get
());
auto
matcher
=
std
::
make_shared
<
pattern
::
Matcher
>
(
cdop
);
auto
cdop_clone
=
clone
->
get_results
().
at
(
0
)
->
get_argument
(
0
);
...
...
test/util/test_tools.cpp
View file @
9b3d5732
...
...
@@ -313,3 +313,35 @@ std::shared_ptr<Function> make_function_from_file(const std::string& file_name)
return
func
;
}
#endif
::
testing
::
AssertionResult
test_ordered_ops
(
shared_ptr
<
Function
>
f
)
{
set
<
shared_ptr
<
Node
>>
seen
;
for
(
auto
node
:
f
->
get_ordered_ops
())
{
if
(
seen
.
count
(
node
)
>
0
)
{
return
::
testing
::
AssertionFailure
()
<<
"Duplication in ordered ops"
;
}
size_t
arg_count
=
node
->
get_input_size
();
for
(
size_t
i
=
0
;
i
<
arg_count
;
++
i
)
{
shared_ptr
<
Node
>
dep
=
node
->
input
(
i
).
get_source_output
().
get_node_shared_ptr
();
if
(
seen
.
count
(
dep
)
==
0
)
{
return
::
testing
::
AssertionFailure
()
<<
"Argument "
<<
dep
<<
" does not occur before op"
<<
node
;
}
}
for
(
shared_ptr
<
Node
>
dep
:
node
->
get_control_dependencies
())
{
if
(
seen
.
count
(
dep
)
==
0
)
{
return
::
testing
::
AssertionFailure
()
<<
"Control dependency "
<<
dep
<<
" does not occur before op"
<<
node
;
}
}
seen
.
insert
(
node
);
}
return
::
testing
::
AssertionSuccess
();
}
test/util/test_tools.hpp
View file @
9b3d5732
...
...
@@ -25,6 +25,8 @@
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/descriptor/layout/tensor_layout.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/log.hpp"
...
...
@@ -276,3 +278,5 @@ std::vector<T> read_binary_file(const std::string& path)
inputs_fs
.
read
(
reinterpret_cast
<
char
*>
(
file_content
.
data
()),
size
);
return
file_content
;
}
testing
::
AssertionResult
test_ordered_ops
(
std
::
shared_ptr
<
ngraph
::
Function
>
f
);
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment