Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
fac27c37
Commit
fac27c37
authored
7 years ago
by
Scott Cyphers
Committed by
GitHub
7 years ago
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #77 from NervanaSystems/bob/tsort2
topological sort working with unit test
parents
a89c33b4
a7c1841e
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
121 additions
and
22 deletions
+121
-22
node.cpp
src/ngraph/node.cpp
+1
-1
node.hpp
src/ngraph/node.hpp
+6
-6
topological_sort.cpp
src/ngraph/topological_sort.cpp
+46
-1
topological_sort.hpp
src/ngraph/topological_sort.hpp
+9
-2
type.hpp
src/ngraph/type.hpp
+2
-1
topological_sort.cpp
test/topological_sort.cpp
+57
-11
No files found.
src/ngraph/node.cpp
View file @
fac27c37
...
...
@@ -25,7 +25,7 @@ ngraph::Node::Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type)
// Add this node as a user of each argument.
for
(
auto
node
:
m_arguments
)
{
node
->
m_users
.
insert
(
node
.
get
()
);
node
->
m_users
.
insert
(
this
);
}
}
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/node.hpp
View file @
fac27c37
...
...
@@ -32,7 +32,7 @@ namespace ngraph
** zero or more nodes as arguments and one value, which is either a tensor
** view or a (possibly empty) tuple of values.
**/
class
Node
:
public
TypedValueMixin
class
Node
:
public
TypedValueMixin
,
public
std
::
enable_shared_from_this
<
Node
>
{
public
:
using
ptr
=
std
::
shared_ptr
<
Node
>
;
...
...
@@ -74,11 +74,11 @@ namespace ngraph
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Node
&
);
protected
:
Nodes
m_arguments
;
std
::
multiset
<
Node
*>
m_users
;
std
::
string
m_name
;
size_t
m_instance_id
;
static
size_t
m_next_instance_id
;
Nodes
m_arguments
;
std
::
multiset
<
Node
*>
m_users
;
std
::
string
m_name
;
size_t
m_instance_id
;
static
size_t
m_next_instance_id
;
};
using
node_ptr
=
std
::
shared_ptr
<
Node
>
;
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/topological_sort.cpp
View file @
fac27c37
...
...
@@ -12,8 +12,53 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "node.hpp"
#include "topological_sort.hpp"
#include "util.hpp"
void
ngraph
::
TopologicalSort
::
process
(
node_ptr
node
)
using
namespace
ngraph
;
using
namespace
std
;
void
ngraph
::
TopologicalSort
::
promote_node
(
Node
*
n
)
{
for
(
auto
dn
=
m_dependent_nodes
.
begin
();
dn
!=
m_dependent_nodes
.
end
();
dn
++
)
{
if
(
dn
->
first
>
0
)
// Skip zero as they should never be promoted
{
auto
it
=
find
(
dn
->
second
.
begin
(),
dn
->
second
.
end
(),
n
);
if
(
it
!=
dn
->
second
.
end
())
{
// found the node
dn
->
second
.
erase
(
it
);
m_dependent_nodes
[
dn
->
first
-
1
].
push_back
(
n
);
}
}
}
}
void
ngraph
::
TopologicalSort
::
process
(
node_ptr
p
)
{
traverse_nodes
(
p
,
[
&
](
node_ptr
node
)
{
list
<
Node
*>&
node_list
=
m_dependent_nodes
[
node
->
arguments
().
size
()];
node_list
.
push_back
(
node
.
get
());
});
list
<
Node
*>&
independent_nodes
=
m_dependent_nodes
[
0
];
while
(
independent_nodes
.
size
()
>
0
)
{
auto
independent_node
=
independent_nodes
.
front
();
m_sorted_list
.
push_back
(
independent_node
);
independent_nodes
.
pop_front
();
for
(
auto
user
:
independent_node
->
users
())
{
promote_node
(
user
);
}
}
}
const
std
::
vector
<
Node
*>&
ngraph
::
TopologicalSort
::
get_sorted_list
()
const
{
return
m_sorted_list
;
}
This diff is collapsed.
Click to expand it.
src/ngraph/topological_sort.hpp
View file @
fac27c37
...
...
@@ -15,6 +15,8 @@
#pragma once
#include <memory>
#include <map>
#include <list>
namespace
ngraph
{
...
...
@@ -26,9 +28,14 @@ namespace ngraph
class
ngraph
::
TopologicalSort
{
public
:
TopologicalSort
()
;
TopologicalSort
()
{}
static
void
process
(
node_ptr
);
void
process
(
node_ptr
);
const
std
::
vector
<
Node
*>&
get_sorted_list
()
const
;
private
:
void
promote_node
(
Node
*
n
);
std
::
map
<
size_t
,
std
::
list
<
Node
*>>
m_dependent_nodes
;
std
::
vector
<
Node
*>
m_sorted_list
;
};
This diff is collapsed.
Click to expand it.
src/ngraph/type.hpp
View file @
fac27c37
...
...
@@ -40,7 +40,7 @@ namespace ngraph
virtual
~
ValueType
()
{}
virtual
bool
operator
==
(
const
ValueType
::
ptr
&
that
)
const
=
0
;
bool
operator
!=
(
const
ValueType
::
ptr
&
that
)
const
{
return
!
(
*
this
==
that
);
}
bool
operator
!=
(
const
ValueType
::
ptr
&
that
)
const
{
return
!
(
*
this
==
that
);
}
};
/**
...
...
@@ -140,6 +140,7 @@ namespace ngraph
** The type associated with this value.
**/
const
ValueType
::
ptr
type
()
const
{
return
m_type
;
}
protected
:
ValueType
::
ptr
m_type
;
};
...
...
This diff is collapsed.
Click to expand it.
test/topological_sort.cpp
View file @
fac27c37
...
...
@@ -26,28 +26,74 @@
using
namespace
std
;
using
namespace
ngraph
;
TEST
(
top_sort
,
basic
)
static
bool
validate_list
(
const
vector
<
Node
*>&
nodes
)
{
auto
arg0
=
op
::
parameter
(
element
::
Float
::
type
,
{
1
});
ASSERT_NE
(
nullptr
,
arg0
);
auto
arg1
=
op
::
parameter
(
element
::
Float
::
type
,
{
1
});
ASSERT_NE
(
nullptr
,
arg1
);
auto
t0
=
op
::
add
(
arg0
,
arg1
);
bool
rc
=
true
;
for
(
auto
it
=
nodes
.
rbegin
();
it
!=
nodes
.
rend
();
it
++
)
{
Node
*
node
=
*
it
;
auto
node_tmp
=
*
it
;
auto
dependencies_tmp
=
node_tmp
->
arguments
();
vector
<
Node
*>
dependencies
;
for
(
shared_ptr
<
Node
>
n
:
dependencies_tmp
)
{
dependencies
.
push_back
(
n
.
get
());
}
auto
tmp
=
it
+
1
;
for
(;
tmp
!=
nodes
.
rend
();
tmp
++
)
{
auto
dep_tmp
=
*
tmp
;
auto
found
=
find
(
dependencies
.
begin
(),
dependencies
.
end
(),
dep_tmp
);
if
(
found
!=
dependencies
.
end
())
{
dependencies
.
erase
(
found
);
}
}
if
(
dependencies
.
size
()
>
0
)
{
rc
=
false
;
}
}
return
rc
;
}
TEST
(
topological_sort
,
basic
)
{
vector
<
shared_ptr
<
Parameter
>>
args
;
for
(
int
i
=
0
;
i
<
10
;
i
++
)
{
auto
arg
=
op
::
parameter
(
element
::
Float
::
type
,
{
1
});
ASSERT_NE
(
nullptr
,
arg
);
args
.
push_back
(
arg
);
}
auto
t0
=
op
::
add
(
args
[
0
],
args
[
1
]);
ASSERT_NE
(
nullptr
,
t0
);
auto
t1
=
op
::
add
(
arg0
,
arg1
);
auto
t1
=
op
::
dot
(
t0
,
args
[
2
]
);
ASSERT_NE
(
nullptr
,
t1
);
Node
::
ptr
r0
=
op
::
add
(
t0
,
t1
);
auto
t2
=
op
::
multiply
(
t0
,
args
[
3
]);
ASSERT_NE
(
nullptr
,
t2
);
auto
t3
=
op
::
add
(
t1
,
args
[
4
]);
ASSERT_NE
(
nullptr
,
t2
);
auto
t4
=
op
::
add
(
t2
,
args
[
5
]);
ASSERT_NE
(
nullptr
,
t3
);
Node
::
ptr
r0
=
op
::
add
(
t3
,
t4
);
ASSERT_NE
(
nullptr
,
r0
);
auto
f0
=
op
::
function
(
r0
,
{
arg0
,
arg1
}
);
auto
f0
=
op
::
function
(
r0
,
args
);
ASSERT_NE
(
nullptr
,
f0
);
ASSERT_EQ
(
2
,
r0
->
arguments
().
size
());
auto
op_r0
=
static_pointer_cast
<
Op
>
(
r0
);
cout
<<
"op_r0 name "
<<
*
r0
<<
endl
;
Visualize
vz
;
vz
.
add
(
r0
);
vz
.
save_dot
(
"test.png"
);
TopologicalSort
::
process
(
r0
);
TopologicalSort
ts
;
ts
.
process
(
r0
);
auto
sorted_list
=
ts
.
get_sorted_list
();
EXPECT_TRUE
(
validate_list
(
sorted_list
));
}
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment