Commit bc3c70df authored by Scott Cyphers's avatar Scott Cyphers

Basic autodiff for + and *

parent 4bec2307
......@@ -12,6 +12,7 @@
# limitations under the License.
set (SRC
autodiff/adjoints.cpp
descriptor/input.cpp
descriptor/layout/dense_tensor_view_layout.cpp
descriptor/layout/tensor_view_layout.cpp
......@@ -23,6 +24,7 @@ set (SRC
function.cpp
log.cpp
node.cpp
ops/add.cpp
ops/binary_elementwise_arithmetic.cpp
ops/binary_elementwise_builtin.cpp
ops/binary_elementwise_comparison.cpp
......@@ -33,6 +35,7 @@ set (SRC
ops/dot.cpp
ops/function_call.cpp
ops/get_tuple_element.cpp
ops/multiply.cpp
ops/op.cpp
ops/parameter.cpp
ops/reduce.cpp
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cassert>
#include <list>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/types/type.hpp"
using namespace ngraph;
/// @brief Make a zero matching a value type.
std::shared_ptr<Node> make_zero(const std::shared_ptr<const ValueType>& value_type);
std::shared_ptr<Node> make_zero(const std::shared_ptr<const TensorViewType>& tensor_view_type)
{
std::shared_ptr<Node> zero = std::make_shared<op::Float32ScalarConstant>(0.0);
std::shared_ptr<const TensorViewType> zero_type =
std::dynamic_pointer_cast<const TensorViewType>(zero->get_value_type());
if (zero_type->get_element_type() != tensor_view_type->get_element_type())
{
zero = std::make_shared<op::Convert>(zero, tensor_view_type->get_element_type());
}
const Shape& shape = tensor_view_type->get_shape();
if (shape.size() > 0)
{
AxisSet axes;
for (size_t i = 0; i < shape.size(); i++)
{
axes.insert(i);
}
zero = std::make_shared<op::Broadcast>(zero, shape, axes);
}
return zero;
}
std::shared_ptr<Node> make_zero(const std::shared_ptr<const TupleType>& tuple_type)
{
std::vector<std::shared_ptr<Node>> elements;
for (auto& value_type : tuple_type->get_element_types())
{
elements.push_back(make_zero(value_type));
}
return std::make_shared<op::Tuple>(elements);
}
std::shared_ptr<Node> make_zero(const std::shared_ptr<const ValueType>& value_type)
{
std::shared_ptr<const TensorViewType> tensor_view_type =
std::dynamic_pointer_cast<const TensorViewType>(value_type);
if (nullptr != tensor_view_type)
{
return (make_zero(tensor_view_type));
}
std::shared_ptr<const TupleType> tuple_type =
std::dynamic_pointer_cast<const TupleType>(value_type);
if (nullptr != tuple_type)
{
return make_zero(tuple_type);
}
// Should be impossible
throw ngraph_error("Unknown value type");
}
autodiff::Adjoints::Adjoints(const std::shared_ptr<Node>& y, const std::shared_ptr<Node>& c)
{
// Pass 1 determines which nodes contribute to y as well as setting up a reverse
// topological sort.
// Number of nodes that use the a node's value
std::unordered_map<std::shared_ptr<Node>, size_t> parent_counts;
// Nodes that have been processed
std::unordered_set<std::shared_ptr<Node>> visited_nodes;
// Nodes we should check
std::list<std::shared_ptr<Node>> nodes_to_check;
nodes_to_check.push_front(y);
while (nodes_to_check.size() > 0)
{
auto node = nodes_to_check.front();
nodes_to_check.pop_front();
if (visited_nodes.count(node) != 0)
{
continue;
}
for (auto arg : node->get_arguments())
{
auto count_it = parent_counts.find(arg);
if (count_it == parent_counts.end())
{
parent_counts[arg] = 1;
nodes_to_check.push_front(arg);
}
else
{
parent_counts[arg]++;
}
}
visited_nodes.insert(node);
}
// Second pass visits the nodes so that all users of a node's value are visited
// before a node is visited.
m_adjoint_map[y.get()] = c;
nodes_to_check.push_front(y);
while (nodes_to_check.size() > 0)
{
auto node = nodes_to_check.front();
nodes_to_check.pop_front();
// Look for nodes that will be available when this node is done
for (auto arg : node->get_arguments())
{
auto count_it = parent_counts.find(arg);
count_it->second--;
if (0 == count_it->second)
{
nodes_to_check.push_front(arg);
}
}
node->generate_adjoints(*this, m_adjoint_map.at(node.get()));
}
}
std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x)
{
auto adjoint_it = m_adjoint_map.find(x.get());
if (m_adjoint_map.end() == adjoint_it)
{
auto result = make_zero(x->get_value_type());
adjoint_it = m_adjoint_map.insert(std::make_tuple(x.get(), result)).first;
}
return adjoint_it->second;
}
void autodiff::Adjoints::add_delta(const std::shared_ptr<Node>& x,
const std::shared_ptr<Node>& delta)
{
assert(*x->get_value_type() == *delta->get_value_type());
auto adjoint_it = m_adjoint_map.find(x.get());
if (m_adjoint_map.end() == adjoint_it)
{
m_adjoint_map.insert(std::make_tuple(x.get(), delta));
}
else
{
m_adjoint_map.insert(
std::make_tuple(x.get(), std::make_shared<op::Add>(adjoint_it->second, delta)));
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <unordered_map>
namespace ngraph
{
class Node;
namespace autodiff
{
class Adjoints
{
public:
/// @brief (dy/dx)(c) for all x used to compute y
///
/// @param y The dependent value
/// @param c An expression for where to evaluate the derivatives
Adjoints(const std::shared_ptr<Node>& y, const std::shared_ptr<Node>& c);
Adjoints(const Adjoints& adjoints) = default;
Adjoints& operator=(const Adjoints& adjoints) = default;
Adjoints() = default;
/// @brief (dy/dx)(c)
///
/// @param x The node whose adjoint is desired.
std::shared_ptr<Node> get(const std::shared_ptr<Node>& x);
/// @brief Add a backprop contribution to x's adjoint
///
/// @param x The adjoint node
/// @param delta A backprop contribution
void add_delta(const std::shared_ptr<Node>& x, const std::shared_ptr<Node>& delta);
protected:
std::unordered_map<Node*, std::shared_ptr<Node>> m_adjoint_map;
};
}
}
......@@ -23,7 +23,7 @@ using namespace ngraph;
atomic<size_t> Function::m_next_instance_id(0);
Function::Function(const std::shared_ptr<Node>& result,
const std::shared_ptr<ValueType>& result_type,
const std::shared_ptr<const ValueType>& result_type,
const std::vector<std::shared_ptr<op::Parameter>>& parameters,
const std::string& name)
: m_result(result)
......
......@@ -35,7 +35,7 @@ namespace ngraph
{
public:
Function(const std::shared_ptr<Node>& result,
const std::shared_ptr<ValueType>& result_type,
const std::shared_ptr<const ValueType>& result_type,
const std::vector<std::shared_ptr<op::Parameter>>& parameters,
const std::string& name = "");
......@@ -44,7 +44,7 @@ namespace ngraph
{
return m_parameters;
}
const std::shared_ptr<ValueType> get_result_type() const { return m_result_type; }
const std::shared_ptr<const ValueType> get_result_type() const { return m_result_type; }
std::string get_name() const;
void set_name(const std::string& name);
std::list<Node*>& get_ops();
......@@ -60,7 +60,7 @@ namespace ngraph
std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::op::Parameter>> m_parameters;
std::string m_name;
std::shared_ptr<ValueType> m_result_type;
std::shared_ptr<const ValueType> m_result_type;
bool m_ordered_ops_valid;
std::list<Node*> m_ordered_ops;
std::list<Node*> m_ops;
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/ngraph.hpp"
using namespace std;
......@@ -32,6 +33,16 @@ Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType>
}
}
Node::Node()
: Node({}, nullptr)
{
}
Node::Node(std::shared_ptr<ValueType> value_type)
: Node({}, value_type)
{
}
Node::~Node()
{
}
......@@ -51,6 +62,24 @@ void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type)
}
}
std::shared_ptr<const ValueType> Node::get_value_type()
{
if (nullptr == m_value_type)
{
propagate_types();
}
return m_value_type;
}
const std::shared_ptr<const ValueType> Node::get_value_type() const
{
if (nullptr == m_value_type)
{
const_cast<Node*>(this)->propagate_types();
}
return m_value_type;
}
void Node::assign_tensors()
{
vector<std::shared_ptr<const TensorViewType>> tensor_view_types;
......@@ -130,6 +159,20 @@ void Node::set_name(const string& name)
}
}
std::shared_ptr<Node> Node::backwards_derivative(const std::shared_ptr<Node>& x,
const std::shared_ptr<Node>& c)
{
auto adjoints_it = m_adjoint_map.find(c.get());
if (adjoints_it == m_adjoint_map.end())
{
adjoints_it =
m_adjoint_map
.insert(std::make_tuple(c.get(), autodiff::Adjoints(shared_from_this(), c)))
.first;
}
return adjoints_it->second.get(x);
}
namespace ngraph
{
ostream& operator<<(ostream& out, const Node& node)
......
......@@ -15,13 +15,16 @@
#pragma once
#include <atomic>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <iostream>
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/common.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
......@@ -35,20 +38,20 @@ namespace ngraph
/// view or a (possibly empty) tuple of values.
class Node : public std::enable_shared_from_this<Node>
{
friend class autodiff::Adjoints;
protected:
Node(const Nodes& arguments, std::shared_ptr<ValueType> value_type = nullptr);
Node()
: Node({}, nullptr)
{
}
Node();
Node(std::shared_ptr<ValueType> value_type);
virtual ~Node();
Node(std::shared_ptr<ValueType> value_type)
: Node({}, value_type)
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
{
}
virtual ~Node();
public:
/// The class name, must not contain spaces
virtual std::string description() const = 0;
......@@ -76,8 +79,8 @@ namespace ngraph
return typeid(*this) == typeid(*n);
}
std::shared_ptr<const ValueType> get_value_type() { return m_value_type; }
const std::shared_ptr<const ValueType> get_value_type() const { return m_value_type; }
std::shared_ptr<const ValueType> get_value_type();
const std::shared_ptr<const ValueType> get_value_type() const;
void set_value_type(const element::Type& element_type, const Shape& shape)
{
m_value_type = std::make_shared<TensorViewType>(element_type, shape);
......@@ -109,6 +112,9 @@ namespace ngraph
std::unordered_set<descriptor::Tensor*> liveness_new_list;
std::unordered_set<descriptor::Tensor*> liveness_free_list;
std::shared_ptr<Node> backwards_derivative(const std::shared_ptr<Node>& x,
const std::shared_ptr<Node>& c);
protected:
Nodes m_arguments;
std::shared_ptr<const ValueType> m_value_type;
......@@ -119,5 +125,6 @@ namespace ngraph
std::deque<descriptor::Input> m_inputs;
std::deque<descriptor::Output> m_outputs;
bool m_is_output;
std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map;
};
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ops/add.hpp"
void ngraph::op::Add::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
{
auto x = m_arguments[0];
auto y = m_arguments[1];
adjoints.add_delta(x, delta);
adjoints.add_delta(y, delta);
}
......@@ -28,6 +28,9 @@ namespace ngraph
{
}
virtual std::string description() const override { return "Add"; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
};
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ops/multiply.hpp"
void ngraph::op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
{
auto x = m_arguments[0];
auto y = m_arguments[1];
adjoints.add_delta(x, delta * y);
adjoints.add_delta(y, x * delta);
}
......@@ -29,12 +29,15 @@ namespace ngraph
}
virtual std::string description() const override { return "Multiply"; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
};
}
};
}
inline std::shared_ptr<ngraph::Node> operator*(const std::shared_ptr<ngraph::Node> arg0,
const std::shared_ptr<ngraph::Node> arg1)
{
return std::make_shared<ngraph::op::Multiply>(arg0, arg1);
}
inline std::shared_ptr<ngraph::Node> operator*(const std::shared_ptr<ngraph::Node> arg0,
const std::shared_ptr<ngraph::Node> arg1)
{
return std::make_shared<ngraph::op::Multiply>(arg0, arg1);
}
......@@ -36,6 +36,11 @@ namespace ngraph
// It is an error to try to associate a parameter with more than one function.
void assign_function(Function* function, size_t index);
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override
{
}
public:
Parameter(const std::shared_ptr<ValueType>& value_type = nullptr);
Parameter(const ngraph::element::Type& element_type, const Shape& shape);
......
......@@ -22,6 +22,7 @@ include_directories(
)
set (SRC
autodiff.cpp
build_graph.cpp
eigen.cpp
element_type.cpp
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
TEST(backwards, parameter)
{
auto shape = Shape{2, 3};
auto X0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Y = X0;
auto C = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto DYDX0 = Y->backwards_derivative(X0, C);
ASSERT_EQ(DYDX0, C);
}
TEST(backwards, add)
{
auto shape = Shape{2, 3};
auto X0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto X1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Y = X0 + X1;
auto C = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto DYDX0 = Y->backwards_derivative(X0, C);
auto DYDX1 = Y->backwards_derivative(X1, C);
ASSERT_EQ(DYDX0, C);
ASSERT_EQ(DYDX1, C);
}
// Returns (dy/(dXs))(C, Xs)
shared_ptr<Function> derivative(const std::shared_ptr<Node>& Y,
const std::vector<std::shared_ptr<op::Parameter>> Xs)
{
auto Y_tv_type = dynamic_pointer_cast<const TensorViewType>(Y->get_value_type());
auto C = make_shared<op::Parameter>(Y_tv_type->get_element_type(), Y_tv_type->get_shape());
std::vector<std::shared_ptr<Node>> dYdXs(Xs.size());
transform(Xs.begin(), Xs.end(), dYdXs.begin(), [C, Y](const std::shared_ptr<Node>& X) {
return Y->backwards_derivative(X, C);
});
auto result = make_shared<op::Tuple>(dYdXs);
std::vector<std::shared_ptr<op::Parameter>> args;
args.push_back(C);
args.insert(args.end(), Xs.begin(), Xs.end());
return make_shared<Function>(result, result->get_value_type(), args);
}
TEST(backwards, multiply)
{
auto shape = Shape{2, 3};
auto X0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto X1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Y = X0 * X1;
auto f = derivative(Y, {X0, X1});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
auto x0 = backend->make_parameterized_tensor_view<element::Float32>(shape);
*x0 = vector<float>{1, 3, 5, 7, 9, 11};
auto x1 = backend->make_parameterized_tensor_view<element::Float32>(shape);
*x1 = vector<float>{0, 2, 4, 6, 8, 10};
auto c = backend->make_parameterized_tensor_view<element::Float32>(shape);
*c = vector<float>{0, 0, 0, 0, 0, 0};
auto dx0 = backend->make_parameterized_tensor_view<element::Float32>(shape);
auto dx1 = backend->make_parameterized_tensor_view<element::Float32>(shape);
auto dx = backend->make_tuple({dx0, dx1});
size_t n = x0->get_vector().size();
vector<float> dx0_correct(n);
vector<float> dx1_correct(n);
for (size_t i = 0; i < n; i++)
{
c->get_vector().assign(n, 0);
c->get_vector()[i] = 1;
(*cf)({c, x0, x1}, {dx});
dx0_correct.assign(n, 0);
dx1_correct.assign(n, 0);
dx0_correct[i] = x1->get_vector()[i];
dx1_correct[i] = x0->get_vector()[i];
ASSERT_EQ(dx0->get_vector(), dx0_correct);
ASSERT_EQ(dx1->get_vector(), dx1_correct);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment