Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
3b49dd1a
Commit
3b49dd1a
authored
Jun 22, 2018
by
Matthew Brookhart
Committed by
Scott Cyphers
Jun 22, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
refactor cache_prop to reuse bprop inputs (#1134)
parent
b9a77a9d
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
19 deletions
+19
-19
util.cpp
src/ngraph/util.cpp
+15
-12
util.hpp
src/ngraph/util.hpp
+1
-3
cpu_fusion.cpp
test/cpu_fusion.cpp
+1
-1
backprop_derivative.hpp
test/util/autodiff/backprop_derivative.hpp
+2
-3
No files found.
src/ngraph/util.cpp
View file @
3b49dd1a
...
...
@@ -185,8 +185,7 @@ size_t ngraph::round_up(size_t size, size_t alignment)
}
ngraph
::
FpropCache
ngraph
::
cache_fprop
(
std
::
shared_ptr
<
ngraph
::
Function
>
fprop
,
std
::
shared_ptr
<
ngraph
::
Function
>
bprop
,
std
::
vector
<
std
::
shared_ptr
<
Node
>>
adjoints
)
std
::
shared_ptr
<
ngraph
::
Function
>
bprop
)
{
using
namespace
ngraph
;
...
...
@@ -208,17 +207,21 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// shape and element type as the nodes in fprop
FpropCache
fprop_cache
;
fprop_cache
.
node_param_map
=
std
::
make_shared
<
NodeMap
>
();
ngraph
::
traverse_nodes
(
fprop
,
[
&
fprop_cache
,
&
in_bprop
](
std
::
shared_ptr
<
Node
>
node
)
{
if
(
in_bprop
.
count
(
node
)
!=
0
)
{
fprop_cache
.
node_param_map
->
add
(
node
,
std
::
make_shared
<
op
::
Parameter
>
(
node
->
get_element_type
(),
node
->
get_shape
()));
}
});
auto
bprop_inputs
=
bprop
->
get_parameters
();
ngraph
::
traverse_nodes
(
fprop
,
[
&
fprop_cache
,
&
in_bprop
,
&
bprop_inputs
](
std
::
shared_ptr
<
Node
>
node
)
{
if
(
in_bprop
.
count
(
node
)
!=
0
&&
std
::
find
(
bprop_inputs
.
begin
(),
bprop_inputs
.
end
(),
node
)
==
bprop_inputs
.
end
())
{
fprop_cache
.
node_param_map
->
add
(
node
,
std
::
make_shared
<
op
::
Parameter
>
(
node
->
get_element_type
(),
node
->
get_shape
()));
}
});
// Find all of the nodes that are intermediate values of fprop and used in
// bprop
// and store those nodes that aren't needed in bprop
// bprop and store those nodes that aren't needed in bprop
std
::
vector
<
std
::
shared_ptr
<
Node
>>
unused_nodes
;
for
(
auto
kv
:
fprop_cache
.
node_param_map
->
get_node_map
())
{
...
...
@@ -262,7 +265,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// get clone bprop parameters
op
::
ParameterVector
bprop_input_params
;
for
(
auto
param
:
adjoin
ts
)
for
(
auto
param
:
bprop_inpu
ts
)
{
bprop_input_params
.
push_back
(
std
::
dynamic_pointer_cast
<
op
::
Parameter
>
(
fprop_cache
.
node_param_map
->
get
(
param
)));
...
...
src/ngraph/util.hpp
View file @
3b49dd1a
...
...
@@ -259,7 +259,5 @@ namespace ngraph
* The last argument is the adjoints coming into the bprop function, the output
* bprop function will have these nodes as the first N input parameters
**/
FpropCache
cache_fprop
(
std
::
shared_ptr
<
Function
>
fprop
,
std
::
shared_ptr
<
Function
>
bprop
,
std
::
vector
<
std
::
shared_ptr
<
Node
>>
adjoints
);
FpropCache
cache_fprop
(
std
::
shared_ptr
<
Function
>
fprop
,
std
::
shared_ptr
<
Function
>
bprop
);
}
// end namespace ngraph
test/cpu_fusion.cpp
View file @
3b49dd1a
...
...
@@ -1533,7 +1533,7 @@ TEST(cpu_fusion, maxpool_with_indices_in_mxnet)
auto
maybe_bf
=
bfa
.
first
;
auto
adjoints
=
bfa
.
second
;
optimize_graph
(
f
,
maybe_bf
);
auto
fprop_cache
=
ngraph
::
cache_fprop
(
f
,
maybe_bf
,
adjoints
);
auto
fprop_cache
=
ngraph
::
cache_fprop
(
f
,
maybe_bf
);
auto
mpwi_bprop
=
fprop_cache
.
bprop
->
get_results
().
at
(
0
)
->
get_argument
(
0
);
ASSERT_TRUE
(
std
::
dynamic_pointer_cast
<
op
::
Parameter
>
(
mpwi_bprop
->
get_argument
(
0
)));
...
...
test/util/autodiff/backprop_derivative.hpp
View file @
3b49dd1a
...
...
@@ -166,15 +166,14 @@ namespace ngraph
// create fprop cache
// creates modified forward function -> (y, cached) = f(x)
// creates modified backward function -> df/dX* = f'(c, cached)
auto
fprop_cache
=
cache_fprop
(
f
,
df
,
{
c_param
}
);
auto
fprop_cache
=
cache_fprop
(
f
,
df
);
// (y, cached) arguments
std
::
vector
<
std
::
shared_ptr
<
runtime
::
TensorView
>>
mod_f_output_args
;
mod_f_output_args
.
push_back
(
backend
->
create_tensor
<
T
>
(
y_shape
));
// (c, cached) arguments
std
::
vector
<
std
::
shared_ptr
<
runtime
::
TensorView
>>
mod_df_input_args
;
mod_df_input_args
.
push_back
(
c_arg
);
std
::
vector
<
std
::
shared_ptr
<
runtime
::
TensorView
>>
mod_df_input_args
=
df_input_args
;
// add cached nodes to both modified f output and modified f' input arguments
for
(
auto
node
:
fprop_cache
.
fprop_output_nodes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment