Commit ca06e6c3 authored by Matthew Brookhart's avatar Matthew Brookhart Committed by Scott Cyphers

rework autodiff of tanh for stability on LC (#591)

parent 22347363
...@@ -15,16 +15,13 @@ ...@@ -15,16 +15,13 @@
*******************************************************************************/ *******************************************************************************/
#include "ngraph/ops/tanh.hpp" #include "ngraph/ops/tanh.hpp"
#include "ngraph/ops/cosh.hpp"
#include "ngraph/ops/divide.hpp"
#include "ngraph/ops/multiply.hpp" #include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/subtract.hpp"
void ngraph::op::Tanh::generate_adjoints(autodiff::Adjoints& adjoints, void ngraph::op::Tanh::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) const std::shared_ptr<Node>& delta)
{ {
auto x = get_input_op(0); auto x = get_input_op(0);
auto c = std::make_shared<op::Cosh>(x); adjoints.add_delta(x, delta - (delta * (shared_from_this() * shared_from_this())));
adjoints.add_delta(x, delta / (c * c));
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment