Commit 3b49dd1a authored by Matthew Brookhart's avatar Matthew Brookhart Committed by Scott Cyphers

refactor cache_prop to reuse bprop inputs (#1134)

parent b9a77a9d
......@@ -185,8 +185,7 @@ size_t ngraph::round_up(size_t size, size_t alignment)
}
ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
std::shared_ptr<ngraph::Function> bprop,
std::vector<std::shared_ptr<Node>> adjoints)
std::shared_ptr<ngraph::Function> bprop)
{
using namespace ngraph;
......@@ -208,17 +207,21 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// shape and element type as the nodes in fprop
FpropCache fprop_cache;
fprop_cache.node_param_map = std::make_shared<NodeMap>();
ngraph::traverse_nodes(fprop, [&fprop_cache, &in_bprop](std::shared_ptr<Node> node) {
if (in_bprop.count(node) != 0)
{
fprop_cache.node_param_map->add(
node, std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape()));
}
});
auto bprop_inputs = bprop->get_parameters();
ngraph::traverse_nodes(
fprop, [&fprop_cache, &in_bprop, &bprop_inputs](std::shared_ptr<Node> node) {
if (in_bprop.count(node) != 0 &&
std::find(bprop_inputs.begin(), bprop_inputs.end(), node) == bprop_inputs.end())
{
fprop_cache.node_param_map->add(
node,
std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape()));
}
});
// Find all of the nodes that are intermediate values of fprop and used in
// bprop
// and store those nodes that aren't needed in bprop
// bprop and store those nodes that aren't needed in bprop
std::vector<std::shared_ptr<Node>> unused_nodes;
for (auto kv : fprop_cache.node_param_map->get_node_map())
{
......@@ -262,7 +265,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// get clone bprop parameters
op::ParameterVector bprop_input_params;
for (auto param : adjoints)
for (auto param : bprop_inputs)
{
bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(param)));
......
......@@ -259,7 +259,5 @@ namespace ngraph
* The last argument is the adjoints coming into the bprop function, the output
* bprop function will have these nodes as the first N input parameters
**/
FpropCache cache_fprop(std::shared_ptr<Function> fprop,
std::shared_ptr<Function> bprop,
std::vector<std::shared_ptr<Node>> adjoints);
FpropCache cache_fprop(std::shared_ptr<Function> fprop, std::shared_ptr<Function> bprop);
} // end namespace ngraph
......@@ -1533,7 +1533,7 @@ TEST(cpu_fusion, maxpool_with_indices_in_mxnet)
auto maybe_bf = bfa.first;
auto adjoints = bfa.second;
optimize_graph(f, maybe_bf);
auto fprop_cache = ngraph::cache_fprop(f, maybe_bf, adjoints);
auto fprop_cache = ngraph::cache_fprop(f, maybe_bf);
auto mpwi_bprop = fprop_cache.bprop->get_results().at(0)->get_argument(0);
ASSERT_TRUE(std::dynamic_pointer_cast<op::Parameter>(mpwi_bprop->get_argument(0)));
......
......@@ -166,15 +166,14 @@ namespace ngraph
// create fprop cache
// creates modified forward function -> (y, cached) = f(x)
// creates modified backward function -> df/dX* = f'(c, cached)
auto fprop_cache = cache_fprop(f, df, {c_param});
auto fprop_cache = cache_fprop(f, df);
// (y, cached) arguments
std::vector<std::shared_ptr<runtime::TensorView>> mod_f_output_args;
mod_f_output_args.push_back(backend->create_tensor<T>(y_shape));
// (c, cached) arguments
std::vector<std::shared_ptr<runtime::TensorView>> mod_df_input_args;
mod_df_input_args.push_back(c_arg);
std::vector<std::shared_ptr<runtime::TensorView>> mod_df_input_args = df_input_args;
// add cached nodes to both modified f output and modified f' input arguments
for (auto node : fprop_cache.fprop_output_nodes)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment