Commit f6c6daef authored by Matthew Brookhart's avatar Matthew Brookhart Committed by Scott Cyphers

Simplify Dot Autodiff (#412)

* Simplify Dot Autodiff

* remove commented code
parent 72a2ce72
......@@ -133,11 +133,9 @@ void op::Dot::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_
J_shape.insert(J_shape.begin(), y_shape.begin(), y_shape.begin() + m_reduction_axes_count);
K_shape.insert(K_shape.begin(), y_shape.begin() + J_shape.size(), y_shape.end());
auto delta_reshaped = make_reshape_axes_to_front(delta, I_shape, K_shape); // KI
auto delta_reshaped_dot_y = make_shared<Dot>(y, delta_reshaped, K_shape.size()); // JI
auto delta_reshaped_dot_y_reshaped =
make_reshape_axes_to_front(delta_reshaped_dot_y, J_shape, I_shape); // IJ
adjoints.add_delta(x, delta_reshaped_dot_y_reshaped);
auto y_reshaped = make_reshape_axes_to_front(y, J_shape, K_shape); // KI
auto delta_dot_y_reshaped = make_shared<Dot>(delta, y_reshaped, K_shape.size()); // JI
adjoints.add_delta(x, delta_dot_y_reshaped);
auto x_reshaped = make_reshape_axes_to_front(x, I_shape, J_shape); // JI
auto x_reshaped_dot_delta = make_shared<Dot>(x_reshaped, delta, I_shape.size()); // JK
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment