Commit bc63f7bb authored by Matthew Brookhart's avatar Matthew Brookhart Committed by Scott Cyphers

Fprop Cache Util Function (#312)

* in progress

* working cache_fprop, no tests

* style fix

* all inputs to bprop (except adjoints) are cached from fprop

* fix typos, make sure to check count == 0

* fix code format
parent 8f3da6b8
...@@ -22,7 +22,10 @@ ...@@ -22,7 +22,10 @@
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/ops/xla_tuple.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "ngraph/xla_function.hpp"
using namespace std; using namespace std;
...@@ -302,25 +305,6 @@ std::list<std::shared_ptr<ngraph::Node>> ...@@ -302,25 +305,6 @@ std::list<std::shared_ptr<ngraph::Node>>
return result_list; return result_list;
} }
void ngraph::NodeMap::Add(std::shared_ptr<ngraph::Node> orig,
std::shared_ptr<ngraph::Node> replacement)
{
if (Exists(orig))
{
throw ngraph_error("NodeMap: key already exists");
}
node_map_[orig] = replacement;
}
std::shared_ptr<ngraph::Node> ngraph::NodeMap::operator[](std::shared_ptr<ngraph::Node> orig) const
{
if (!Exists(orig))
{
throw ngraph_error("NodeMap: key does not exist");
}
return node_map_.at(orig);
}
std::list<std::shared_ptr<ngraph::Node>> std::list<std::shared_ptr<ngraph::Node>>
ngraph::clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map) ngraph::clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map)
{ {
...@@ -328,7 +312,7 @@ std::list<std::shared_ptr<ngraph::Node>> ...@@ -328,7 +312,7 @@ std::list<std::shared_ptr<ngraph::Node>>
auto sorted_nodes = topological_sort(nodes); auto sorted_nodes = topological_sort(nodes);
for (auto node : sorted_nodes) for (auto node : sorted_nodes)
{ {
if (!node_map.Exists(node)) if (node_map.count(node) == 0)
{ {
// get (already) cloned arguments and clone the node // get (already) cloned arguments and clone the node
Nodes cloned_args; Nodes cloned_args;
...@@ -336,7 +320,7 @@ std::list<std::shared_ptr<ngraph::Node>> ...@@ -336,7 +320,7 @@ std::list<std::shared_ptr<ngraph::Node>>
{ {
cloned_args.push_back(node_map[arg]); cloned_args.push_back(node_map[arg]);
} }
node_map.Add(node, node->copy_with_new_args(cloned_args)); node_map[node] = node->copy_with_new_args(cloned_args);
} }
} }
...@@ -402,3 +386,100 @@ size_t ngraph::round_up(size_t size, size_t alignment) ...@@ -402,3 +386,100 @@ size_t ngraph::round_up(size_t size, size_t alignment)
return size + alignment - remainder; return size + alignment - remainder;
} }
ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::XLAFunction> fprop,
std::shared_ptr<ngraph::XLAFunction> bprop,
std::vector<std::shared_ptr<Node>> adjoints)
{
using namespace ngraph;
// Traverse fprop to make a map that stores parameters with the same
// shape and element type as the nodes in fprop
NodeMap node_param_map;
ngraph::traverse_nodes(fprop, [&node_param_map](std::shared_ptr<Node> node) {
node_param_map[node] =
std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape());
});
// Traverse bprop to find all of the nodes in the graph
std::unordered_set<std::shared_ptr<Node>> in_bprop;
ngraph::traverse_nodes(bprop, [&in_bprop](std::shared_ptr<Node> node) {
if (in_bprop.count(node) == 0)
{
in_bprop.insert(node);
}
});
// Get the input paramters of fprop
std::unordered_set<std::shared_ptr<Node>> fprop_params;
for (auto node : fprop->get_parameters())
{
if (fprop_params.count(node) == 0)
{
fprop_params.insert(node);
}
}
// Find all of the nodes that are intermediate values of fprop and used in
// bprop
// and store those nodes that aren't needed in bprop
FpropCache fprop_cache;
std::vector<std::shared_ptr<Node>> unused_nodes;
for (auto kv : node_param_map)
{
// if it's not in bprop, mark it unused
if (in_bprop.count(kv.first) == 0)
{
unused_nodes.push_back(kv.first);
}
// otherwise save in in the ouputs
else
{
fprop_cache.fprop_output_nodes.push_back(kv.first);
}
}
// erase all unused nodes form the map
for (auto node : unused_nodes)
{
node_param_map.erase(node);
}
// create the new outputs for fprop and the new fprop function
std::vector<std::shared_ptr<Node>> fprop_outputs{fprop->get_results()};
fprop_outputs.insert(fprop_outputs.end(),
fprop_cache.fprop_output_nodes.begin(),
fprop_cache.fprop_output_nodes.end());
auto outTuple = std::make_shared<op::XLATuple>(fprop_outputs);
auto outTupleType = outTuple->get_value_type();
fprop_cache.fprop =
std::make_shared<XLAFunction>(outTuple, outTupleType, fprop->get_parameters());
// clone the nodes in bprop, replacing fprop-related nodes with the
// intermediate parameters
ngraph::clone_nodes(bprop->get_ops(), node_param_map);
// get cloned bprop results
auto cloned_result = node_param_map[bprop->get_result()];
// get clone bprop parameters
op::Parameters bprop_input_params;
for (auto param : adjoints)
{
bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(node_param_map[param]));
}
// add the cached fprop nodes as inputs to bprop
for (auto x : fprop_cache.fprop_output_nodes)
{
bprop_input_params.push_back(std::dynamic_pointer_cast<op::Parameter>(node_param_map[x]));
}
// create the new bprop function
fprop_cache.bprop = std::make_shared<XLAFunction>(
cloned_result, cloned_result->get_value_type(), bprop_input_params);
return fprop_cache;
}
...@@ -30,8 +30,15 @@ namespace ngraph ...@@ -30,8 +30,15 @@ namespace ngraph
{ {
class Node; class Node;
class Function; class Function;
class XLAFunction;
class stopwatch; class stopwatch;
namespace runtime
{
class Backend;
class Value;
}
template <typename T> template <typename T>
std::string join(const T& v, const std::string& sep = ", ") std::string join(const T& v, const std::string& sep = ", ")
{ {
...@@ -244,28 +251,8 @@ namespace ngraph ...@@ -244,28 +251,8 @@ namespace ngraph
std::list<std::shared_ptr<Node>> std::list<std::shared_ptr<Node>>
topological_sort(const std::list<std::shared_ptr<Node>>& nodes); topological_sort(const std::list<std::shared_ptr<Node>>& nodes);
// maps original to replacement nodes e.g. for clone utilities using NodeMap =
// performs index checking on access std::unordered_map<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>>;
class NodeMap
{
public:
// map original node to replcacement node
// throws ngraph_error if key already exists
void Add(std::shared_ptr<ngraph::Node> orig, std::shared_ptr<ngraph::Node> replacement);
// get replacement node from original node
// throws ngrah_error if key does not exist
std::shared_ptr<ngraph::Node> operator[](std::shared_ptr<ngraph::Node> orig) const;
// returns true if original node is already mapped
bool Exists(std::shared_ptr<ngraph::Node> orig) const
{
return (node_map_.count(orig) != 0);
}
private:
std::unordered_map<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>> node_map_;
};
// input nodes are cloned and returned // input nodes are cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes // NodeMap input may contain default node mapping i.e. pre-cloned nodes
...@@ -282,4 +269,29 @@ namespace ngraph ...@@ -282,4 +269,29 @@ namespace ngraph
void* aligned_alloc(size_t alignment, size_t size); void* aligned_alloc(size_t alignment, size_t size);
void aligned_free(void*); void aligned_free(void*);
size_t round_up(size_t size, size_t alignment); size_t round_up(size_t size, size_t alignment);
/*
* Return type struct for cache_fprop, with the modified fprop and bprop
* functions
* and a list of the nodes that have been appended to fprop output/bprop
* input
*/
struct FpropCache
{
std::shared_ptr<XLAFunction> fprop;
std::shared_ptr<XLAFunction> bprop;
std::vector<std::shared_ptr<Node>> fprop_output_nodes;
};
/**
* This utility takes forward-propogation and back-propogation XLAunctions
* and turns them into clone functions where the intermediate values of
* the forward prop are added to the output of fprop and the input of the bprop
* to avoid repeat calcualtions.
* The last argument is the adjoints coming into the bprop function, the output
* bprop function will have these nodes as the first N input parameters
**/
FpropCache cache_fprop(std::shared_ptr<XLAFunction> fprop,
std::shared_ptr<XLAFunction> bprop,
std::vector<std::shared_ptr<Node>> adjoints);
} // end namespace ngraph } // end namespace ngraph
...@@ -280,7 +280,7 @@ public: ...@@ -280,7 +280,7 @@ public:
auto cloneit = clone.begin(); auto cloneit = clone.begin();
while (origit != orig.end() && cloneit != clone.end()) while (origit != orig.end() && cloneit != clone.end())
{ {
if (*cloneit != nm[*origit]) if (*cloneit != nm.at(*origit))
{ {
return false; return false;
} }
...@@ -311,7 +311,7 @@ TEST_F(CloneTest, clone_nodes_partial) ...@@ -311,7 +311,7 @@ TEST_F(CloneTest, clone_nodes_partial)
{ {
// map A -> A' prior to clone // map A -> A' prior to clone
auto Aprime = make_shared<op::Parameter>(element::f32, shape); auto Aprime = make_shared<op::Parameter>(element::f32, shape);
node_map.Add(A, Aprime); node_map[A] = Aprime;
auto cloned_nodes = clone_nodes(nodes, node_map); auto cloned_nodes = clone_nodes(nodes, node_map);
ASSERT_TRUE(CompareNodes(nodes, cloned_nodes, node_map)); ASSERT_TRUE(CompareNodes(nodes, cloned_nodes, node_map));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment