Unverified Commit c11644ec authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Fix numeric instability in batchnorm bprop (#2246)

* Fix numeric instability in batchnorm bprop

* Another instability
parent b7f097ec
...@@ -155,12 +155,10 @@ namespace ngraph ...@@ -155,12 +155,10 @@ namespace ngraph
// gamma[C] // gamma[C]
// beta[C] // beta[C]
// mu[c:C] = sum(input[., c, ...])/elements_per_channel // mu[c:C] = sum(input[., c, ...])/elements_per_channel
// centered[., c:C, ...] = input[., c, ...] - mu[c] // var[c:C] = sum(input[., c, ...]^2 - mu[c])/elements_per_channel
// square[., c:C, ...] = centered[., c, ...]^2
// var[c:C] = sum(centered[., c, ...]^2)/elements_per_channel
// inv_sqrt[c:C] = 1/sqrt(var[c]+epsilon) // inv_sqrt[c:C] = 1/sqrt(var[c]+epsilon)
// gammad[c:C] = gamma[c]*inv_sqrt[c] // gammad[c:C] = gamma[c]*inv_sqrt[c]
// normed[., c:C, ...] = centered[., c, ...]*gammad[c]+beta[c] // normed[., c:C, ...] = (input[., c, ...]-mu)*gammad[c]+beta[c]
for (auto c = 0; c < num_channels; ++c) for (auto c = 0; c < num_channels; ++c)
{ {
...@@ -182,9 +180,8 @@ namespace ngraph ...@@ -182,9 +180,8 @@ namespace ngraph
auto idx = input_transform.index(input_coord); auto idx = input_transform.index(input_coord);
auto delta_idx = delta_normed[idx]; auto delta_idx = delta_normed[idx];
auto input_idx = input[idx]; auto input_idx = input[idx];
auto centered = input_idx - mu;
delta_beta_sum += delta_idx; delta_beta_sum += delta_idx;
delta_gammad += centered * delta_idx; delta_gammad += (input_idx - mu) * delta_idx;
T delta_centered = gammad * delta_idx; T delta_centered = gammad * delta_idx;
delta_input[idx] = delta_centered; delta_input[idx] = delta_centered;
delta_mu -= delta_centered; delta_mu -= delta_centered;
...@@ -192,22 +189,21 @@ namespace ngraph ...@@ -192,22 +189,21 @@ namespace ngraph
delta_beta[c] = delta_beta_sum; delta_beta[c] = delta_beta_sum;
delta_gamma[c] = delta_gammad * inv_sqrt_var_eps; delta_gamma[c] = delta_gammad * inv_sqrt_var_eps;
T delta_inv_sqrt = gamma[c] * delta_gammad; T delta_inv_sqrt = gamma[c] * delta_gammad;
// y = x^(-1/2)
// dy = -(1/2)x^(-3/2) = -y/(2x) dx
T delta_var = -delta_inv_sqrt * inv_sqrt_var_eps / (2 * var_eps); T delta_var = -delta_inv_sqrt * inv_sqrt_var_eps / (2 * var_eps);
T delta_two_var_sum = 2 * delta_var / elements_per_channel; T delta_two_var_sum = 2 * delta_var / elements_per_channel;
for (Coordinate input_coord : input_transform)
{
auto idx = input_transform.index(input_coord);
auto two_centered = (input[idx] - mu) * delta_two_var_sum;
delta_input[idx] += two_centered;
delta_mu -= two_centered;
}
T delta_mu_over_n = delta_mu / elements_per_channel; T delta_mu_over_n = delta_mu / elements_per_channel;
for (Coordinate input_coord : input_transform) for (Coordinate input_coord : input_transform)
{ {
// v = 1/N sum(x_i - mu)^2
// dv = 2/N sum[(x_i - mu)dx_i] - 2/N sum[(x_i - mu) dmu]
// = 2/N sum[(x_i - mu)dx_i] - 2/N (Nmu-Nmu) dmu
// = 2/N sum[(x_i - mu)dx_i]
auto idx = input_transform.index(input_coord); auto idx = input_transform.index(input_coord);
delta_input[idx] += delta_mu_over_n; // These two values mostly cancel out so add them first
auto val = delta_input[idx] + delta_mu_over_n;
delta_input[idx] = val + (input[idx] - mu) * delta_two_var_sum;
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment