Unverified Commit dd41bb62 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Return correct value from autodiff tests. Make batchnorm autodiff test a little bigger. (#2218)

parent e526aeef
...@@ -1626,7 +1626,7 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2) ...@@ -1626,7 +1626,7 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2)
NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training) NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training)
{ {
const Shape input_shape{5, 3, 2, 2}; const Shape input_shape{10, 4, 5, 5};
const Shape channel_shape{input_shape.at(1)}; const Shape channel_shape{input_shape.at(1)};
const double eps = 1e-3; const double eps = 1e-3;
const element::Type& et = element::f32; const element::Type& et = element::f32;
...@@ -1647,19 +1647,19 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training) ...@@ -1647,19 +1647,19 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training)
goes.push_back(mean); goes.push_back(mean);
goes.push_back(variance); goes.push_back(variance);
// TODO autodiff testing with more than one result // TODO autodiff testing with more than one result
auto f = make_shared<Function>(ResultVector{normed_input /* , mean, variance*/}, auto f =
ParameterVector{input, gamma, beta}); make_shared<Function>(ResultVector{normed_input}, ParameterVector{input, gamma, beta});
return f; return f;
}; };
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
test::Uniform<T> rng(-1.0, 1.0); test::Uniform<T> rng(-5.0, 2.0);
auto input = rng.initialize(backend->create_tensor<T>(input_shape)); auto input = rng.initialize(backend->create_tensor<T>(input_shape));
auto gamma = rng.initialize(backend->create_tensor<T>(channel_shape)); auto gamma = rng.initialize(backend->create_tensor<T>(channel_shape));
auto beta = rng.initialize(backend->create_tensor<T>(channel_shape)); auto beta = rng.initialize(backend->create_tensor<T>(channel_shape));
EXPECT_TRUE( EXPECT_TRUE(
autodiff_numeric_compare<T>(backend.get(), make_graph, {input, gamma, beta}, .001, .001)); autodiff_numeric_compare<T>(backend.get(), make_graph, {input, gamma, beta}, .005, .005));
} }
NGRAPH_TEST(${BACKEND_NAME}, backwards_reverse_sequence_n3_c2_h3) NGRAPH_TEST(${BACKEND_NAME}, backwards_reverse_sequence_n3_c2_h3)
......
...@@ -25,12 +25,13 @@ ...@@ -25,12 +25,13 @@
// derivative does not work with int types // derivative does not work with int types
// TODO: Always compute the numerical derivatives in double // TODO: Always compute the numerical derivatives in double
template <typename T> template <typename T>
bool autodiff_numeric_compare(ngraph::runtime::Backend* backend, ::testing::AssertionResult
std::shared_ptr<ngraph::Function> f, autodiff_numeric_compare(ngraph::runtime::Backend* backend,
std::shared_ptr<ngraph::Function> g, std::shared_ptr<ngraph::Function> f,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& args, std::shared_ptr<ngraph::Function> g,
T rtol, const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& args,
T atol) T rtol,
T atol)
{ {
T delta = static_cast<T>(0.0009765625f); // Binary-representable number near 0.001 T delta = static_cast<T>(0.0009765625f); // Binary-representable number near 0.001
...@@ -75,17 +76,18 @@ bool autodiff_numeric_compare(ngraph::runtime::Backend* backend, ...@@ -75,17 +76,18 @@ bool autodiff_numeric_compare(ngraph::runtime::Backend* backend,
} }
template <typename T> template <typename T>
bool autodiff_numeric_compare(ngraph::runtime::Backend* backend, ::testing::AssertionResult
std::function<std::shared_ptr<ngraph::Function>()> make_graph, autodiff_numeric_compare(ngraph::runtime::Backend* backend,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& args, std::function<std::shared_ptr<ngraph::Function>()> make_graph,
T rtol, const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& args,
T atol) T rtol,
T atol)
{ {
return autodiff_numeric_compare(backend, make_graph(), make_graph(), args, rtol, atol); return autodiff_numeric_compare(backend, make_graph(), make_graph(), args, rtol, atol);
} }
template <typename T> template <typename T>
bool autodiff_numeric_compare_selective( ::testing::AssertionResult autodiff_numeric_compare_selective(
ngraph::runtime::Backend* backend, ngraph::runtime::Backend* backend,
std::shared_ptr<ngraph::Function> f, std::shared_ptr<ngraph::Function> f,
std::shared_ptr<ngraph::Function> g, std::shared_ptr<ngraph::Function> g,
...@@ -160,7 +162,7 @@ bool autodiff_numeric_compare_selective( ...@@ -160,7 +162,7 @@ bool autodiff_numeric_compare_selective(
} }
template <typename T> template <typename T>
bool autodiff_numeric_compare_selective( ::testing::AssertionResult autodiff_numeric_compare_selective(
ngraph::runtime::Backend* backend, ngraph::runtime::Backend* backend,
std::function<std::shared_ptr<ngraph::Function>()> make_graph, std::function<std::shared_ptr<ngraph::Function>()> make_graph,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& args, const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& args,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment