Commit 0c3bc7d0 authored by Matthew Brookhart's avatar Matthew Brookhart Committed by Scott Cyphers

Add autodiff for the arc trig ops (#935)

parent 833a05b2
......@@ -16,6 +16,19 @@
#include "ngraph/op/acos.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/shape.hpp"
#include <string>
#include <vector>
using namespace std;
using namespace ngraph;
......@@ -32,3 +45,19 @@ shared_ptr<Node> op::Acos::copy_with_new_args(const NodeVector& new_args) const
}
return make_shared<Acos>(new_args.at(0));
}
void op::Acos::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
auto delta = deltas.at(0);
auto x = get_inputs().at(0).get_output().get_node();
auto one = make_shared<op::Constant>(x->get_element_type(), Shape{}, vector<string>{"1"});
AxisSet axes;
for (size_t i = 0; i < x->get_shape().size(); i++)
axes.insert(i);
auto ones = make_shared<op::Broadcast>(one, x->get_shape(), axes);
adjoints.add_delta(x, -delta / make_shared<op::Sqrt>(ones - x * x));
}
......@@ -40,6 +40,10 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
}
}
......@@ -16,6 +16,18 @@
#include "ngraph/op/asin.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/shape.hpp"
#include <string>
#include <vector>
using namespace std;
using namespace ngraph;
......@@ -32,3 +44,19 @@ shared_ptr<Node> op::Asin::copy_with_new_args(const NodeVector& new_args) const
}
return make_shared<Asin>(new_args.at(0));
}
void op::Asin::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
auto delta = deltas.at(0);
auto x = get_inputs().at(0).get_output().get_node();
auto one = make_shared<op::Constant>(x->get_element_type(), Shape{}, vector<string>{"1"});
AxisSet axes;
for (size_t i = 0; i < x->get_shape().size(); i++)
axes.insert(i);
auto ones = make_shared<op::Broadcast>(one, x->get_shape(), axes);
adjoints.add_delta(x, delta / make_shared<op::Sqrt>(ones - x * x));
}
......@@ -40,6 +40,10 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
}
}
......@@ -16,6 +16,17 @@
#include "ngraph/op/atan.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/shape.hpp"
#include <string>
#include <vector>
using namespace std;
using namespace ngraph;
......@@ -32,3 +43,19 @@ shared_ptr<Node> op::Atan::copy_with_new_args(const NodeVector& new_args) const
}
return make_shared<Atan>(new_args.at(0));
}
void op::Atan::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
auto delta = deltas.at(0);
auto x = get_inputs().at(0).get_output().get_node();
auto one = make_shared<op::Constant>(x->get_element_type(), Shape{}, vector<string>{"1"});
AxisSet axes;
for (size_t i = 0; i < x->get_shape().size(); i++)
axes.insert(i);
auto ones = make_shared<op::Broadcast>(one, x->get_shape(), axes);
adjoints.add_delta(x, delta / (ones + x * x));
}
......@@ -40,6 +40,10 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
}
}
......@@ -378,6 +378,22 @@ TEST(${BACKEND_NAME}, backwards_abs)
}
}
TEST(${BACKEND_NAME}, backwards_acos)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
test::Uniform<float> rng(-0.9f, 0.9f);
Shape shape{2, 3};
auto x0 = rng.initialize(backend->create_tensor<float>(shape));
auto make_graph = [shape]() {
auto X0 = make_shared<op::Parameter>(element::f32, shape);
return make_shared<Function>(make_shared<op::Acos>(X0),
std::vector<std::shared_ptr<op::Parameter>>{X0});
};
EXPECT_TRUE(autodiff_numeric_compare<float>(backend, make_graph, {x0}, .01f, .01f));
}
TEST(${BACKEND_NAME}, backwards_add)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
......@@ -413,6 +429,38 @@ TEST(${BACKEND_NAME}, backwards_add_nested)
EXPECT_TRUE(autodiff_numeric_compare<float>(backend, make_graph, {x0, x1}, .01f, .01f));
}
TEST(${BACKEND_NAME}, backwards_asin)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
test::Uniform<float> rng(-0.9f, 0.9f);
Shape shape{2, 3};
auto x0 = rng.initialize(backend->create_tensor<float>(shape));
auto make_graph = [shape]() {
auto X0 = make_shared<op::Parameter>(element::f32, shape);
return make_shared<Function>(make_shared<op::Asin>(X0),
std::vector<std::shared_ptr<op::Parameter>>{X0});
};
EXPECT_TRUE(autodiff_numeric_compare<float>(backend, make_graph, {x0}, .01f, .01f));
}
TEST(${BACKEND_NAME}, backwards_atan)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
test::Uniform<float> rng(-10.0f, 10.0f);
Shape shape{2, 3};
auto x0 = rng.initialize(backend->create_tensor<float>(shape));
auto make_graph = [shape]() {
auto X0 = make_shared<op::Parameter>(element::f32, shape);
return make_shared<Function>(make_shared<op::Atan>(X0),
std::vector<std::shared_ptr<op::Parameter>>{X0});
};
EXPECT_TRUE(autodiff_numeric_compare<float>(backend, make_graph, {x0}, .01f, .01f));
}
TEST(${BACKEND_NAME}, backwards_broadcast0)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment