d_K[tid * ns * no + i * k_stride] = d_P[tid * ns * ns + i * p_stride] * (1 / d_S[tid * no * no + i * s_stride]);
}
else{
d_K[tid * ns * no + no * no + (i - no) * k_stride] = d_P[tid * ns * ns + ns * no + (i - no) * p_stride] * (1 / d_S[tid * no * no + (i - no) * s_stride]);