Unverified Commit 995671ae authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

[v0.1.0] Multi-output fprop_cache tentative fix (#657)

Contains multiple fixes to GetOutputElement, BatchNorm, autodiff, fprop_cache to integrate multi-output batchnorm and fprop_cache 
parent feeaed57
......@@ -125,7 +125,7 @@ std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x)
void autodiff::Adjoints::add_delta(const std::shared_ptr<Node>& x,
const std::shared_ptr<Node>& delta)
{
if (!x->has_same_type(delta))
if (!x->has_same_type(delta) && delta->get_shape() != x->get_outputs().at(0).get_shape())
{
throw ngraph_error("Autodiff internal error: Mismatch on backprop and op in add_delta.");
}
......
......@@ -148,8 +148,23 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
auto gamma = get_input_op(0);
auto beta = get_input_op(1);
auto input = get_input_op(2);
auto mean = std::make_shared<op::GetOutputElement>(shared_from_this(), 1);
auto var = std::make_shared<op::GetOutputElement>(shared_from_this(), 2);
//Extract mean and variance outputs from BatchNorm
//as these are used by BatchNormBackprop.
//The users of the outputs (GetOutputElements' Inputs) aren't sorted
//and get_n() is used to sort the inputs in the same order as Batchnorm's outputs
//Next, Mean and Variance (`at(1)` and `at(2)`) are extracted
//Please see `add_output` in `BatchNorm::BatchNorm` for more details
std::vector<std::shared_ptr<Node>> goes(get_outputs().size());
for (auto _input : get_output_inputs(0))
{
auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(_input->get_node());
goes.at(goe->get_n()) = _input->get_node();
}
auto mean = goes.at(1);
auto var = goes.at(2);
auto bbn = std::make_shared<op::BatchNormBackprop>(
get_eps_value(), gamma, beta, input, mean, var, delta);
auto dinput = std::make_shared<op::GetOutputElement>(bbn, 0);
......
......@@ -68,6 +68,17 @@ namespace ngraph
}
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override
{
//Filter out updates(deltas) from mean and variance (for batchnorm)
//as dinput is the only update required.
//This logic needs to be generalized as new multi-output ops are introduced
if (get_n() == 0)
{
adjoints.add_delta(get_inputs().at(0).get_output().get_node(), delta);
}
}
size_t m_n;
};
}
......
......@@ -189,57 +189,40 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
{
using namespace ngraph;
// Traverse fprop to make a map that stores parameters with the same
// shape and element type as the nodes in fprop
NodeMap node_param_map;
ngraph::traverse_nodes(fprop, [&node_param_map](std::shared_ptr<Node> node) {
node_param_map.add(
node, std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape()));
});
// Traverse bprop to find all of the nodes in the graph
std::unordered_set<std::shared_ptr<Node>> in_bprop;
ngraph::traverse_nodes(bprop, [&in_bprop](std::shared_ptr<Node> node) {
if (node->get_outputs().size() == 1)
{
if (in_bprop.count(node) == 0)
{
in_bprop.insert(node);
}
}
});
// Get the input paramters of fprop
std::unordered_set<std::shared_ptr<Node>> fprop_params;
for (auto node : fprop->get_parameters())
{
if (fprop_params.count(node) == 0)
// Traverse fprop to make a map that stores parameters with the same
// shape and element type as the nodes in fprop
FpropCache fprop_cache;
fprop_cache.node_param_map = std::make_shared<NodeMap>();
ngraph::traverse_nodes(fprop, [&fprop_cache, &in_bprop](std::shared_ptr<Node> node) {
if (in_bprop.count(node) != 0)
{
fprop_params.insert(node);
}
fprop_cache.node_param_map->add(
node, std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape()));
}
});
// Find all of the nodes that are intermediate values of fprop and used in
// bprop
// and store those nodes that aren't needed in bprop
FpropCache fprop_cache;
std::vector<std::shared_ptr<Node>> unused_nodes;
for (auto kv : node_param_map.get_node_map())
{
// if it's not in bprop, mark it unused
if (in_bprop.count(kv.first) == 0)
{
unused_nodes.push_back(kv.first);
}
// otherwise save in in the ouputs
else
for (auto kv : fprop_cache.node_param_map->get_node_map())
{
fprop_cache.fprop_output_nodes.push_back(kv.first);
}
}
// erase all unused nodes form the map
for (auto node : unused_nodes)
{
node_param_map.get_node_map().erase(node);
}
// create the new outputs for fprop and the new fprop function
ResultVector fprop_outputs;
......@@ -262,13 +245,13 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// clone the nodes in bprop, replacing fprop-related nodes with the
// intermediate parameters
ngraph::clone_nodes(bprop->get_ops(), node_param_map);
ngraph::clone_nodes(bprop->get_ops(), *(fprop_cache.node_param_map));
// get cloned bprop results
ResultVector cloned_results;
for (auto node : bprop->get_results())
{
auto result = std::dynamic_pointer_cast<op::Result>(node_param_map.get(node));
auto result = std::dynamic_pointer_cast<op::Result>(fprop_cache.node_param_map->get(node));
if (!result)
{
throw ngraph_error("Expected op::Result values for op::Result keys in node_param_map");
......@@ -281,14 +264,14 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
for (auto param : adjoints)
{
bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(node_param_map.get(param)));
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(param)));
}
// add the cached fprop nodes as inputs to bprop
for (auto x : fprop_cache.fprop_output_nodes)
{
bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(node_param_map.get(x)));
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(x)));
}
// create the new bprop function
......
......@@ -28,6 +28,7 @@ namespace ngraph
{
class Node;
class Function;
class NodeMap;
class stopwatch;
namespace runtime
......@@ -229,6 +230,7 @@ namespace ngraph
std::shared_ptr<Function> fprop;
std::shared_ptr<Function> bprop;
std::vector<std::shared_ptr<Node>> fprop_output_nodes;
std::shared_ptr<NodeMap> node_param_map;
};
/**
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment