Commit 11b992a7 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

bn bprop test fix, comments and throws (#1325)

parent 0599a628
...@@ -272,6 +272,14 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -272,6 +272,14 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
auto goes = op::get_output_elements(shared_from_this()); auto goes = op::get_output_elements(shared_from_this());
mean = goes.at(1); mean = goes.at(1);
var = goes.at(2); var = goes.at(2);
if (!mean)
{
throw ngraph_error("GetOutputElement for mean is missing");
};
if (!var)
{
throw ngraph_error("GetOutputElement for variance is missing");
}
} }
else // BatchNorm Training with global stats else // BatchNorm Training with global stats
{ {
......
...@@ -56,12 +56,12 @@ void op::GetOutputElement::generate_adjoints(autodiff::Adjoints& adjoints, const ...@@ -56,12 +56,12 @@ void op::GetOutputElement::generate_adjoints(autodiff::Adjoints& adjoints, const
NodeVector op::get_output_elements(const shared_ptr<Node>& mon) NodeVector op::get_output_elements(const shared_ptr<Node>& mon)
{ {
NodeVector goes; NodeVector goes(mon->get_outputs().size());
for (size_t i = 0; i < mon->get_outputs().size(); i++) for (auto goe_input : mon->get_output_inputs(0))
{ {
auto goe = make_shared<GetOutputElement>(mon, i); auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(goe_input->get_node());
goes.push_back(std::static_pointer_cast<Node>(goe)); goes.at(goe->get_n()) = goe_input->get_node();
} }
return goes; return goes;
} }
...@@ -1589,12 +1589,20 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_three_outputs) ...@@ -1589,12 +1589,20 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_three_outputs)
auto shape_in = Shape{2, 3, 1, 1}; auto shape_in = Shape{2, 3, 1, 1};
auto shape_mean = Shape{3}; auto shape_mean = Shape{3};
auto make_graph = [shape_in, shape_mean] { //we need to keep GOEs for mean and variance alive
//even though those aren't used as outputs for fprop
//they are needed for a bprop pass
NodeVector goes;
auto make_graph = [&goes, shape_in, shape_mean] {
auto A = make_shared<op::Parameter>(element::f64, shape_in); auto A = make_shared<op::Parameter>(element::f64, shape_in);
auto B = make_shared<op::Parameter>(element::f64, shape_mean); auto B = make_shared<op::Parameter>(element::f64, shape_mean);
auto C = make_shared<op::Parameter>(element::f64, shape_mean); auto C = make_shared<op::Parameter>(element::f64, shape_mean);
auto BN = make_shared<op::BatchNorm>(1e-3, B, C, A); auto BN = make_shared<op::BatchNorm>(1e-3, B, C, A);
//make sure we create GOEs for mean and variance needed for bprop
goes.push_back(make_shared<op::GetOutputElement>(BN, 1));
goes.push_back(make_shared<op::GetOutputElement>(BN, 2));
auto f = make_shared<Function>(make_shared<op::GetOutputElement>(BN, 0), auto f = make_shared<Function>(make_shared<op::GetOutputElement>(BN, 0),
op::ParameterVector{A, B, C}); op::ParameterVector{A, B, C});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment