Commit 3b49dd1a authored by Matthew Brookhart's avatar Matthew Brookhart Committed by Scott Cyphers

refactor cache_prop to reuse bprop inputs (#1134)

parent b9a77a9d
...@@ -185,8 +185,7 @@ size_t ngraph::round_up(size_t size, size_t alignment) ...@@ -185,8 +185,7 @@ size_t ngraph::round_up(size_t size, size_t alignment)
} }
ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
std::shared_ptr<ngraph::Function> bprop, std::shared_ptr<ngraph::Function> bprop)
std::vector<std::shared_ptr<Node>> adjoints)
{ {
using namespace ngraph; using namespace ngraph;
...@@ -208,17 +207,21 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -208,17 +207,21 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// shape and element type as the nodes in fprop // shape and element type as the nodes in fprop
FpropCache fprop_cache; FpropCache fprop_cache;
fprop_cache.node_param_map = std::make_shared<NodeMap>(); fprop_cache.node_param_map = std::make_shared<NodeMap>();
ngraph::traverse_nodes(fprop, [&fprop_cache, &in_bprop](std::shared_ptr<Node> node) { auto bprop_inputs = bprop->get_parameters();
if (in_bprop.count(node) != 0)
{ ngraph::traverse_nodes(
fprop_cache.node_param_map->add( fprop, [&fprop_cache, &in_bprop, &bprop_inputs](std::shared_ptr<Node> node) {
node, std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape())); if (in_bprop.count(node) != 0 &&
} std::find(bprop_inputs.begin(), bprop_inputs.end(), node) == bprop_inputs.end())
}); {
fprop_cache.node_param_map->add(
node,
std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape()));
}
});
// Find all of the nodes that are intermediate values of fprop and used in // Find all of the nodes that are intermediate values of fprop and used in
// bprop // bprop and store those nodes that aren't needed in bprop
// and store those nodes that aren't needed in bprop
std::vector<std::shared_ptr<Node>> unused_nodes; std::vector<std::shared_ptr<Node>> unused_nodes;
for (auto kv : fprop_cache.node_param_map->get_node_map()) for (auto kv : fprop_cache.node_param_map->get_node_map())
{ {
...@@ -262,7 +265,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -262,7 +265,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// get clone bprop parameters // get clone bprop parameters
op::ParameterVector bprop_input_params; op::ParameterVector bprop_input_params;
for (auto param : adjoints) for (auto param : bprop_inputs)
{ {
bprop_input_params.push_back( bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(param))); std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(param)));
......
...@@ -259,7 +259,5 @@ namespace ngraph ...@@ -259,7 +259,5 @@ namespace ngraph
* The last argument is the adjoints coming into the bprop function, the output * The last argument is the adjoints coming into the bprop function, the output
* bprop function will have these nodes as the first N input parameters * bprop function will have these nodes as the first N input parameters
**/ **/
FpropCache cache_fprop(std::shared_ptr<Function> fprop, FpropCache cache_fprop(std::shared_ptr<Function> fprop, std::shared_ptr<Function> bprop);
std::shared_ptr<Function> bprop,
std::vector<std::shared_ptr<Node>> adjoints);
} // end namespace ngraph } // end namespace ngraph
...@@ -1533,7 +1533,7 @@ TEST(cpu_fusion, maxpool_with_indices_in_mxnet) ...@@ -1533,7 +1533,7 @@ TEST(cpu_fusion, maxpool_with_indices_in_mxnet)
auto maybe_bf = bfa.first; auto maybe_bf = bfa.first;
auto adjoints = bfa.second; auto adjoints = bfa.second;
optimize_graph(f, maybe_bf); optimize_graph(f, maybe_bf);
auto fprop_cache = ngraph::cache_fprop(f, maybe_bf, adjoints); auto fprop_cache = ngraph::cache_fprop(f, maybe_bf);
auto mpwi_bprop = fprop_cache.bprop->get_results().at(0)->get_argument(0); auto mpwi_bprop = fprop_cache.bprop->get_results().at(0)->get_argument(0);
ASSERT_TRUE(std::dynamic_pointer_cast<op::Parameter>(mpwi_bprop->get_argument(0))); ASSERT_TRUE(std::dynamic_pointer_cast<op::Parameter>(mpwi_bprop->get_argument(0)));
......
...@@ -166,15 +166,14 @@ namespace ngraph ...@@ -166,15 +166,14 @@ namespace ngraph
// create fprop cache // create fprop cache
// creates modified forward function -> (y, cached) = f(x) // creates modified forward function -> (y, cached) = f(x)
// creates modified backward function -> df/dX* = f'(c, cached) // creates modified backward function -> df/dX* = f'(c, cached)
auto fprop_cache = cache_fprop(f, df, {c_param}); auto fprop_cache = cache_fprop(f, df);
// (y, cached) arguments // (y, cached) arguments
std::vector<std::shared_ptr<runtime::TensorView>> mod_f_output_args; std::vector<std::shared_ptr<runtime::TensorView>> mod_f_output_args;
mod_f_output_args.push_back(backend->create_tensor<T>(y_shape)); mod_f_output_args.push_back(backend->create_tensor<T>(y_shape));
// (c, cached) arguments // (c, cached) arguments
std::vector<std::shared_ptr<runtime::TensorView>> mod_df_input_args; std::vector<std::shared_ptr<runtime::TensorView>> mod_df_input_args = df_input_args;
mod_df_input_args.push_back(c_arg);
// add cached nodes to both modified f output and modified f' input arguments // add cached nodes to both modified f output and modified f' input arguments
for (auto node : fprop_cache.fprop_output_nodes) for (auto node : fprop_cache.fprop_output_nodes)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment