Unverified Commit 995671ae authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

[v0.1.0] Multi-output fprop_cache tentative fix (#657)

Contains multiple fixes to GetOutputElement, BatchNorm, autodiff, fprop_cache to integrate multi-output batchnorm and fprop_cache 
parent feeaed57
...@@ -125,7 +125,7 @@ std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x) ...@@ -125,7 +125,7 @@ std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x)
void autodiff::Adjoints::add_delta(const std::shared_ptr<Node>& x, void autodiff::Adjoints::add_delta(const std::shared_ptr<Node>& x,
const std::shared_ptr<Node>& delta) const std::shared_ptr<Node>& delta)
{ {
if (!x->has_same_type(delta)) if (!x->has_same_type(delta) && delta->get_shape() != x->get_outputs().at(0).get_shape())
{ {
throw ngraph_error("Autodiff internal error: Mismatch on backprop and op in add_delta."); throw ngraph_error("Autodiff internal error: Mismatch on backprop and op in add_delta.");
} }
......
...@@ -148,8 +148,23 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -148,8 +148,23 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
auto gamma = get_input_op(0); auto gamma = get_input_op(0);
auto beta = get_input_op(1); auto beta = get_input_op(1);
auto input = get_input_op(2); auto input = get_input_op(2);
auto mean = std::make_shared<op::GetOutputElement>(shared_from_this(), 1);
auto var = std::make_shared<op::GetOutputElement>(shared_from_this(), 2); //Extract mean and variance outputs from BatchNorm
//as these are used by BatchNormBackprop.
//The users of the outputs (GetOutputElements' Inputs) aren't sorted
//and get_n() is used to sort the inputs in the same order as Batchnorm's outputs
//Next, Mean and Variance (`at(1)` and `at(2)`) are extracted
//Please see `add_output` in `BatchNorm::BatchNorm` for more details
std::vector<std::shared_ptr<Node>> goes(get_outputs().size());
for (auto _input : get_output_inputs(0))
{
auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(_input->get_node());
goes.at(goe->get_n()) = _input->get_node();
}
auto mean = goes.at(1);
auto var = goes.at(2);
auto bbn = std::make_shared<op::BatchNormBackprop>( auto bbn = std::make_shared<op::BatchNormBackprop>(
get_eps_value(), gamma, beta, input, mean, var, delta); get_eps_value(), gamma, beta, input, mean, var, delta);
auto dinput = std::make_shared<op::GetOutputElement>(bbn, 0); auto dinput = std::make_shared<op::GetOutputElement>(bbn, 0);
......
...@@ -68,6 +68,17 @@ namespace ngraph ...@@ -68,6 +68,17 @@ namespace ngraph
} }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override
{
//Filter out updates(deltas) from mean and variance (for batchnorm)
//as dinput is the only update required.
//This logic needs to be generalized as new multi-output ops are introduced
if (get_n() == 0)
{
adjoints.add_delta(get_inputs().at(0).get_output().get_node(), delta);
}
}
size_t m_n; size_t m_n;
}; };
} }
......
...@@ -189,56 +189,39 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -189,56 +189,39 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
{ {
using namespace ngraph; using namespace ngraph;
// Traverse fprop to make a map that stores parameters with the same
// shape and element type as the nodes in fprop
NodeMap node_param_map;
ngraph::traverse_nodes(fprop, [&node_param_map](std::shared_ptr<Node> node) {
node_param_map.add(
node, std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape()));
});
// Traverse bprop to find all of the nodes in the graph // Traverse bprop to find all of the nodes in the graph
std::unordered_set<std::shared_ptr<Node>> in_bprop; std::unordered_set<std::shared_ptr<Node>> in_bprop;
ngraph::traverse_nodes(bprop, [&in_bprop](std::shared_ptr<Node> node) { ngraph::traverse_nodes(bprop, [&in_bprop](std::shared_ptr<Node> node) {
if (in_bprop.count(node) == 0)
if (node->get_outputs().size() == 1)
{ {
in_bprop.insert(node); if (in_bprop.count(node) == 0)
{
in_bprop.insert(node);
}
} }
}); });
// Get the input paramters of fprop // Traverse fprop to make a map that stores parameters with the same
std::unordered_set<std::shared_ptr<Node>> fprop_params; // shape and element type as the nodes in fprop
for (auto node : fprop->get_parameters()) FpropCache fprop_cache;
{ fprop_cache.node_param_map = std::make_shared<NodeMap>();
if (fprop_params.count(node) == 0) ngraph::traverse_nodes(fprop, [&fprop_cache, &in_bprop](std::shared_ptr<Node> node) {
if (in_bprop.count(node) != 0)
{ {
fprop_params.insert(node); fprop_cache.node_param_map->add(
node, std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape()));
} }
} });
// Find all of the nodes that are intermediate values of fprop and used in // Find all of the nodes that are intermediate values of fprop and used in
// bprop // bprop
// and store those nodes that aren't needed in bprop // and store those nodes that aren't needed in bprop
FpropCache fprop_cache;
std::vector<std::shared_ptr<Node>> unused_nodes; std::vector<std::shared_ptr<Node>> unused_nodes;
for (auto kv : node_param_map.get_node_map()) for (auto kv : fprop_cache.node_param_map->get_node_map())
{
// if it's not in bprop, mark it unused
if (in_bprop.count(kv.first) == 0)
{
unused_nodes.push_back(kv.first);
}
// otherwise save in in the ouputs
else
{
fprop_cache.fprop_output_nodes.push_back(kv.first);
}
}
// erase all unused nodes form the map
for (auto node : unused_nodes)
{ {
node_param_map.get_node_map().erase(node); fprop_cache.fprop_output_nodes.push_back(kv.first);
} }
// create the new outputs for fprop and the new fprop function // create the new outputs for fprop and the new fprop function
...@@ -262,13 +245,13 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -262,13 +245,13 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// clone the nodes in bprop, replacing fprop-related nodes with the // clone the nodes in bprop, replacing fprop-related nodes with the
// intermediate parameters // intermediate parameters
ngraph::clone_nodes(bprop->get_ops(), node_param_map); ngraph::clone_nodes(bprop->get_ops(), *(fprop_cache.node_param_map));
// get cloned bprop results // get cloned bprop results
ResultVector cloned_results; ResultVector cloned_results;
for (auto node : bprop->get_results()) for (auto node : bprop->get_results())
{ {
auto result = std::dynamic_pointer_cast<op::Result>(node_param_map.get(node)); auto result = std::dynamic_pointer_cast<op::Result>(fprop_cache.node_param_map->get(node));
if (!result) if (!result)
{ {
throw ngraph_error("Expected op::Result values for op::Result keys in node_param_map"); throw ngraph_error("Expected op::Result values for op::Result keys in node_param_map");
...@@ -281,14 +264,14 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -281,14 +264,14 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
for (auto param : adjoints) for (auto param : adjoints)
{ {
bprop_input_params.push_back( bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(node_param_map.get(param))); std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(param)));
} }
// add the cached fprop nodes as inputs to bprop // add the cached fprop nodes as inputs to bprop
for (auto x : fprop_cache.fprop_output_nodes) for (auto x : fprop_cache.fprop_output_nodes)
{ {
bprop_input_params.push_back( bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(node_param_map.get(x))); std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(x)));
} }
// create the new bprop function // create the new bprop function
......
...@@ -28,6 +28,7 @@ namespace ngraph ...@@ -28,6 +28,7 @@ namespace ngraph
{ {
class Node; class Node;
class Function; class Function;
class NodeMap;
class stopwatch; class stopwatch;
namespace runtime namespace runtime
...@@ -229,6 +230,7 @@ namespace ngraph ...@@ -229,6 +230,7 @@ namespace ngraph
std::shared_ptr<Function> fprop; std::shared_ptr<Function> fprop;
std::shared_ptr<Function> bprop; std::shared_ptr<Function> bprop;
std::vector<std::shared_ptr<Node>> fprop_output_nodes; std::vector<std::shared_ptr<Node>> fprop_output_nodes;
std::shared_ptr<NodeMap> node_param_map;
}; };
/** /**
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment