Commit eaa85e1c authored by Matthew Brookhart's avatar Matthew Brookhart Committed by Scott Cyphers

add some fixes to speed up RN50 training (#1840)

parent ab28a4a1
......@@ -45,9 +45,17 @@ std::shared_ptr<Node> make_zero(const std::shared_ptr<Node>& node)
NodeVector make_zeros(std::shared_ptr<Node> x)
{
NodeVector zeros;
for (size_t i = 0; i < x->get_outputs().size(); ++i)
if (x->get_outputs().size() > 1)
{
zeros.push_back(make_zero(get_output_element(x, i)));
auto goes = op::get_output_elements(x);
for (size_t i = 0; i < goes.size(); ++i)
{
zeros.push_back(make_zero(goes.at(i)));
}
}
else
{
zeros.push_back(make_zero(x));
}
return zeros;
}
......
......@@ -48,6 +48,6 @@ void op::Relu::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{
auto delta = deltas.at(0);
auto backprop = make_shared<op::ReluBackprop>(get_argument(0), delta);
auto backprop = make_shared<op::ReluBackprop>(shared_from_this(), delta);
adjoints.add_delta(get_argument(0), backprop);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment