Unverified Commit ffe3a631 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Cache functions so the backend does not need to recompile (#1209)

* Cache some generated functions in backwards tests to speed performance

* more caching
parent 9fecc560
This diff is collapsed.
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <unordered_map>
#include "ngraph/autodiff/adjoints.hpp" #include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
...@@ -31,6 +32,10 @@ namespace ngraph ...@@ -31,6 +32,10 @@ namespace ngraph
class Node; class Node;
class Function; class Function;
static std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<Function>> s_df_map;
static std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<Function>> s_clone_fwd_map;
static std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<Function>> s_clone_bwd_map;
namespace runtime namespace runtime
{ {
class Backend; class Backend;
...@@ -145,7 +150,6 @@ namespace ngraph ...@@ -145,7 +150,6 @@ namespace ngraph
for (auto x : indep_params) for (auto x : indep_params)
{ {
// add df/dx to df/dX* // add df/dx to df/dX*
auto x_shape = x->get_shape();
df_output_params.push_back(adjoints.backprop_node(x)); df_output_params.push_back(adjoints.backprop_node(x));
} }
...@@ -154,7 +158,11 @@ namespace ngraph ...@@ -154,7 +158,11 @@ namespace ngraph
df_input_params.insert(df_input_params.begin(), c_param); df_input_params.insert(df_input_params.begin(), c_param);
// df/dX* = f'(c, X) // df/dX* = f'(c, X)
auto df = std::make_shared<Function>(df_output_params, df_input_params); if (!s_df_map[f])
{
s_df_map[f] = std::make_shared<Function>(df_output_params, df_input_params);
}
auto df = s_df_map[f];
// (c, X) arguments // (c, X) arguments
std::vector<std::shared_ptr<runtime::TensorView>> df_input_args = f_input_args; std::vector<std::shared_ptr<runtime::TensorView>> df_input_args = f_input_args;
...@@ -184,11 +192,20 @@ namespace ngraph ...@@ -184,11 +192,20 @@ namespace ngraph
} }
// compile and run modified (y, cached) = f(x) // compile and run modified (y, cached) = f(x)
auto clone_fwd = clone_function(*fprop_cache.fprop); if (!s_clone_fwd_map[f])
{
s_clone_fwd_map[f] = clone_function(*fprop_cache.fprop);
}
auto clone_fwd = s_clone_fwd_map[f];
backend->call(clone_fwd, mod_f_output_args, f_input_args); backend->call(clone_fwd, mod_f_output_args, f_input_args);
// call modfied f'(c, cached) to get df/dX* // call modfied f'(c, cached) to get df/dX*
auto clone_bwd = clone_function(*fprop_cache.bprop); if (!s_clone_bwd_map[f])
{
s_clone_bwd_map[f] = clone_function(*fprop_cache.bprop);
}
auto clone_bwd = s_clone_bwd_map[f];
auto cache_dfdx = get_autodiff<T>(backend, clone_bwd, mod_df_input_args, indep_params); auto cache_dfdx = get_autodiff<T>(backend, clone_bwd, mod_df_input_args, indep_params);
const auto numpy_atol = 1e-5f; const auto numpy_atol = 1e-5f;
......
...@@ -26,7 +26,8 @@ ...@@ -26,7 +26,8 @@
// TODO: Always compute the numerical derivatives in double // TODO: Always compute the numerical derivatives in double
template <typename T> template <typename T>
bool autodiff_numeric_compare(const std::shared_ptr<ngraph::runtime::Backend>& backend, bool autodiff_numeric_compare(const std::shared_ptr<ngraph::runtime::Backend>& backend,
std::function<std::shared_ptr<ngraph::Function>()> make_graph, std::shared_ptr<ngraph::Function> f,
std::shared_ptr<ngraph::Function> g,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& args, const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& args,
T rtol, T rtol,
T atol) T atol)
...@@ -35,7 +36,6 @@ bool autodiff_numeric_compare(const std::shared_ptr<ngraph::runtime::Backend>& b ...@@ -35,7 +36,6 @@ bool autodiff_numeric_compare(const std::shared_ptr<ngraph::runtime::Backend>& b
// Use INTERPRETER to compute numerical derivatives // Use INTERPRETER to compute numerical derivatives
auto interpreter_backend = ngraph::runtime::Backend::create("INTERPRETER"); auto interpreter_backend = ngraph::runtime::Backend::create("INTERPRETER");
auto f = make_graph();
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> interpreter_args; std::vector<std::shared_ptr<ngraph::runtime::TensorView>> interpreter_args;
for (auto arg : args) for (auto arg : args)
...@@ -58,7 +58,6 @@ bool autodiff_numeric_compare(const std::shared_ptr<ngraph::runtime::Backend>& b ...@@ -58,7 +58,6 @@ bool autodiff_numeric_compare(const std::shared_ptr<ngraph::runtime::Backend>& b
interpreter_backend, f, interpreter_args, delta, f->get_parameters()); interpreter_backend, f, interpreter_args, delta, f->get_parameters());
// Use the backend being tested to compute symbolic derivatives // Use the backend being tested to compute symbolic derivatives
auto g = make_graph();
auto results_sym = auto results_sym =
ngraph::autodiff::backprop_derivative<T>(backend, g, args, g->get_parameters()); ngraph::autodiff::backprop_derivative<T>(backend, g, args, g->get_parameters());
...@@ -75,10 +74,21 @@ bool autodiff_numeric_compare(const std::shared_ptr<ngraph::runtime::Backend>& b ...@@ -75,10 +74,21 @@ bool autodiff_numeric_compare(const std::shared_ptr<ngraph::runtime::Backend>& b
return ngraph::test::all_close(results_num, interpreter_results_sym, rtol, atol); return ngraph::test::all_close(results_num, interpreter_results_sym, rtol, atol);
} }
template <typename T>
bool autodiff_numeric_compare(const std::shared_ptr<ngraph::runtime::Backend>& backend,
std::function<std::shared_ptr<ngraph::Function>()> make_graph,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& args,
T rtol,
T atol)
{
return autodiff_numeric_compare(backend, make_graph(), make_graph(), args, rtol, atol);
}
template <typename T> template <typename T>
bool autodiff_numeric_compare_selective( bool autodiff_numeric_compare_selective(
const std::shared_ptr<ngraph::runtime::Backend>& backend, const std::shared_ptr<ngraph::runtime::Backend>& backend,
std::function<std::shared_ptr<ngraph::Function>()> make_graph, std::shared_ptr<ngraph::Function> f,
std::shared_ptr<ngraph::Function> g,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& args, const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& args,
T rtol, T rtol,
T atol, T atol,
...@@ -86,7 +96,6 @@ bool autodiff_numeric_compare_selective( ...@@ -86,7 +96,6 @@ bool autodiff_numeric_compare_selective(
{ {
// Use INTERPRETER to compute numerical derivatives // Use INTERPRETER to compute numerical derivatives
std::vector<std::shared_ptr<ngraph::op::Parameter>> f_indep_params; std::vector<std::shared_ptr<ngraph::op::Parameter>> f_indep_params;
auto f = make_graph();
size_t i = 0; size_t i = 0;
...@@ -123,7 +132,6 @@ bool autodiff_numeric_compare_selective( ...@@ -123,7 +132,6 @@ bool autodiff_numeric_compare_selective(
// Use the backend being tested to compute symbolic derivatives // Use the backend being tested to compute symbolic derivatives
std::vector<std::shared_ptr<ngraph::op::Parameter>> g_indep_params; std::vector<std::shared_ptr<ngraph::op::Parameter>> g_indep_params;
auto g = make_graph();
i = 0; i = 0;
...@@ -150,3 +158,16 @@ bool autodiff_numeric_compare_selective( ...@@ -150,3 +158,16 @@ bool autodiff_numeric_compare_selective(
return ngraph::test::all_close(results_num, interpreter_results_sym, rtol, atol); return ngraph::test::all_close(results_num, interpreter_results_sym, rtol, atol);
} }
template <typename T>
bool autodiff_numeric_compare_selective(
const std::shared_ptr<ngraph::runtime::Backend>& backend,
std::function<std::shared_ptr<ngraph::Function>()> make_graph,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& args,
T rtol,
T atol,
const std::vector<bool>& indep_param_mask)
{
return autodiff_numeric_compare_selective(
backend, make_graph(), make_graph(), args, rtol, atol, indep_param_mask);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment