Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
bc3c70df
Commit
bc3c70df
authored
Oct 11, 2017
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Basic autodiff for + and *
parent
4bec2307
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
463 additions
and
20 deletions
+463
-20
CMakeLists.txt
src/ngraph/CMakeLists.txt
+3
-0
adjoints.cpp
src/ngraph/autodiff/adjoints.cpp
+170
-0
adjoints.hpp
src/ngraph/autodiff/adjoints.hpp
+54
-0
function.cpp
src/ngraph/function.cpp
+1
-1
function.hpp
src/ngraph/function.hpp
+3
-3
node.cpp
src/ngraph/node.cpp
+43
-0
node.hpp
src/ngraph/node.hpp
+17
-10
add.cpp
src/ngraph/ops/add.cpp
+25
-0
add.hpp
src/ngraph/ops/add.hpp
+3
-0
multiply.cpp
src/ngraph/ops/multiply.cpp
+25
-0
multiply.hpp
src/ngraph/ops/multiply.hpp
+9
-6
parameter.hpp
src/ngraph/ops/parameter.hpp
+5
-0
CMakeLists.txt
test/CMakeLists.txt
+1
-0
autodiff.cpp
test/autodiff.cpp
+104
-0
No files found.
src/ngraph/CMakeLists.txt
View file @
bc3c70df
...
...
@@ -12,6 +12,7 @@
# limitations under the License.
set
(
SRC
autodiff/adjoints.cpp
descriptor/input.cpp
descriptor/layout/dense_tensor_view_layout.cpp
descriptor/layout/tensor_view_layout.cpp
...
...
@@ -23,6 +24,7 @@ set (SRC
function.cpp
log.cpp
node.cpp
ops/add.cpp
ops/binary_elementwise_arithmetic.cpp
ops/binary_elementwise_builtin.cpp
ops/binary_elementwise_comparison.cpp
...
...
@@ -33,6 +35,7 @@ set (SRC
ops/dot.cpp
ops/function_call.cpp
ops/get_tuple_element.cpp
ops/multiply.cpp
ops/op.cpp
ops/parameter.cpp
ops/reduce.cpp
...
...
src/ngraph/autodiff/adjoints.cpp
0 → 100644
View file @
bc3c70df
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cassert>
#include <list>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/types/type.hpp"
using
namespace
ngraph
;
/// @brief Make a zero matching a value type.
std
::
shared_ptr
<
Node
>
make_zero
(
const
std
::
shared_ptr
<
const
ValueType
>&
value_type
);
std
::
shared_ptr
<
Node
>
make_zero
(
const
std
::
shared_ptr
<
const
TensorViewType
>&
tensor_view_type
)
{
std
::
shared_ptr
<
Node
>
zero
=
std
::
make_shared
<
op
::
Float32ScalarConstant
>
(
0.0
);
std
::
shared_ptr
<
const
TensorViewType
>
zero_type
=
std
::
dynamic_pointer_cast
<
const
TensorViewType
>
(
zero
->
get_value_type
());
if
(
zero_type
->
get_element_type
()
!=
tensor_view_type
->
get_element_type
())
{
zero
=
std
::
make_shared
<
op
::
Convert
>
(
zero
,
tensor_view_type
->
get_element_type
());
}
const
Shape
&
shape
=
tensor_view_type
->
get_shape
();
if
(
shape
.
size
()
>
0
)
{
AxisSet
axes
;
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
i
++
)
{
axes
.
insert
(
i
);
}
zero
=
std
::
make_shared
<
op
::
Broadcast
>
(
zero
,
shape
,
axes
);
}
return
zero
;
}
std
::
shared_ptr
<
Node
>
make_zero
(
const
std
::
shared_ptr
<
const
TupleType
>&
tuple_type
)
{
std
::
vector
<
std
::
shared_ptr
<
Node
>>
elements
;
for
(
auto
&
value_type
:
tuple_type
->
get_element_types
())
{
elements
.
push_back
(
make_zero
(
value_type
));
}
return
std
::
make_shared
<
op
::
Tuple
>
(
elements
);
}
std
::
shared_ptr
<
Node
>
make_zero
(
const
std
::
shared_ptr
<
const
ValueType
>&
value_type
)
{
std
::
shared_ptr
<
const
TensorViewType
>
tensor_view_type
=
std
::
dynamic_pointer_cast
<
const
TensorViewType
>
(
value_type
);
if
(
nullptr
!=
tensor_view_type
)
{
return
(
make_zero
(
tensor_view_type
));
}
std
::
shared_ptr
<
const
TupleType
>
tuple_type
=
std
::
dynamic_pointer_cast
<
const
TupleType
>
(
value_type
);
if
(
nullptr
!=
tuple_type
)
{
return
make_zero
(
tuple_type
);
}
// Should be impossible
throw
ngraph_error
(
"Unknown value type"
);
}
autodiff
::
Adjoints
::
Adjoints
(
const
std
::
shared_ptr
<
Node
>&
y
,
const
std
::
shared_ptr
<
Node
>&
c
)
{
// Pass 1 determines which nodes contribute to y as well as setting up a reverse
// topological sort.
// Number of nodes that use the a node's value
std
::
unordered_map
<
std
::
shared_ptr
<
Node
>
,
size_t
>
parent_counts
;
// Nodes that have been processed
std
::
unordered_set
<
std
::
shared_ptr
<
Node
>>
visited_nodes
;
// Nodes we should check
std
::
list
<
std
::
shared_ptr
<
Node
>>
nodes_to_check
;
nodes_to_check
.
push_front
(
y
);
while
(
nodes_to_check
.
size
()
>
0
)
{
auto
node
=
nodes_to_check
.
front
();
nodes_to_check
.
pop_front
();
if
(
visited_nodes
.
count
(
node
)
!=
0
)
{
continue
;
}
for
(
auto
arg
:
node
->
get_arguments
())
{
auto
count_it
=
parent_counts
.
find
(
arg
);
if
(
count_it
==
parent_counts
.
end
())
{
parent_counts
[
arg
]
=
1
;
nodes_to_check
.
push_front
(
arg
);
}
else
{
parent_counts
[
arg
]
++
;
}
}
visited_nodes
.
insert
(
node
);
}
// Second pass visits the nodes so that all users of a node's value are visited
// before a node is visited.
m_adjoint_map
[
y
.
get
()]
=
c
;
nodes_to_check
.
push_front
(
y
);
while
(
nodes_to_check
.
size
()
>
0
)
{
auto
node
=
nodes_to_check
.
front
();
nodes_to_check
.
pop_front
();
// Look for nodes that will be available when this node is done
for
(
auto
arg
:
node
->
get_arguments
())
{
auto
count_it
=
parent_counts
.
find
(
arg
);
count_it
->
second
--
;
if
(
0
==
count_it
->
second
)
{
nodes_to_check
.
push_front
(
arg
);
}
}
node
->
generate_adjoints
(
*
this
,
m_adjoint_map
.
at
(
node
.
get
()));
}
}
std
::
shared_ptr
<
Node
>
autodiff
::
Adjoints
::
get
(
const
std
::
shared_ptr
<
Node
>&
x
)
{
auto
adjoint_it
=
m_adjoint_map
.
find
(
x
.
get
());
if
(
m_adjoint_map
.
end
()
==
adjoint_it
)
{
auto
result
=
make_zero
(
x
->
get_value_type
());
adjoint_it
=
m_adjoint_map
.
insert
(
std
::
make_tuple
(
x
.
get
(),
result
)).
first
;
}
return
adjoint_it
->
second
;
}
void
autodiff
::
Adjoints
::
add_delta
(
const
std
::
shared_ptr
<
Node
>&
x
,
const
std
::
shared_ptr
<
Node
>&
delta
)
{
assert
(
*
x
->
get_value_type
()
==
*
delta
->
get_value_type
());
auto
adjoint_it
=
m_adjoint_map
.
find
(
x
.
get
());
if
(
m_adjoint_map
.
end
()
==
adjoint_it
)
{
m_adjoint_map
.
insert
(
std
::
make_tuple
(
x
.
get
(),
delta
));
}
else
{
m_adjoint_map
.
insert
(
std
::
make_tuple
(
x
.
get
(),
std
::
make_shared
<
op
::
Add
>
(
adjoint_it
->
second
,
delta
)));
}
}
src/ngraph/autodiff/adjoints.hpp
0 → 100644
View file @
bc3c70df
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <unordered_map>
namespace
ngraph
{
class
Node
;
namespace
autodiff
{
class
Adjoints
{
public
:
/// @brief (dy/dx)(c) for all x used to compute y
///
/// @param y The dependent value
/// @param c An expression for where to evaluate the derivatives
Adjoints
(
const
std
::
shared_ptr
<
Node
>&
y
,
const
std
::
shared_ptr
<
Node
>&
c
);
Adjoints
(
const
Adjoints
&
adjoints
)
=
default
;
Adjoints
&
operator
=
(
const
Adjoints
&
adjoints
)
=
default
;
Adjoints
()
=
default
;
/// @brief (dy/dx)(c)
///
/// @param x The node whose adjoint is desired.
std
::
shared_ptr
<
Node
>
get
(
const
std
::
shared_ptr
<
Node
>&
x
);
/// @brief Add a backprop contribution to x's adjoint
///
/// @param x The adjoint node
/// @param delta A backprop contribution
void
add_delta
(
const
std
::
shared_ptr
<
Node
>&
x
,
const
std
::
shared_ptr
<
Node
>&
delta
);
protected
:
std
::
unordered_map
<
Node
*
,
std
::
shared_ptr
<
Node
>>
m_adjoint_map
;
};
}
}
src/ngraph/function.cpp
View file @
bc3c70df
...
...
@@ -23,7 +23,7 @@ using namespace ngraph;
atomic
<
size_t
>
Function
::
m_next_instance_id
(
0
);
Function
::
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
shared_ptr
<
ValueType
>&
result_type
,
const
std
::
shared_ptr
<
const
ValueType
>&
result_type
,
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
,
const
std
::
string
&
name
)
:
m_result
(
result
)
...
...
src/ngraph/function.hpp
View file @
bc3c70df
...
...
@@ -35,7 +35,7 @@ namespace ngraph
{
public
:
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
shared_ptr
<
ValueType
>&
result_type
,
const
std
::
shared_ptr
<
const
ValueType
>&
result_type
,
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
parameters
,
const
std
::
string
&
name
=
""
);
...
...
@@ -44,7 +44,7 @@ namespace ngraph
{
return
m_parameters
;
}
const
std
::
shared_ptr
<
ValueType
>
get_result_type
()
const
{
return
m_result_type
;
}
const
std
::
shared_ptr
<
const
ValueType
>
get_result_type
()
const
{
return
m_result_type
;
}
std
::
string
get_name
()
const
;
void
set_name
(
const
std
::
string
&
name
);
std
::
list
<
Node
*>&
get_ops
();
...
...
@@ -60,7 +60,7 @@ namespace ngraph
std
::
shared_ptr
<
Node
>
m_result
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
op
::
Parameter
>>
m_parameters
;
std
::
string
m_name
;
std
::
shared_ptr
<
ValueType
>
m_result_type
;
std
::
shared_ptr
<
const
ValueType
>
m_result_type
;
bool
m_ordered_ops_valid
;
std
::
list
<
Node
*>
m_ordered_ops
;
std
::
list
<
Node
*>
m_ops
;
...
...
src/ngraph/node.cpp
View file @
bc3c70df
...
...
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/ngraph.hpp"
using
namespace
std
;
...
...
@@ -32,6 +33,16 @@ Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType>
}
}
Node
::
Node
()
:
Node
({},
nullptr
)
{
}
Node
::
Node
(
std
::
shared_ptr
<
ValueType
>
value_type
)
:
Node
({},
value_type
)
{
}
Node
::~
Node
()
{
}
...
...
@@ -51,6 +62,24 @@ void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type)
}
}
std
::
shared_ptr
<
const
ValueType
>
Node
::
get_value_type
()
{
if
(
nullptr
==
m_value_type
)
{
propagate_types
();
}
return
m_value_type
;
}
const
std
::
shared_ptr
<
const
ValueType
>
Node
::
get_value_type
()
const
{
if
(
nullptr
==
m_value_type
)
{
const_cast
<
Node
*>
(
this
)
->
propagate_types
();
}
return
m_value_type
;
}
void
Node
::
assign_tensors
()
{
vector
<
std
::
shared_ptr
<
const
TensorViewType
>>
tensor_view_types
;
...
...
@@ -130,6 +159,20 @@ void Node::set_name(const string& name)
}
}
std
::
shared_ptr
<
Node
>
Node
::
backwards_derivative
(
const
std
::
shared_ptr
<
Node
>&
x
,
const
std
::
shared_ptr
<
Node
>&
c
)
{
auto
adjoints_it
=
m_adjoint_map
.
find
(
c
.
get
());
if
(
adjoints_it
==
m_adjoint_map
.
end
())
{
adjoints_it
=
m_adjoint_map
.
insert
(
std
::
make_tuple
(
c
.
get
(),
autodiff
::
Adjoints
(
shared_from_this
(),
c
)))
.
first
;
}
return
adjoints_it
->
second
.
get
(
x
);
}
namespace
ngraph
{
ostream
&
operator
<<
(
ostream
&
out
,
const
Node
&
node
)
...
...
src/ngraph/node.hpp
View file @
bc3c70df
...
...
@@ -15,13 +15,16 @@
#pragma once
#include <atomic>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <iostream>
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/common.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
...
...
@@ -35,20 +38,20 @@ namespace ngraph
/// view or a (possibly empty) tuple of values.
class
Node
:
public
std
::
enable_shared_from_this
<
Node
>
{
friend
class
autodiff
::
Adjoints
;
protected
:
Node
(
const
Nodes
&
arguments
,
std
::
shared_ptr
<
ValueType
>
value_type
=
nullptr
);
Node
()
:
Node
({},
nullptr
)
{
}
Node
()
;
Node
(
std
::
shared_ptr
<
ValueType
>
value_type
);
virtual
~
Node
();
Node
(
std
::
shared_ptr
<
ValueType
>
value_type
)
:
Node
({},
value_type
)
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
std
::
shared_ptr
<
Node
>&
delta
)
{
}
virtual
~
Node
();
public
:
/// The class name, must not contain spaces
virtual
std
::
string
description
()
const
=
0
;
...
...
@@ -76,8 +79,8 @@ namespace ngraph
return
typeid
(
*
this
)
==
typeid
(
*
n
);
}
std
::
shared_ptr
<
const
ValueType
>
get_value_type
()
{
return
m_value_type
;
}
const
std
::
shared_ptr
<
const
ValueType
>
get_value_type
()
const
{
return
m_value_type
;
}
std
::
shared_ptr
<
const
ValueType
>
get_value_type
()
;
const
std
::
shared_ptr
<
const
ValueType
>
get_value_type
()
const
;
void
set_value_type
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
{
m_value_type
=
std
::
make_shared
<
TensorViewType
>
(
element_type
,
shape
);
...
...
@@ -109,6 +112,9 @@ namespace ngraph
std
::
unordered_set
<
descriptor
::
Tensor
*>
liveness_new_list
;
std
::
unordered_set
<
descriptor
::
Tensor
*>
liveness_free_list
;
std
::
shared_ptr
<
Node
>
backwards_derivative
(
const
std
::
shared_ptr
<
Node
>&
x
,
const
std
::
shared_ptr
<
Node
>&
c
);
protected
:
Nodes
m_arguments
;
std
::
shared_ptr
<
const
ValueType
>
m_value_type
;
...
...
@@ -119,5 +125,6 @@ namespace ngraph
std
::
deque
<
descriptor
::
Input
>
m_inputs
;
std
::
deque
<
descriptor
::
Output
>
m_outputs
;
bool
m_is_output
;
std
::
unordered_map
<
Node
*
,
autodiff
::
Adjoints
>
m_adjoint_map
;
};
}
src/ngraph/ops/add.cpp
0 → 100644
View file @
bc3c70df
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ops/add.hpp"
void
ngraph
::
op
::
Add
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
std
::
shared_ptr
<
Node
>&
delta
)
{
auto
x
=
m_arguments
[
0
];
auto
y
=
m_arguments
[
1
];
adjoints
.
add_delta
(
x
,
delta
);
adjoints
.
add_delta
(
y
,
delta
);
}
src/ngraph/ops/add.hpp
View file @
bc3c70df
...
...
@@ -28,6 +28,9 @@ namespace ngraph
{
}
virtual
std
::
string
description
()
const
override
{
return
"Add"
;
}
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
std
::
shared_ptr
<
Node
>&
delta
)
override
;
};
}
...
...
src/ngraph/ops/multiply.cpp
0 → 100644
View file @
bc3c70df
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ops/multiply.hpp"
void
ngraph
::
op
::
Multiply
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
std
::
shared_ptr
<
Node
>&
delta
)
{
auto
x
=
m_arguments
[
0
];
auto
y
=
m_arguments
[
1
];
adjoints
.
add_delta
(
x
,
delta
*
y
);
adjoints
.
add_delta
(
y
,
x
*
delta
);
}
src/ngraph/ops/multiply.hpp
View file @
bc3c70df
...
...
@@ -29,12 +29,15 @@ namespace ngraph
}
virtual
std
::
string
description
()
const
override
{
return
"Multiply"
;
}
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
std
::
shared_ptr
<
Node
>&
delta
)
override
;
};
}
};
}
inline
std
::
shared_ptr
<
ngraph
::
Node
>
operator
*
(
const
std
::
shared_ptr
<
ngraph
::
Node
>
arg0
,
const
std
::
shared_ptr
<
ngraph
::
Node
>
arg1
)
{
return
std
::
make_shared
<
ngraph
::
op
::
Multiply
>
(
arg0
,
arg1
);
}
inline
std
::
shared_ptr
<
ngraph
::
Node
>
operator
*
(
const
std
::
shared_ptr
<
ngraph
::
Node
>
arg0
,
const
std
::
shared_ptr
<
ngraph
::
Node
>
arg1
)
{
return
std
::
make_shared
<
ngraph
::
op
::
Multiply
>
(
arg0
,
arg1
);
}
src/ngraph/ops/parameter.hpp
View file @
bc3c70df
...
...
@@ -36,6 +36,11 @@ namespace ngraph
// It is an error to try to associate a parameter with more than one function.
void
assign_function
(
Function
*
function
,
size_t
index
);
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
std
::
shared_ptr
<
Node
>&
delta
)
override
{
}
public
:
Parameter
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
=
nullptr
);
Parameter
(
const
ngraph
::
element
::
Type
&
element_type
,
const
Shape
&
shape
);
...
...
test/CMakeLists.txt
View file @
bc3c70df
...
...
@@ -22,6 +22,7 @@ include_directories(
)
set
(
SRC
autodiff.cpp
build_graph.cpp
eigen.cpp
element_type.cpp
...
...
test/autodiff.cpp
0 → 100644
View file @
bc3c70df
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using
namespace
std
;
using
namespace
ngraph
;
TEST
(
backwards
,
parameter
)
{
auto
shape
=
Shape
{
2
,
3
};
auto
X0
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Y
=
X0
;
auto
C
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
DYDX0
=
Y
->
backwards_derivative
(
X0
,
C
);
ASSERT_EQ
(
DYDX0
,
C
);
}
TEST
(
backwards
,
add
)
{
auto
shape
=
Shape
{
2
,
3
};
auto
X0
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
X1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Y
=
X0
+
X1
;
auto
C
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
DYDX0
=
Y
->
backwards_derivative
(
X0
,
C
);
auto
DYDX1
=
Y
->
backwards_derivative
(
X1
,
C
);
ASSERT_EQ
(
DYDX0
,
C
);
ASSERT_EQ
(
DYDX1
,
C
);
}
// Returns (dy/(dXs))(C, Xs)
shared_ptr
<
Function
>
derivative
(
const
std
::
shared_ptr
<
Node
>&
Y
,
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>
Xs
)
{
auto
Y_tv_type
=
dynamic_pointer_cast
<
const
TensorViewType
>
(
Y
->
get_value_type
());
auto
C
=
make_shared
<
op
::
Parameter
>
(
Y_tv_type
->
get_element_type
(),
Y_tv_type
->
get_shape
());
std
::
vector
<
std
::
shared_ptr
<
Node
>>
dYdXs
(
Xs
.
size
());
transform
(
Xs
.
begin
(),
Xs
.
end
(),
dYdXs
.
begin
(),
[
C
,
Y
](
const
std
::
shared_ptr
<
Node
>&
X
)
{
return
Y
->
backwards_derivative
(
X
,
C
);
});
auto
result
=
make_shared
<
op
::
Tuple
>
(
dYdXs
);
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>
args
;
args
.
push_back
(
C
);
args
.
insert
(
args
.
end
(),
Xs
.
begin
(),
Xs
.
end
());
return
make_shared
<
Function
>
(
result
,
result
->
get_value_type
(),
args
);
}
TEST
(
backwards
,
multiply
)
{
auto
shape
=
Shape
{
2
,
3
};
auto
X0
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
X1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Y
=
X0
*
X1
;
auto
f
=
derivative
(
Y
,
{
X0
,
X1
});
auto
manager
=
runtime
::
Manager
::
get
(
"NGVM"
);
auto
external
=
manager
->
compile
(
f
);
auto
backend
=
manager
->
allocate_backend
();
auto
cf
=
backend
->
make_call_frame
(
external
);
auto
x0
=
backend
->
make_parameterized_tensor_view
<
element
::
Float32
>
(
shape
);
*
x0
=
vector
<
float
>
{
1
,
3
,
5
,
7
,
9
,
11
};
auto
x1
=
backend
->
make_parameterized_tensor_view
<
element
::
Float32
>
(
shape
);
*
x1
=
vector
<
float
>
{
0
,
2
,
4
,
6
,
8
,
10
};
auto
c
=
backend
->
make_parameterized_tensor_view
<
element
::
Float32
>
(
shape
);
*
c
=
vector
<
float
>
{
0
,
0
,
0
,
0
,
0
,
0
};
auto
dx0
=
backend
->
make_parameterized_tensor_view
<
element
::
Float32
>
(
shape
);
auto
dx1
=
backend
->
make_parameterized_tensor_view
<
element
::
Float32
>
(
shape
);
auto
dx
=
backend
->
make_tuple
({
dx0
,
dx1
});
size_t
n
=
x0
->
get_vector
().
size
();
vector
<
float
>
dx0_correct
(
n
);
vector
<
float
>
dx1_correct
(
n
);
for
(
size_t
i
=
0
;
i
<
n
;
i
++
)
{
c
->
get_vector
().
assign
(
n
,
0
);
c
->
get_vector
()[
i
]
=
1
;
(
*
cf
)({
c
,
x0
,
x1
},
{
dx
});
dx0_correct
.
assign
(
n
,
0
);
dx1_correct
.
assign
(
n
,
0
);
dx0_correct
[
i
]
=
x1
->
get_vector
()[
i
];
dx1_correct
[
i
]
=
x0
->
get_vector
()[
i
];
ASSERT_EQ
(
dx0
->
get_vector
(),
dx0_correct
);
ASSERT_EQ
(
dx1
->
get_vector
(),
dx1_correct
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment