Commit 804e381a authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Update some incorrect comments in Dot::generate_adjoints (#2045)

parent 40bcfdf7
...@@ -173,11 +173,11 @@ void op::Dot::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -173,11 +173,11 @@ void op::Dot::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
J_shape.insert(J_shape.begin(), y_shape.begin(), y_shape.begin() + m_reduction_axes_count); J_shape.insert(J_shape.begin(), y_shape.begin(), y_shape.begin() + m_reduction_axes_count);
K_shape.insert(K_shape.begin(), y_shape.begin() + J_shape.size(), y_shape.end()); K_shape.insert(K_shape.begin(), y_shape.begin() + J_shape.size(), y_shape.end());
auto y_reshaped = make_reshape_axes_to_front(y, J_shape, K_shape); // KI auto y_reshaped = make_reshape_axes_to_front(y, J_shape, K_shape); // KJ
auto delta_dot_y_reshaped = make_shared<Dot>(delta, y_reshaped, K_shape.size()); // JI auto delta_dot_y_reshaped = make_shared<Dot>(delta, y_reshaped, K_shape.size()); // IK.KJ->IJ
adjoints.add_delta(x, delta_dot_y_reshaped); adjoints.add_delta(x, delta_dot_y_reshaped);
auto x_reshaped = make_reshape_axes_to_front(x, I_shape, J_shape); // JI auto x_reshaped = make_reshape_axes_to_front(x, I_shape, J_shape); // JI
auto x_reshaped_dot_delta = make_shared<Dot>(x_reshaped, delta, I_shape.size()); // JK auto x_reshaped_dot_delta = make_shared<Dot>(x_reshaped, delta, I_shape.size()); // JI.IK->JK
adjoints.add_delta(y, x_reshaped_dot_delta); adjoints.add_delta(y, x_reshaped_dot_delta);
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment