Commit 4a800845 authored by Adam Procter's avatar Adam Procter

Add test for apparent autodiff bug with nested adds

parent 3970b477
...@@ -90,14 +90,16 @@ template <typename T> ...@@ -90,14 +90,16 @@ template <typename T>
bool ngraph::test::all_close(const std::vector<T>& a, const std::vector<T>& b, T rtol, T atol) bool ngraph::test::all_close(const std::vector<T>& a, const std::vector<T>& b, T rtol, T atol)
{ {
assert(a.size() == b.size()); assert(a.size() == b.size());
bool retval = true;
for (size_t i = 0; i < a.size(); ++i) for (size_t i = 0; i < a.size(); ++i)
{ {
if (std::abs(a[i] - b[i]) > atol + rtol * std::abs(b[i])) if (std::abs(a[i] - b[i]) > atol + rtol * std::abs(b[i]))
{ {
return false; std::cout << "a[" << i << "] = " << a[i] << ", b[" << i << "] = " << b[i] << std::endl;
retval = false;
} }
} }
return true; return retval;
} }
template bool ngraph::test::all_close<float>(const std::vector<float>& a, template bool ngraph::test::all_close<float>(const std::vector<float>& a,
......
...@@ -65,6 +65,26 @@ TEST(backwards, add) ...@@ -65,6 +65,26 @@ TEST(backwards, add)
manager, backend, make_graph, {x0, x1}, .01f, .01f)); manager, backend, make_graph, {x0, x1}, .01f, .01f));
} }
TEST(backwards, add_nested)
{
auto manager = runtime::Manager::get("NGVM");
auto backend = manager->allocate_backend();
test::Uniform<element::Float32> rng(-1.0f, 1.0f);
auto shape = Shape{2, 3};
auto x0 = rng.initialize(backend->make_parameterized_tensor_view<element::Float32>(shape));
auto x1 = rng.initialize(backend->make_parameterized_tensor_view<element::Float32>(shape));
auto make_graph = [shape]() {
auto X0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto X1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
return make_shared<Function>(
(X0+X1) + (X1+X0), nullptr, std::vector<std::shared_ptr<op::Parameter>>{X0, X1});
};
EXPECT_TRUE(autodiff_numeric_compare<element::Float32>(
manager, backend, make_graph, {x0, x1}, .01f, .01f));
}
TEST(backwards, broadcast0) TEST(backwards, broadcast0)
{ {
auto manager = runtime::Manager::get("NGVM"); auto manager = runtime::Manager::get("NGVM");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment