Commit eaa85e1c authored by Matthew Brookhart's avatar Matthew Brookhart Committed by Scott Cyphers

add some fixes to speed up RN50 training (#1840)

parent ab28a4a1
...@@ -45,9 +45,17 @@ std::shared_ptr<Node> make_zero(const std::shared_ptr<Node>& node) ...@@ -45,9 +45,17 @@ std::shared_ptr<Node> make_zero(const std::shared_ptr<Node>& node)
NodeVector make_zeros(std::shared_ptr<Node> x) NodeVector make_zeros(std::shared_ptr<Node> x)
{ {
NodeVector zeros; NodeVector zeros;
for (size_t i = 0; i < x->get_outputs().size(); ++i) if (x->get_outputs().size() > 1)
{ {
zeros.push_back(make_zero(get_output_element(x, i))); auto goes = op::get_output_elements(x);
for (size_t i = 0; i < goes.size(); ++i)
{
zeros.push_back(make_zero(goes.at(i)));
}
}
else
{
zeros.push_back(make_zero(x));
} }
return zeros; return zeros;
} }
......
...@@ -48,6 +48,6 @@ void op::Relu::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -48,6 +48,6 @@ void op::Relu::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto backprop = make_shared<op::ReluBackprop>(get_argument(0), delta); auto backprop = make_shared<op::ReluBackprop>(shared_from_this(), delta);
adjoints.add_delta(get_argument(0), backprop); adjoints.add_delta(get_argument(0), backprop);
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment