Commit 11b992a7 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

bn bprop test fix, comments and throws (#1325)

parent 0599a628
......@@ -272,6 +272,14 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
auto goes = op::get_output_elements(shared_from_this());
mean = goes.at(1);
var = goes.at(2);
if (!mean)
{
throw ngraph_error("GetOutputElement for mean is missing");
};
if (!var)
{
throw ngraph_error("GetOutputElement for variance is missing");
}
}
else // BatchNorm Training with global stats
{
......
......@@ -56,12 +56,12 @@ void op::GetOutputElement::generate_adjoints(autodiff::Adjoints& adjoints, const
NodeVector op::get_output_elements(const shared_ptr<Node>& mon)
{
NodeVector goes;
NodeVector goes(mon->get_outputs().size());
for (size_t i = 0; i < mon->get_outputs().size(); i++)
for (auto goe_input : mon->get_output_inputs(0))
{
auto goe = make_shared<GetOutputElement>(mon, i);
goes.push_back(std::static_pointer_cast<Node>(goe));
auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(goe_input->get_node());
goes.at(goe->get_n()) = goe_input->get_node();
}
return goes;
}
......@@ -1589,12 +1589,20 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_three_outputs)
auto shape_in = Shape{2, 3, 1, 1};
auto shape_mean = Shape{3};
auto make_graph = [shape_in, shape_mean] {
//we need to keep GOEs for mean and variance alive
//even though those aren't used as outputs for fprop
//they are needed for a bprop pass
NodeVector goes;
auto make_graph = [&goes, shape_in, shape_mean] {
auto A = make_shared<op::Parameter>(element::f64, shape_in);
auto B = make_shared<op::Parameter>(element::f64, shape_mean);
auto C = make_shared<op::Parameter>(element::f64, shape_mean);
auto BN = make_shared<op::BatchNorm>(1e-3, B, C, A);
//make sure we create GOEs for mean and variance needed for bprop
goes.push_back(make_shared<op::GetOutputElement>(BN, 1));
goes.push_back(make_shared<op::GetOutputElement>(BN, 2));
auto f = make_shared<Function>(make_shared<op::GetOutputElement>(BN, 0),
op::ParameterVector{A, B, C});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment