Commit eb5f6cce authored by adstraw's avatar adstraw Committed by Scott Cyphers

fix fprop cache test error by cloning functions to preserve metadata (#664)

* fix fprop cache test error by cloning functions to preserve metadata
parent 5e718498
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/types/element_type.hpp" #include "ngraph/types/element_type.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -187,13 +188,17 @@ namespace ngraph ...@@ -187,13 +188,17 @@ namespace ngraph
} }
// compile and run modified (y, cached) = f(x) // compile and run modified (y, cached) = f(x)
auto cache_fwd = manager->compile(fprop_cache.fprop); NodeMap nm1;
auto clone_fwd = clone_function(fprop_cache.fprop, nm1);
auto cache_fwd = manager->compile(clone_fwd);
auto cache_fwd_cf = backend->make_call_frame(cache_fwd); auto cache_fwd_cf = backend->make_call_frame(cache_fwd);
cache_fwd_cf->tensor_call(f_input_args, mod_f_output_args); cache_fwd_cf->tensor_call(f_input_args, mod_f_output_args);
// call modfied f'(c, cached) to get df/dX* // call modfied f'(c, cached) to get df/dX*
auto cache_dfdx = get_autodiff<T>( NodeMap nm2;
manager, backend, fprop_cache.bprop, mod_df_input_args, indep_params); auto clone_bwd = clone_function(fprop_cache.bprop, nm2);
auto cache_dfdx =
get_autodiff<T>(manager, backend, clone_bwd, mod_df_input_args, indep_params);
const auto numpy_atol = 1e-5f; const auto numpy_atol = 1e-5f;
const auto numpy_rtol = 1e-8f; const auto numpy_rtol = 1e-8f;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment