Commit dfb88350 authored by adstraw's avatar adstraw Committed by Scott Cyphers

modify existing autodiff unit tests to test fprop cache (#354)

* modify existing autodiff unit tests to test fprop cache

* cleanup

* fix compile error introduced with bad merge

* remove invalid negative/negative backwards power test
parent 8bccef1a
...@@ -988,12 +988,6 @@ TEST(${BACKEND_NAME}, backwards_power) ...@@ -988,12 +988,6 @@ TEST(${BACKEND_NAME}, backwards_power)
x0 = rng_pos.initialize(backend->make_primary_tensor_view<float>(shape)); x0 = rng_pos.initialize(backend->make_primary_tensor_view<float>(shape));
x1 = rng_neg.initialize(backend->make_primary_tensor_view<float>(shape)); x1 = rng_neg.initialize(backend->make_primary_tensor_view<float>(shape));
EXPECT_TRUE(
autodiff_numeric_compare<float>(manager, backend, make_graph, {x0, x1}, .01f, .01f));
x0 = rng_neg.initialize(backend->make_primary_tensor_view<float>(shape));
x1 = rng_neg.initialize(backend->make_primary_tensor_view<float>(shape));
EXPECT_TRUE( EXPECT_TRUE(
autodiff_numeric_compare<float>(manager, backend, make_graph, {x0, x1}, .01f, .01f)); autodiff_numeric_compare<float>(manager, backend, make_graph, {x0, x1}, .01f, .01f));
......
...@@ -30,15 +30,10 @@ namespace ngraph ...@@ -30,15 +30,10 @@ namespace ngraph
{ {
class Backend; class Backend;
class Manager; class Manager;
} } // namespace runtime
namespace autodiff namespace autodiff
{ {
/// @brief Returns a FunctionSpec for the backprop derivative of its argument.
/// @param f is f(X_i...)
/// @returns f'(X_i..., c) where f'(x_i, ..., c)_j is backprop for X_j
std::shared_ptr<Function> backprop_function(const std::shared_ptr<Function>& f);
template <typename T> template <typename T>
std::vector<std::shared_ptr<runtime::TensorView>> std::vector<std::shared_ptr<runtime::TensorView>>
backprop_derivative(const std::shared_ptr<runtime::Manager>& manager, backprop_derivative(const std::shared_ptr<runtime::Manager>& manager,
...@@ -47,66 +42,119 @@ namespace ngraph ...@@ -47,66 +42,119 @@ namespace ngraph
const std::vector<std::shared_ptr<runtime::TensorView>>& args, const std::vector<std::shared_ptr<runtime::TensorView>>& args,
const std::vector<std::shared_ptr<op::Parameter>>& indep_params) const std::vector<std::shared_ptr<op::Parameter>>& indep_params)
{ {
// y = f(X)
// using X (upper case) to denote all paramenters of f
// using x (lower case) to denote an individual paramemter of f a.k.a. Xj
// NOTE: using X* to denote all x "of interest" represented by indep_params
Shape y_shape = f->get_output_shape(0); Shape y_shape = f->get_output_shape(0);
// adjoint
auto c_param = std::make_shared<op::Parameter>(element::from<T>(), y_shape); auto c_param = std::make_shared<op::Parameter>(element::from<T>(), y_shape);
auto c_arg = backend->make_primary_tensor_view<T>(y_shape); auto c_arg = backend->make_primary_tensor_view<T>(y_shape);
auto params = f->get_parameters();
std::vector<std::shared_ptr<Node>> deriv_nodes; // df/dX*
std::vector<std::shared_ptr<runtime::TensorView>> bprops; // return value for f'(X, c)
std::vector<std::shared_ptr<Node>> df_output_params;
std::vector<std::shared_ptr<runtime::TensorView>> df_output_args;
// return value for this function
std::vector<std::shared_ptr<runtime::TensorView>> results; std::vector<std::shared_ptr<runtime::TensorView>> results;
for (auto param : indep_params) // for each x "of interest"
for (auto x : indep_params)
{ {
Shape s = y_shape; auto x_shape = x->get_shape();
auto param_shape = param->get_shape();
s.insert(s.end(), param_shape.begin(), param_shape.end()); // each element of y has a derivative with respect to each element of x
results.push_back(backend->make_primary_tensor_view<T>(s)); // hence, create a y by x sized tensor for this result
bprops.push_back(backend->make_primary_tensor_view<T>(param_shape)); auto y_by_x_shape = y_shape;
deriv_nodes.push_back(f->get_output_op(0)->backprop_node(param, c_param)); y_by_x_shape.insert(y_by_x_shape.end(), x_shape.begin(), x_shape.end());
results.push_back(backend->make_primary_tensor_view<T>(y_by_x_shape));
// add df/dx to df/dX*
df_output_params.push_back(f->get_output_op(0)->backprop_node(x, c_param));
df_output_args.push_back(backend->make_primary_tensor_view<T>(x_shape));
} }
std::vector<std::shared_ptr<op::Parameter>> df_params = params; // (X, c)
df_params.push_back(c_param); // input to f'(X, c)
auto df = std::make_shared<Function>(deriv_nodes, df_params); std::vector<std::shared_ptr<op::Parameter>> df_input_params = f->get_parameters();
df_input_params.push_back(c_param);
// df/dX* = f'(X, c)
auto df = std::make_shared<Function>(df_output_params, df_input_params);
auto external = manager->compile(df); // create fprop cache
// creates modified forward function -> (y, cached) = f(x)
// creates modified backward function -> df/dX* = f'(c, cached)
auto fprop_cache = cache_fprop(f, df, {c_param});
// modified f outputs
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> f_output_args;
f_output_args.push_back(backend->make_primary_tensor_view<T>(y_shape));
// modified f' inputs
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> df_input_args;
df_input_args.push_back(c_arg);
// add cached nodes to both modified f outputs and modified f' inputs
for (auto node : fprop_cache.fprop_output_nodes)
{
auto tv = backend->make_primary_tensor_view<T>(node->get_shape());
df_input_args.push_back(tv);
f_output_args.push_back(tv);
}
// compile and run modified (y, cached) = f(x)
auto cache_fwd = manager->compile(fprop_cache.fprop);
auto cache_fwd_cf = backend->make_call_frame(cache_fwd);
cache_fwd_cf->tensor_call(args, f_output_args);
// compile modified df/dX* = f'(c, cached)
auto external = manager->compile(fprop_cache.bprop);
auto cf = backend->make_call_frame(external); auto cf = backend->make_call_frame(external);
// We compute the derivatives chunk by chunk // create storage for results
std::vector<typename std::vector<T>::iterator> result_pos; // * outer vector size = number of x "of interest"
// * inner vector size = number of elements in y * number of elements in x
std::vector<std::vector<T>> result_vect; std::vector<std::vector<T>> result_vect;
std::vector<typename std::vector<T>::iterator> result_pos;
for (auto result : results) for (auto result : results)
{ {
result_vect.push_back(read_vector<T>(result)); // storage for results result_vect.push_back(read_vector<T>(result));
result_pos.push_back(result_vect.back().begin()); result_pos.push_back(result_vect.back().begin());
} }
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> args_tv; // get adjoint and force to all elements to zero
args_tv.insert(args_tv.begin(), args.begin(), args.end());
args_tv.push_back(c_arg);
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> bprops_tv;
bprops_tv.insert(bprops_tv.begin(), bprops.begin(), bprops.end());
auto c_vec = read_vector<T>(c_arg); auto c_vec = read_vector<T>(c_arg);
fill(c_vec.begin(), c_vec.end(), 0); fill(c_vec.begin(), c_vec.end(), 0);
// for each element of the adjoint
// same as saying for each element of y
for (size_t i = 0; i < c_vec.size(); i++) for (size_t i = 0; i < c_vec.size(); i++)
{ {
// set a single adjoint element
c_vec[i] = 1; c_vec[i] = 1;
write_vector(c_arg, c_vec); write_vector(c_arg, c_vec);
cf->tensor_call(args_tv, bprops_tv);
// call modified df/dX* = f'(c, cached)
cf->tensor_call(df_input_args, df_output_args);
// reset the adjoint element
c_vec[i] = 0; c_vec[i] = 0;
write_vector(c_arg, c_vec); write_vector(c_arg, c_vec);
// for each result
// same as saying for each x "of interest"
for (size_t j = 0; j < results.size(); j++) for (size_t j = 0; j < results.size(); j++)
{ {
auto bprop_vec = read_vector<T>(bprops[j]); // copy df/dx to storage for this element of y
result_pos[j] = std::copy(bprop_vec.begin(), bprop_vec.end(), result_pos[j]); auto dfdx = read_vector<T>(df_output_args[j]);
result_pos[j] = std::copy(dfdx.begin(), dfdx.end(), result_pos[j]);
} }
} }
// Copy results from temp to result vector // copy storage to results and return
for (size_t j = 0; j < results.size(); j++) for (size_t j = 0; j < results.size(); j++)
{ {
write_vector(results[j], result_vect[j]); write_vector(results[j], result_vect[j]);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment