Commit d5991e9a authored by Scott Cyphers's avatar Scott Cyphers

Move autodiff test functions into test/util

parent f5768063
...@@ -13,9 +13,6 @@ ...@@ -13,9 +13,6 @@
set (SRC set (SRC
autodiff/adjoints.cpp autodiff/adjoints.cpp
autodiff/backprop_derivative.cpp
autodiff/backprop_function.cpp
autodiff/numeric_derivative.cpp
descriptor/input.cpp descriptor/input.cpp
descriptor/layout/dense_tensor_view_layout.cpp descriptor/layout/dense_tensor_view_layout.cpp
descriptor/layout/tensor_view_layout.cpp descriptor/layout/tensor_view_layout.cpp
......
...@@ -64,30 +64,5 @@ namespace ngraph ...@@ -64,30 +64,5 @@ namespace ngraph
/// @param f is f(X_i...) /// @param f is f(X_i...)
/// @returns f'(X_i..., c) where f'(x_i, ..., c)_j is backprop for X_j /// @returns f'(X_i..., c) where f'(x_i, ..., c)_j is backprop for X_j
std::shared_ptr<Function> backprop_function(const std::shared_ptr<Function>& f); std::shared_ptr<Function> backprop_function(const std::shared_ptr<Function>& f);
template <typename ET>
std::vector<std::shared_ptr<runtime::ParameterizedTensorView<ET>>> backprop_derivative(
const std::shared_ptr<runtime::Manager>& manager,
const std::shared_ptr<runtime::Backend>& backend,
const std::shared_ptr<Function>& f,
const std::vector<std::shared_ptr<runtime::ParameterizedTensorView<ET>>>& args);
extern template std::vector<
std::shared_ptr<runtime::ParameterizedTensorView<ngraph::element::Float32>>>
backprop_derivative<ngraph::element::Float32>(
const std::shared_ptr<runtime::Manager>& manager,
const std::shared_ptr<runtime::Backend>& backend,
const std::shared_ptr<Function>& f,
const std::vector<
std::shared_ptr<runtime::ParameterizedTensorView<element::Float32>>>& args);
extern template std::vector<
std::shared_ptr<runtime::ParameterizedTensorView<ngraph::element::Float64>>>
backprop_derivative<ngraph::element::Float64>(
const std::shared_ptr<runtime::Manager>& manager,
const std::shared_ptr<runtime::Backend>& backend,
const std::shared_ptr<Function>& f,
const std::vector<
std::shared_ptr<runtime::ParameterizedTensorView<element::Float64>>>& args);
} }
} }
...@@ -37,6 +37,9 @@ set (SRC ...@@ -37,6 +37,9 @@ set (SRC
topological_sort.cpp topological_sort.cpp
type_prop.cpp type_prop.cpp
util/all_close.cpp util/all_close.cpp
util/autodiff/backprop_derivative.cpp
util/autodiff/backprop_function.cpp
util/autodiff/numeric_derivative.cpp
util/test_tools.cpp util/test_tools.cpp
util.cpp util.cpp
uuid.cpp uuid.cpp
......
...@@ -19,11 +19,11 @@ ...@@ -19,11 +19,11 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/autodiff/backprop_derivative.hpp"
#include "ngraph/autodiff/backprop_function.hpp"
#include "ngraph/autodiff/numeric_derivative.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "util/all_close.hpp" #include "util/all_close.hpp"
#include "util/autodiff/backprop_derivative.hpp"
#include "util/autodiff/backprop_function.hpp"
#include "util/autodiff/numeric_derivative.hpp"
#include "util/random.hpp" #include "util/random.hpp"
using namespace std; using namespace std;
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/autodiff/backprop_derivative.hpp" #include "backprop_derivative.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/ops/tuple.hpp" #include "ngraph/ops/tuple.hpp"
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "backprop_function.hpp"
#include "ngraph/autodiff/adjoints.hpp" #include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
......
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include "ngraph/autodiff/numeric_derivative.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/ops/tuple.hpp" #include "ngraph/ops/tuple.hpp"
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/call_frame.hpp"
#include "numeric_derivative.hpp"
using namespace ngraph; using namespace ngraph;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment