Unverified Commit 6e1b308d authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Autodiff for atan2 (#3865)

* Autodiff for atan2

* atan2 not supported on PlaidML
parent 41a44f92
......@@ -15,6 +15,12 @@
//*****************************************************************************
#include "ngraph/op/atan2.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
using namespace std;
using namespace ngraph;
......@@ -39,5 +45,9 @@ void op::Atan2::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
throw ngraph_error("Autodiff not supported for Atan2");
auto y = input_value(0);
auto x = input_value(1);
auto delta_over_r = deltas.at(0) / (x * x + y * y);
adjoints.add_delta(y, x * delta_over_r);
adjoints.add_delta(x, -y * delta_over_r);
}
......@@ -54,6 +54,7 @@ top_k_opset_11_const_k_smallest # No plans to implement TopK
# unsupported op: `Atan2`
atan2
backwards_atan2
# unsupported op: `Erf`
erf
......
......@@ -521,6 +521,23 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_atan)
EXPECT_TRUE(autodiff_numeric_compare<float>(backend.get(), make_graph, {x0}, .01f, .01f));
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_atan2)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
Shape shape{30};
test::Uniform<float> rng(-5.0f, 5.0f);
auto y = rng.initialize(backend->create_tensor<float>(shape));
auto x = rng.initialize(backend->create_tensor<float>(shape));
auto make_graph = [shape]() {
auto X = make_shared<op::Parameter>(element::f32, shape);
auto Y = make_shared<op::Parameter>(element::f32, shape);
return make_shared<Function>(make_shared<op::Atan2>(Y, X), ParameterVector{Y, X});
};
EXPECT_TRUE(autodiff_numeric_compare<float>(backend.get(), make_graph, {y, x}, .01f, .01f));
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_broadcast0)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment