Commit efa2561e authored by Matthew Brookhart's avatar Matthew Brookhart Committed by Scott Cyphers

reduce fprop cache outputs (#1343)

* reduce fprop cache outputs

* refactor traverse nodes

* Slight refactor, add test, adress PR comments

* fix formatting
parent f1c29c9c
......@@ -46,17 +46,33 @@ void ngraph::traverse_nodes(const std::shared_ptr<const Function> p,
void ngraph::traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f)
{
std::unordered_set<std::shared_ptr<Node>> instances_seen;
std::deque<std::shared_ptr<Node>> stack;
NodeVector nodes;
for (auto r : p->get_results())
{
stack.push_front(r);
nodes.push_back(r);
}
for (auto param : p->get_parameters())
{
stack.push_front(param);
nodes.push_back(param);
}
traverse_nodes(nodes, f);
}
// This version of traverses directly from input/output nodes to perform functions on
// graphs that are not wrapped by functions. Most useful for finding parameters of a graph
// directly from the result nodes, not from function parameters.
void ngraph::traverse_nodes(const NodeVector& io_nodes,
std::function<void(std::shared_ptr<Node>)> f)
{
std::unordered_set<std::shared_ptr<Node>> instances_seen;
std::deque<std::shared_ptr<Node>> stack;
for (auto r : io_nodes)
{
stack.push_front(r);
}
while (stack.size() > 0)
......
......@@ -44,6 +44,8 @@ namespace ngraph
std::function<void(std::shared_ptr<Node>)> f);
void traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f);
void traverse_nodes(const NodeVector& io_nodes, std::function<void(std::shared_ptr<Node>)> f);
void traverse_functions(std::shared_ptr<Function> p,
std::function<void(std::shared_ptr<Function>)> f);
......
......@@ -196,10 +196,13 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
{
using namespace ngraph;
// Traverse bprop to find all of the nodes in the graph
// Create a fprop_cache object to store the results of this analysis
FpropCache fprop_cache;
fprop_cache.node_param_map = std::make_shared<NodeMap>();
// Traverse bprop to find all of the nodes in the bprop graph
std::unordered_set<std::shared_ptr<Node>> in_bprop;
ngraph::traverse_nodes(bprop, [&in_bprop](std::shared_ptr<Node> node) {
if (node->get_outputs().size() == 1)
{
if (in_bprop.count(node) == 0)
......@@ -207,15 +210,12 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
in_bprop.insert(node);
}
}
});
// Traverse fprop to make a map that stores parameters with the same
// shape and element type as the nodes in fprop
FpropCache fprop_cache;
fprop_cache.node_param_map = std::make_shared<NodeMap>();
// shape and element type as the nodes in fprop iff they are in bprop
// and aren't inputs to bprop
auto bprop_inputs = bprop->get_parameters();
ngraph::traverse_nodes(
fprop, [&fprop_cache, &in_bprop, &bprop_inputs](std::shared_ptr<Node> node) {
if (in_bprop.count(node) != 0 &&
......@@ -227,39 +227,22 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
}
});
// Find all of the nodes that are intermediate values of fprop and used in
// bprop and store those nodes that aren't needed in bprop
std::vector<std::shared_ptr<Node>> unused_nodes;
for (auto kv : fprop_cache.node_param_map->get_node_map())
{
fprop_cache.fprop_output_nodes.push_back(kv.first);
}
// create the new outputs for fprop and the new fprop function
ResultVector fprop_outputs;
for (auto fpr : fprop->get_results())
{
fprop_outputs.push_back(fpr);
}
// clone the nodes in bprop, replacing fprop-related nodes with the
// intermediate parameters from fprop_cache. This breaks connections in the
// bprop graph such that only intermediate values from fprop needed by bprop
// are still connected to the bprop graph as parameters
ngraph::clone_nodes(bprop->get_ops(), *(fprop_cache.node_param_map));
for (auto fpir : fprop_cache.fprop_output_nodes)
//invert the fprop_cache cloned node map for easy back and for acces.
std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>> inverted_node_map;
for (auto kv : fprop_cache.node_param_map->get_node_map())
{
if (std::dynamic_pointer_cast<op::Result>(fpir))
{
throw ngraph_error("Expected op::Result in fprop->get_results()");
}
fprop_outputs.push_back(std::make_shared<op::Result>(fpir));
inverted_node_map[kv.second] = kv.first;
}
fprop_cache.fprop = std::make_shared<Function>(fprop_outputs, fprop->get_parameters());
// clone the nodes in bprop, replacing fprop-related nodes with the
// intermediate parameters
ngraph::clone_nodes(bprop->get_ops(), *(fprop_cache.node_param_map));
// get cloned bprop results
ResultVector cloned_results;
NodeVector result_nodes;
for (auto node : bprop->get_results())
{
auto result = std::dynamic_pointer_cast<op::Result>(fprop_cache.node_param_map->get(node));
......@@ -268,25 +251,60 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
throw ngraph_error("Expected op::Result values for op::Result keys in node_param_map");
}
cloned_results.push_back(result);
result_nodes.push_back(result);
}
// get clone bprop parameters
op::ParameterVector bprop_input_params;
for (auto param : bprop_inputs)
{
bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(param)));
}
// Utility for getting bprop parameters with fprop cache.
auto get_bprop_params = [&bprop_inputs, &fprop_cache]() {
// get cloned bprop parameters
op::ParameterVector bprop_input_params;
for (auto param : bprop_inputs)
{
bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(param)));
}
// add the cached fprop nodes as inputs to bprop
for (auto x : fprop_cache.fprop_output_nodes)
{
bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(x)));
}
return bprop_input_params;
};
// Traverse the graph from the cloned results of bprop. If we find a parameter
// that's not an original input of bprop, this is an intermediate value of
// fprop that needs to be returned from fprop and send to bprop
auto cloned_bprop_inputs = get_bprop_params();
ngraph::traverse_nodes(
result_nodes,
[&cloned_bprop_inputs, &fprop_cache, &inverted_node_map](std::shared_ptr<Node> node) {
auto pnode = std::dynamic_pointer_cast<op::Parameter>(node);
if (pnode != nullptr &&
std::find(cloned_bprop_inputs.begin(), cloned_bprop_inputs.end(), pnode) ==
cloned_bprop_inputs.end())
{
fprop_cache.fprop_output_nodes.push_back(inverted_node_map.at(node));
}
});
// add the cached fprop nodes as inputs to bprop
for (auto x : fprop_cache.fprop_output_nodes)
// create the new outputs for fprop and the new fprop function
ResultVector fprop_outputs = fprop->get_results();
for (auto fpir : fprop_cache.fprop_output_nodes)
{
bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(x)));
if (std::dynamic_pointer_cast<op::Result>(fpir))
{
throw ngraph_error("Expected op::Result in fprop->get_results()");
}
fprop_outputs.push_back(std::make_shared<op::Result>(fpir));
}
// create the new bprop function
fprop_cache.bprop = std::make_shared<Function>(cloned_results, bprop_input_params);
fprop_cache.fprop = std::make_shared<Function>(fprop_outputs, fprop->get_parameters());
// Create the new bprop function with cloned results and cached parameters.
fprop_cache.bprop = std::make_shared<Function>(cloned_results, get_bprop_params());
return fprop_cache;
}
......
......@@ -27,6 +27,7 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/serializer.hpp"
#include "util/all_close.hpp"
#include "util/autodiff/backprop_function.hpp"
#include "util/ndarray.hpp"
using namespace std;
......@@ -365,3 +366,21 @@ TEST(graph_util, get_subgraph_outputs_trivial_tests)
outputs = ngraph::get_subgraph_outputs(NodeVector{B, abs_b, abs_b_neg}, NodeVector{});
ASSERT_EQ(outputs, (NodeVector{B, abs_b_neg}));
}
TEST(util, test_fprop_cache)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto output = (A + B) * C + A;
auto f = make_shared<Function>(NodeVector{output}, op::ParameterVector{A, B, C});
auto bf = autodiff::backprop_function(f);
auto fprop_cache = cache_fprop(f, bf);
EXPECT_EQ(fprop_cache.fprop->get_results().size(), 2);
EXPECT_EQ(fprop_cache.bprop->get_parameters().size(), 5);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment