Commit bc3c70df authored by Scott Cyphers's avatar Scott Cyphers

Basic autodiff for + and *

parent 4bec2307
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# limitations under the License. # limitations under the License.
set (SRC set (SRC
autodiff/adjoints.cpp
descriptor/input.cpp descriptor/input.cpp
descriptor/layout/dense_tensor_view_layout.cpp descriptor/layout/dense_tensor_view_layout.cpp
descriptor/layout/tensor_view_layout.cpp descriptor/layout/tensor_view_layout.cpp
...@@ -23,6 +24,7 @@ set (SRC ...@@ -23,6 +24,7 @@ set (SRC
function.cpp function.cpp
log.cpp log.cpp
node.cpp node.cpp
ops/add.cpp
ops/binary_elementwise_arithmetic.cpp ops/binary_elementwise_arithmetic.cpp
ops/binary_elementwise_builtin.cpp ops/binary_elementwise_builtin.cpp
ops/binary_elementwise_comparison.cpp ops/binary_elementwise_comparison.cpp
...@@ -33,6 +35,7 @@ set (SRC ...@@ -33,6 +35,7 @@ set (SRC
ops/dot.cpp ops/dot.cpp
ops/function_call.cpp ops/function_call.cpp
ops/get_tuple_element.cpp ops/get_tuple_element.cpp
ops/multiply.cpp
ops/op.cpp ops/op.cpp
ops/parameter.cpp ops/parameter.cpp
ops/reduce.cpp ops/reduce.cpp
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cassert>
#include <list>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/types/type.hpp"
using namespace ngraph;
/// @brief Make a zero matching a value type.
std::shared_ptr<Node> make_zero(const std::shared_ptr<const ValueType>& value_type);
std::shared_ptr<Node> make_zero(const std::shared_ptr<const TensorViewType>& tensor_view_type)
{
std::shared_ptr<Node> zero = std::make_shared<op::Float32ScalarConstant>(0.0);
std::shared_ptr<const TensorViewType> zero_type =
std::dynamic_pointer_cast<const TensorViewType>(zero->get_value_type());
if (zero_type->get_element_type() != tensor_view_type->get_element_type())
{
zero = std::make_shared<op::Convert>(zero, tensor_view_type->get_element_type());
}
const Shape& shape = tensor_view_type->get_shape();
if (shape.size() > 0)
{
AxisSet axes;
for (size_t i = 0; i < shape.size(); i++)
{
axes.insert(i);
}
zero = std::make_shared<op::Broadcast>(zero, shape, axes);
}
return zero;
}
std::shared_ptr<Node> make_zero(const std::shared_ptr<const TupleType>& tuple_type)
{
std::vector<std::shared_ptr<Node>> elements;
for (auto& value_type : tuple_type->get_element_types())
{
elements.push_back(make_zero(value_type));
}
return std::make_shared<op::Tuple>(elements);
}
std::shared_ptr<Node> make_zero(const std::shared_ptr<const ValueType>& value_type)
{
std::shared_ptr<const TensorViewType> tensor_view_type =
std::dynamic_pointer_cast<const TensorViewType>(value_type);
if (nullptr != tensor_view_type)
{
return (make_zero(tensor_view_type));
}
std::shared_ptr<const TupleType> tuple_type =
std::dynamic_pointer_cast<const TupleType>(value_type);
if (nullptr != tuple_type)
{
return make_zero(tuple_type);
}
// Should be impossible
throw ngraph_error("Unknown value type");
}
autodiff::Adjoints::Adjoints(const std::shared_ptr<Node>& y, const std::shared_ptr<Node>& c)
{
// Pass 1 determines which nodes contribute to y as well as setting up a reverse
// topological sort.
// Number of nodes that use the a node's value
std::unordered_map<std::shared_ptr<Node>, size_t> parent_counts;
// Nodes that have been processed
std::unordered_set<std::shared_ptr<Node>> visited_nodes;
// Nodes we should check
std::list<std::shared_ptr<Node>> nodes_to_check;
nodes_to_check.push_front(y);
while (nodes_to_check.size() > 0)
{
auto node = nodes_to_check.front();
nodes_to_check.pop_front();
if (visited_nodes.count(node) != 0)
{
continue;
}
for (auto arg : node->get_arguments())
{
auto count_it = parent_counts.find(arg);
if (count_it == parent_counts.end())
{
parent_counts[arg] = 1;
nodes_to_check.push_front(arg);
}
else
{
parent_counts[arg]++;
}
}
visited_nodes.insert(node);
}
// Second pass visits the nodes so that all users of a node's value are visited
// before a node is visited.
m_adjoint_map[y.get()] = c;
nodes_to_check.push_front(y);
while (nodes_to_check.size() > 0)
{
auto node = nodes_to_check.front();
nodes_to_check.pop_front();
// Look for nodes that will be available when this node is done
for (auto arg : node->get_arguments())
{
auto count_it = parent_counts.find(arg);
count_it->second--;
if (0 == count_it->second)
{
nodes_to_check.push_front(arg);
}
}
node->generate_adjoints(*this, m_adjoint_map.at(node.get()));
}
}
std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x)
{
auto adjoint_it = m_adjoint_map.find(x.get());
if (m_adjoint_map.end() == adjoint_it)
{
auto result = make_zero(x->get_value_type());
adjoint_it = m_adjoint_map.insert(std::make_tuple(x.get(), result)).first;
}
return adjoint_it->second;
}
void autodiff::Adjoints::add_delta(const std::shared_ptr<Node>& x,
const std::shared_ptr<Node>& delta)
{
assert(*x->get_value_type() == *delta->get_value_type());
auto adjoint_it = m_adjoint_map.find(x.get());
if (m_adjoint_map.end() == adjoint_it)
{
m_adjoint_map.insert(std::make_tuple(x.get(), delta));
}
else
{
m_adjoint_map.insert(
std::make_tuple(x.get(), std::make_shared<op::Add>(adjoint_it->second, delta)));
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <unordered_map>
namespace ngraph
{
class Node;
namespace autodiff
{
class Adjoints
{
public:
/// @brief (dy/dx)(c) for all x used to compute y
///
/// @param y The dependent value
/// @param c An expression for where to evaluate the derivatives
Adjoints(const std::shared_ptr<Node>& y, const std::shared_ptr<Node>& c);
Adjoints(const Adjoints& adjoints) = default;
Adjoints& operator=(const Adjoints& adjoints) = default;
Adjoints() = default;
/// @brief (dy/dx)(c)
///
/// @param x The node whose adjoint is desired.
std::shared_ptr<Node> get(const std::shared_ptr<Node>& x);
/// @brief Add a backprop contribution to x's adjoint
///
/// @param x The adjoint node
/// @param delta A backprop contribution
void add_delta(const std::shared_ptr<Node>& x, const std::shared_ptr<Node>& delta);
protected:
std::unordered_map<Node*, std::shared_ptr<Node>> m_adjoint_map;
};
}
}
...@@ -23,7 +23,7 @@ using namespace ngraph; ...@@ -23,7 +23,7 @@ using namespace ngraph;
atomic<size_t> Function::m_next_instance_id(0); atomic<size_t> Function::m_next_instance_id(0);
Function::Function(const std::shared_ptr<Node>& result, Function::Function(const std::shared_ptr<Node>& result,
const std::shared_ptr<ValueType>& result_type, const std::shared_ptr<const ValueType>& result_type,
const std::vector<std::shared_ptr<op::Parameter>>& parameters, const std::vector<std::shared_ptr<op::Parameter>>& parameters,
const std::string& name) const std::string& name)
: m_result(result) : m_result(result)
......
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
{ {
public: public:
Function(const std::shared_ptr<Node>& result, Function(const std::shared_ptr<Node>& result,
const std::shared_ptr<ValueType>& result_type, const std::shared_ptr<const ValueType>& result_type,
const std::vector<std::shared_ptr<op::Parameter>>& parameters, const std::vector<std::shared_ptr<op::Parameter>>& parameters,
const std::string& name = ""); const std::string& name = "");
...@@ -44,7 +44,7 @@ namespace ngraph ...@@ -44,7 +44,7 @@ namespace ngraph
{ {
return m_parameters; return m_parameters;
} }
const std::shared_ptr<ValueType> get_result_type() const { return m_result_type; } const std::shared_ptr<const ValueType> get_result_type() const { return m_result_type; }
std::string get_name() const; std::string get_name() const;
void set_name(const std::string& name); void set_name(const std::string& name);
std::list<Node*>& get_ops(); std::list<Node*>& get_ops();
...@@ -60,7 +60,7 @@ namespace ngraph ...@@ -60,7 +60,7 @@ namespace ngraph
std::shared_ptr<Node> m_result; std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::op::Parameter>> m_parameters; std::vector<std::shared_ptr<ngraph::op::Parameter>> m_parameters;
std::string m_name; std::string m_name;
std::shared_ptr<ValueType> m_result_type; std::shared_ptr<const ValueType> m_result_type;
bool m_ordered_ops_valid; bool m_ordered_ops_valid;
std::list<Node*> m_ordered_ops; std::list<Node*> m_ordered_ops;
std::list<Node*> m_ops; std::list<Node*> m_ops;
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
...@@ -32,6 +33,16 @@ Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType> ...@@ -32,6 +33,16 @@ Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType>
} }
} }
Node::Node()
: Node({}, nullptr)
{
}
Node::Node(std::shared_ptr<ValueType> value_type)
: Node({}, value_type)
{
}
Node::~Node() Node::~Node()
{ {
} }
...@@ -51,6 +62,24 @@ void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type) ...@@ -51,6 +62,24 @@ void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type)
} }
} }
std::shared_ptr<const ValueType> Node::get_value_type()
{
if (nullptr == m_value_type)
{
propagate_types();
}
return m_value_type;
}
const std::shared_ptr<const ValueType> Node::get_value_type() const
{
if (nullptr == m_value_type)
{
const_cast<Node*>(this)->propagate_types();
}
return m_value_type;
}
void Node::assign_tensors() void Node::assign_tensors()
{ {
vector<std::shared_ptr<const TensorViewType>> tensor_view_types; vector<std::shared_ptr<const TensorViewType>> tensor_view_types;
...@@ -130,6 +159,20 @@ void Node::set_name(const string& name) ...@@ -130,6 +159,20 @@ void Node::set_name(const string& name)
} }
} }
std::shared_ptr<Node> Node::backwards_derivative(const std::shared_ptr<Node>& x,
const std::shared_ptr<Node>& c)
{
auto adjoints_it = m_adjoint_map.find(c.get());
if (adjoints_it == m_adjoint_map.end())
{
adjoints_it =
m_adjoint_map
.insert(std::make_tuple(c.get(), autodiff::Adjoints(shared_from_this(), c)))
.first;
}
return adjoints_it->second.get(x);
}
namespace ngraph namespace ngraph
{ {
ostream& operator<<(ostream& out, const Node& node) ostream& operator<<(ostream& out, const Node& node)
......
...@@ -15,13 +15,16 @@ ...@@ -15,13 +15,16 @@
#pragma once #pragma once
#include <atomic> #include <atomic>
#include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include <iostream> #include <iostream>
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/common.hpp" #include "ngraph/common.hpp"
#include "ngraph/descriptor/input.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
...@@ -35,20 +38,20 @@ namespace ngraph ...@@ -35,20 +38,20 @@ namespace ngraph
/// view or a (possibly empty) tuple of values. /// view or a (possibly empty) tuple of values.
class Node : public std::enable_shared_from_this<Node> class Node : public std::enable_shared_from_this<Node>
{ {
friend class autodiff::Adjoints;
protected: protected:
Node(const Nodes& arguments, std::shared_ptr<ValueType> value_type = nullptr); Node(const Nodes& arguments, std::shared_ptr<ValueType> value_type = nullptr);
Node() Node();
: Node({}, nullptr) Node(std::shared_ptr<ValueType> value_type);
{
} virtual ~Node();
Node(std::shared_ptr<ValueType> value_type) virtual void generate_adjoints(autodiff::Adjoints& adjoints,
: Node({}, value_type) const std::shared_ptr<Node>& delta)
{ {
} }
virtual ~Node();
public: public:
/// The class name, must not contain spaces /// The class name, must not contain spaces
virtual std::string description() const = 0; virtual std::string description() const = 0;
...@@ -76,8 +79,8 @@ namespace ngraph ...@@ -76,8 +79,8 @@ namespace ngraph
return typeid(*this) == typeid(*n); return typeid(*this) == typeid(*n);
} }
std::shared_ptr<const ValueType> get_value_type() { return m_value_type; } std::shared_ptr<const ValueType> get_value_type();
const std::shared_ptr<const ValueType> get_value_type() const { return m_value_type; } const std::shared_ptr<const ValueType> get_value_type() const;
void set_value_type(const element::Type& element_type, const Shape& shape) void set_value_type(const element::Type& element_type, const Shape& shape)
{ {
m_value_type = std::make_shared<TensorViewType>(element_type, shape); m_value_type = std::make_shared<TensorViewType>(element_type, shape);
...@@ -109,6 +112,9 @@ namespace ngraph ...@@ -109,6 +112,9 @@ namespace ngraph
std::unordered_set<descriptor::Tensor*> liveness_new_list; std::unordered_set<descriptor::Tensor*> liveness_new_list;
std::unordered_set<descriptor::Tensor*> liveness_free_list; std::unordered_set<descriptor::Tensor*> liveness_free_list;
std::shared_ptr<Node> backwards_derivative(const std::shared_ptr<Node>& x,
const std::shared_ptr<Node>& c);
protected: protected:
Nodes m_arguments; Nodes m_arguments;
std::shared_ptr<const ValueType> m_value_type; std::shared_ptr<const ValueType> m_value_type;
...@@ -119,5 +125,6 @@ namespace ngraph ...@@ -119,5 +125,6 @@ namespace ngraph
std::deque<descriptor::Input> m_inputs; std::deque<descriptor::Input> m_inputs;
std::deque<descriptor::Output> m_outputs; std::deque<descriptor::Output> m_outputs;
bool m_is_output; bool m_is_output;
std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map;
}; };
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ops/add.hpp"
void ngraph::op::Add::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
{
auto x = m_arguments[0];
auto y = m_arguments[1];
adjoints.add_delta(x, delta);
adjoints.add_delta(y, delta);
}
...@@ -28,6 +28,9 @@ namespace ngraph ...@@ -28,6 +28,9 @@ namespace ngraph
{ {
} }
virtual std::string description() const override { return "Add"; } virtual std::string description() const override { return "Add"; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
}; };
} }
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ops/multiply.hpp"
void ngraph::op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
{
auto x = m_arguments[0];
auto y = m_arguments[1];
adjoints.add_delta(x, delta * y);
adjoints.add_delta(y, x * delta);
}
...@@ -29,12 +29,15 @@ namespace ngraph ...@@ -29,12 +29,15 @@ namespace ngraph
} }
virtual std::string description() const override { return "Multiply"; } virtual std::string description() const override { return "Multiply"; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
}; };
} };
}
inline std::shared_ptr<ngraph::Node> operator*(const std::shared_ptr<ngraph::Node> arg0, inline std::shared_ptr<ngraph::Node> operator*(const std::shared_ptr<ngraph::Node> arg0,
const std::shared_ptr<ngraph::Node> arg1) const std::shared_ptr<ngraph::Node> arg1)
{ {
return std::make_shared<ngraph::op::Multiply>(arg0, arg1); return std::make_shared<ngraph::op::Multiply>(arg0, arg1);
}
} }
...@@ -36,6 +36,11 @@ namespace ngraph ...@@ -36,6 +36,11 @@ namespace ngraph
// It is an error to try to associate a parameter with more than one function. // It is an error to try to associate a parameter with more than one function.
void assign_function(Function* function, size_t index); void assign_function(Function* function, size_t index);
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override
{
}
public: public:
Parameter(const std::shared_ptr<ValueType>& value_type = nullptr); Parameter(const std::shared_ptr<ValueType>& value_type = nullptr);
Parameter(const ngraph::element::Type& element_type, const Shape& shape); Parameter(const ngraph::element::Type& element_type, const Shape& shape);
......
...@@ -22,6 +22,7 @@ include_directories( ...@@ -22,6 +22,7 @@ include_directories(
) )
set (SRC set (SRC
autodiff.cpp
build_graph.cpp build_graph.cpp
eigen.cpp eigen.cpp
element_type.cpp element_type.cpp
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
TEST(backwards, parameter)
{
auto shape = Shape{2, 3};
auto X0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Y = X0;
auto C = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto DYDX0 = Y->backwards_derivative(X0, C);
ASSERT_EQ(DYDX0, C);
}
TEST(backwards, add)
{
auto shape = Shape{2, 3};
auto X0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto X1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Y = X0 + X1;
auto C = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto DYDX0 = Y->backwards_derivative(X0, C);
auto DYDX1 = Y->backwards_derivative(X1, C);
ASSERT_EQ(DYDX0, C);
ASSERT_EQ(DYDX1, C);
}
// Returns (dy/(dXs))(C, Xs)
shared_ptr<Function> derivative(const std::shared_ptr<Node>& Y,
const std::vector<std::shared_ptr<op::Parameter>> Xs)
{
auto Y_tv_type = dynamic_pointer_cast<const TensorViewType>(Y->get_value_type());
auto C = make_shared<op::Parameter>(Y_tv_type->get_element_type(), Y_tv_type->get_shape());
std::vector<std::shared_ptr<Node>> dYdXs(Xs.size());
transform(Xs.begin(), Xs.end(), dYdXs.begin(), [C, Y](const std::shared_ptr<Node>& X) {
return Y->backwards_derivative(X, C);
});
auto result = make_shared<op::Tuple>(dYdXs);
std::vector<std::shared_ptr<op::Parameter>> args;
args.push_back(C);
args.insert(args.end(), Xs.begin(), Xs.end());
return make_shared<Function>(result, result->get_value_type(), args);
}
TEST(backwards, multiply)
{
auto shape = Shape{2, 3};
auto X0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto X1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Y = X0 * X1;
auto f = derivative(Y, {X0, X1});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
auto x0 = backend->make_parameterized_tensor_view<element::Float32>(shape);
*x0 = vector<float>{1, 3, 5, 7, 9, 11};
auto x1 = backend->make_parameterized_tensor_view<element::Float32>(shape);
*x1 = vector<float>{0, 2, 4, 6, 8, 10};
auto c = backend->make_parameterized_tensor_view<element::Float32>(shape);
*c = vector<float>{0, 0, 0, 0, 0, 0};
auto dx0 = backend->make_parameterized_tensor_view<element::Float32>(shape);
auto dx1 = backend->make_parameterized_tensor_view<element::Float32>(shape);
auto dx = backend->make_tuple({dx0, dx1});
size_t n = x0->get_vector().size();
vector<float> dx0_correct(n);
vector<float> dx1_correct(n);
for (size_t i = 0; i < n; i++)
{
c->get_vector().assign(n, 0);
c->get_vector()[i] = 1;
(*cf)({c, x0, x1}, {dx});
dx0_correct.assign(n, 0);
dx1_correct.assign(n, 0);
dx0_correct[i] = x1->get_vector()[i];
dx1_correct[i] = x0->get_vector()[i];
ASSERT_EQ(dx0->get_vector(), dx0_correct);
ASSERT_EQ(dx1->get_vector(), dx1_correct);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment