Commit e349b31d authored by Adam Rogowiec's avatar Adam Rogowiec

Use sigmoid definition resistant to overflow in exponent function.

parent 01a944a5
...@@ -57,28 +57,28 @@ namespace ngraph ...@@ -57,28 +57,28 @@ namespace ngraph
{ {
case 0 /*Logistic|Logistic*/: case 0 /*Logistic|Logistic*/:
{ {
auto c = (in0.exp() * in1.exp()) / ((in0.exp() + 1.f) * (in1.exp() + 1.f)); auto c = 1.f / (((-in0).exp() + 1.f) * ((-in1).exp() + 1.f));
out_tm.device( out_tm.device(
ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c; ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c;
} }
break; break;
case 1 /*Logistic|Tanh*/: case 1 /*Logistic|Tanh*/:
{ {
auto c = (in0.exp() * in1.tanh()) / (in0.exp() + 1.f); auto c = in1.tanh() / ((-in0).exp() + 1.f);
out_tm.device( out_tm.device(
ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c; ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c;
} }
break; break;
case 2 /*Logistic|Identity*/: case 2 /*Logistic|Identity*/:
{ {
auto c = (in0.exp() * in1) / (in0.exp() + 1.f); auto c = in1 / ((-in0).exp() + 1.f);
out_tm.device( out_tm.device(
ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c; ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c;
} }
break; break;
case 3 /*Tanh|Logistic*/: case 3 /*Tanh|Logistic*/:
{ {
auto c = (in0.tanh() * in1.exp()) / (in1.exp() + 1.f); auto c = in0.tanh() / ((-in1).exp() + 1.f);
out_tm.device( out_tm.device(
ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c; ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c;
} }
...@@ -99,7 +99,7 @@ namespace ngraph ...@@ -99,7 +99,7 @@ namespace ngraph
break; break;
case 6 /*Identity|Logistic*/: case 6 /*Identity|Logistic*/:
{ {
auto c = (in0 * in1.exp()) / (in1.exp() + 1.f); auto c = in0 / ((-in1).exp() + 1.f);
out_tm.device( out_tm.device(
ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c; ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = c;
} }
...@@ -141,10 +141,12 @@ namespace ngraph ...@@ -141,10 +141,12 @@ namespace ngraph
{ {
case 0 /*Logistic|Logistic*/: case 0 /*Logistic|Logistic*/:
{ {
auto i0 = delta * (in1.exp() * in0.exp()) / auto i0 =
((in1.exp() + 1.f) * ((in0.exp() + 1.f) * (in0.exp() + 1.f))); delta * (-in0).exp() /
auto i1 = delta * (in0.exp() * in1.exp()) / (((-in1).exp() + 1.f) * (((-in0).exp() + 1.f) * ((-in0).exp() + 1.f)));
((in0.exp() + 1.f) * ((in1.exp() + 1.f) * (in1.exp() + 1.f))); auto i1 =
delta * (-in1).exp() /
(((-in0).exp() + 1.f) * (((-in1).exp() + 1.f) * ((-in1).exp() + 1.f)));
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
...@@ -153,11 +155,11 @@ namespace ngraph ...@@ -153,11 +155,11 @@ namespace ngraph
break; break;
case 1 /*Logistic|Tanh*/: case 1 /*Logistic|Tanh*/:
{ {
auto i0 = auto i0 = delta * (((in1 * 2.f).exp() - 1.f) * (-in0).exp()) /
delta * (((in1 * 2.f).exp() - 1.f) * in0.exp()) / (((in1 * 2.f).exp() + 1.f) *
(((in1 * 2.f).exp() + 1.f) * ((in0.exp() + 1.f) * (in0.exp() + 1.f))); (((-in0).exp() + 1.f) * ((-in0).exp() + 1.f)));
auto i1 = delta * (in0.exp() * (4.f * (in1 * 2.f).exp())) / auto i1 = delta * (4.f * (in1 * 2.f).exp()) /
((in0.exp() + 1.f) * (((-in0).exp() + 1.f) *
(((in1 * 2.f).exp() + 1.f) * ((in1 * 2.f).exp() + 1.f))); (((in1 * 2.f).exp() + 1.f) * ((in1 * 2.f).exp() + 1.f)));
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
...@@ -167,9 +169,9 @@ namespace ngraph ...@@ -167,9 +169,9 @@ namespace ngraph
break; break;
case 2 /*Logistic|Identity*/: case 2 /*Logistic|Identity*/:
{ {
auto i0 = auto i0 = delta * (in1 * (-in0).exp()) /
delta * (in1 * in0.exp()) / ((in0.exp() + 1.f) * (in0.exp() + 1.f)); (((-in0).exp() + 1.f) * ((-in0).exp() + 1.f));
auto i1 = delta * in0.exp() / ((in0.exp() + 1.f)); auto i1 = delta / (((-in0).exp() + 1.f));
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
...@@ -178,8 +180,8 @@ namespace ngraph ...@@ -178,8 +180,8 @@ namespace ngraph
break; break;
case 3 /*Tanh|Logistic*/: case 3 /*Tanh|Logistic*/:
{ {
auto i0 = delta * (in1.exp() * (4.f * (in0 * 2.f).exp())) / auto i0 = delta * (4.f * (in0 * 2.f).exp()) /
((in1.exp() + 1.f) * ((in0 * 2.f).exp() + 1.f) * (((-in1).exp() + 1.f) * ((in0 * 2.f).exp() + 1.f) *
((in0 * 2.f).exp() + 1.f)); ((in0 * 2.f).exp() + 1.f));
auto i1 = auto i1 =
delta * (((in0 * 2.f).exp() - 1.f) * in1.exp()) / delta * (((in0 * 2.f).exp() - 1.f) * in1.exp()) /
...@@ -217,9 +219,9 @@ namespace ngraph ...@@ -217,9 +219,9 @@ namespace ngraph
break; break;
case 6 /*Identity|Logistic*/: case 6 /*Identity|Logistic*/:
{ {
auto i0 = delta * (in1.exp()) / (in1.exp() + 1.f); auto i0 = delta * 1.f / ((-in1).exp() + 1.f);
auto i1 = auto i1 = delta * (in0 * (-in1).exp()) /
delta * (in0 * in1.exp()) / ((in1.exp() + 1.f) * (in1.exp() + 1.f)); (((-in1).exp() + 1.f) * ((-in1).exp() + 1.f));
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment