Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
125b1a85
Commit
125b1a85
authored
Nov 11, 2017
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Get tuples out of autodiff
parent
7ff8e1f2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
32 deletions
+6
-32
adjoints.cpp
src/ngraph/autodiff/adjoints.cpp
+1
-32
output.hpp
src/ngraph/descriptor/output.hpp
+5
-0
No files found.
src/ngraph/autodiff/adjoints.cpp
View file @
125b1a85
...
@@ -30,9 +30,6 @@
...
@@ -30,9 +30,6 @@
using
namespace
ngraph
;
using
namespace
ngraph
;
/// @brief Make a zero matching a value type.
std
::
shared_ptr
<
Node
>
make_zero
(
const
std
::
shared_ptr
<
const
ValueType
>&
value_type
);
std
::
shared_ptr
<
Node
>
make_zero
(
const
std
::
shared_ptr
<
const
TensorViewType
>&
tensor_view_type
)
std
::
shared_ptr
<
Node
>
make_zero
(
const
std
::
shared_ptr
<
const
TensorViewType
>&
tensor_view_type
)
{
{
std
::
shared_ptr
<
Node
>
zero
=
std
::
shared_ptr
<
Node
>
zero
=
...
@@ -50,34 +47,6 @@ std::shared_ptr<Node> make_zero(const std::shared_ptr<const TensorViewType>& ten
...
@@ -50,34 +47,6 @@ std::shared_ptr<Node> make_zero(const std::shared_ptr<const TensorViewType>& ten
return
zero
;
return
zero
;
}
}
std
::
shared_ptr
<
Node
>
make_zero
(
const
std
::
shared_ptr
<
const
TupleType
>&
tuple_type
)
{
std
::
vector
<
std
::
shared_ptr
<
Node
>>
elements
;
for
(
auto
&
value_type
:
tuple_type
->
get_element_types
())
{
elements
.
push_back
(
make_zero
(
value_type
));
}
return
std
::
make_shared
<
op
::
Tuple
>
(
elements
);
}
std
::
shared_ptr
<
Node
>
make_zero
(
const
std
::
shared_ptr
<
const
ValueType
>&
value_type
)
{
std
::
shared_ptr
<
const
TensorViewType
>
tensor_view_type
=
std
::
dynamic_pointer_cast
<
const
TensorViewType
>
(
value_type
);
if
(
nullptr
!=
tensor_view_type
)
{
return
(
make_zero
(
tensor_view_type
));
}
std
::
shared_ptr
<
const
TupleType
>
tuple_type
=
std
::
dynamic_pointer_cast
<
const
TupleType
>
(
value_type
);
if
(
nullptr
!=
tuple_type
)
{
return
make_zero
(
tuple_type
);
}
// Should be impossible
throw
ngraph_error
(
"Unknown value type"
);
}
autodiff
::
Adjoints
::
Adjoints
(
const
std
::
shared_ptr
<
Node
>&
y
,
const
std
::
shared_ptr
<
Node
>&
c
)
autodiff
::
Adjoints
::
Adjoints
(
const
std
::
shared_ptr
<
Node
>&
y
,
const
std
::
shared_ptr
<
Node
>&
c
)
{
{
// Pass 1 determines which nodes contribute to y as well as setting up a reverse
// Pass 1 determines which nodes contribute to y as well as setting up a reverse
...
@@ -143,7 +112,7 @@ std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x)
...
@@ -143,7 +112,7 @@ std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x)
auto
adjoint_it
=
m_adjoint_map
.
find
(
x
.
get
());
auto
adjoint_it
=
m_adjoint_map
.
find
(
x
.
get
());
if
(
m_adjoint_map
.
end
()
==
adjoint_it
)
if
(
m_adjoint_map
.
end
()
==
adjoint_it
)
{
{
auto
result
=
make_zero
(
x
->
get_
value
_type
());
auto
result
=
make_zero
(
x
->
get_
outputs
().
at
(
0
).
get_tensor_view
_type
());
adjoint_it
=
m_adjoint_map
.
insert
({
x
.
get
(),
result
}).
first
;
adjoint_it
=
m_adjoint_map
.
insert
({
x
.
get
(),
result
}).
first
;
}
}
return
adjoint_it
->
second
;
return
adjoint_it
->
second
;
...
...
src/ngraph/descriptor/output.hpp
View file @
125b1a85
...
@@ -47,6 +47,11 @@ namespace ngraph
...
@@ -47,6 +47,11 @@ namespace ngraph
const
std
::
set
<
Input
*>&
get_inputs
()
const
{
return
m_inputs
;
}
const
std
::
set
<
Input
*>&
get_inputs
()
const
{
return
m_inputs
;
}
const
Tensor
&
get_tensor
()
const
;
const
Tensor
&
get_tensor
()
const
;
Tensor
&
get_tensor
();
Tensor
&
get_tensor
();
/// @return the tensor view type for the connected output
std
::
shared_ptr
<
const
TensorViewType
>
get_tensor_view_type
()
const
{
return
get_tensor_view
()
->
get_tensor_view_type
();
}
protected
:
protected
:
Node
*
m_node
;
Node
*
m_node
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment