Commit b1da7bc0 authored by Adam Rogowiec's avatar Adam Rogowiec

Cache intermediate calcucations results.

parent 3fab85bb
...@@ -141,12 +141,15 @@ namespace ngraph ...@@ -141,12 +141,15 @@ namespace ngraph
{ {
case 0 /*Logistic|Logistic*/: case 0 /*Logistic|Logistic*/:
{ {
auto i0 = auto in0_neg_exp = (-in0).exp();
delta * (-in0).exp() / auto in0_log_denominator = in0_neg_exp + 1.f;
(((-in1).exp() + 1.f) * (((-in0).exp() + 1.f) * ((-in0).exp() + 1.f))); auto in1_neg_exp = (-in1).exp();
auto i1 = auto in1_log_denominator = in1_neg_exp + 1.f;
delta * (-in1).exp() /
(((-in0).exp() + 1.f) * (((-in1).exp() + 1.f) * ((-in1).exp() + 1.f))); auto i0 = delta * in0_neg_exp /
(in1_log_denominator * in0_log_denominator * in0_log_denominator);
auto i1 = delta * in1_neg_exp /
(in0_log_denominator * in1_log_denominator * in1_log_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
...@@ -155,12 +158,17 @@ namespace ngraph ...@@ -155,12 +158,17 @@ namespace ngraph
break; break;
case 1 /*Logistic|Tanh*/: case 1 /*Logistic|Tanh*/:
{ {
auto i0 = delta * (((in1 * 2.f).exp() - 1.f) * (-in0).exp()) / auto in0_neg_exp = (-in0).exp();
(((in1 * 2.f).exp() + 1.f) * auto in0_log_denominator = in0_neg_exp + 1.f;
(((-in0).exp() + 1.f) * ((-in0).exp() + 1.f))); auto in1_2exp = (in1 * 2.f).exp();
auto i1 = delta * (4.f * (in1 * 2.f).exp()) / auto in1_tanh_denominator = in1_2exp + 1.f;
(((-in0).exp() + 1.f) *
(((in1 * 2.f).exp() + 1.f) * ((in1 * 2.f).exp() + 1.f))); auto i0 = delta * ((in1_2exp - 1.f) * in0_neg_exp) /
(in1_tanh_denominator *
in0_log_denominator * in0_log_denominator);
auto i1 = delta * (4.f * in1_2exp) /
(in0_log_denominator *
in1_tanh_denominator * in1_tanh_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
...@@ -169,9 +177,12 @@ namespace ngraph ...@@ -169,9 +177,12 @@ namespace ngraph
break; break;
case 2 /*Logistic|Identity*/: case 2 /*Logistic|Identity*/:
{ {
auto i0 = delta * (in1 * (-in0).exp()) / auto in0_neg_exp = (-in0).exp();
(((-in0).exp() + 1.f) * ((-in0).exp() + 1.f)); auto in0_log_denominator = in0_neg_exp + 1.f;
auto i1 = delta / (((-in0).exp() + 1.f));
auto i0 = delta * (in1 * in0_neg_exp) /
(in0_log_denominator * in0_log_denominator);
auto i1 = delta / in0_log_denominator;
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
...@@ -180,12 +191,17 @@ namespace ngraph ...@@ -180,12 +191,17 @@ namespace ngraph
break; break;
case 3 /*Tanh|Logistic*/: case 3 /*Tanh|Logistic*/:
{ {
auto i0 = delta * (4.f * (in0 * 2.f).exp()) / auto in0_2exp = (in0 * 2.f).exp();
(((-in1).exp() + 1.f) * ((in0 * 2.f).exp() + 1.f) * auto in0_tanh_denominator = in0_2exp + 1.f;
((in0 * 2.f).exp() + 1.f)); auto in1_neg_exp = (-in1).exp();
auto in1_log_denominator = in1_neg_exp + 1.f;
auto i0 =
delta * (4.f * in0_2exp) /
(in1_log_denominator * in0_tanh_denominator * in0_tanh_denominator);
auto i1 = auto i1 =
delta * (((in0 * 2.f).exp() - 1.f) * in1.exp()) / delta * ((in0_2exp - 1.f) * in1_neg_exp) /
(((in0 * 2.f).exp() + 1.f) * ((in1.exp() + 1.f) * (in1.exp() + 1.f))); (in0_tanh_denominator * in1_log_denominator * in1_log_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
...@@ -194,12 +210,17 @@ namespace ngraph ...@@ -194,12 +210,17 @@ namespace ngraph
break; break;
case 4 /*Tanh|Tanh*/: case 4 /*Tanh|Tanh*/:
{ {
auto i0 = delta * (((in1 * 2.f).exp() - 1.f) * (4.f * (in0 * 2.f).exp())) / auto in0_2exp = (in0 * 2.f).exp();
(((in1 * 2.f).exp() + 1.f) * auto in0_tanh_denominator = in0_2exp + 1.f;
(((in0 * 2.f).exp() + 1.f) * ((in0 * 2.f).exp() + 1.f))); auto in1_2exp = (in1 * 2.f).exp();
auto i1 = delta * (((in0 * 2.f).exp() - 1.f) * (4.f * (in1 * 2.f).exp())) / auto in1_tanh_denominator = in1_2exp + 1.f;
(((in0 * 2.f).exp() + 1.f) *
(((in1 * 2.f).exp() + 1.f) * ((in1 * 2.f).exp() + 1.f))); auto i0 =
delta * (in1_2exp - 1.f) * 4.f * in0_2exp /
(in1_tanh_denominator * in0_tanh_denominator * in0_tanh_denominator);
auto i1 =
delta * (in0_2exp - 1.f) * 4.f * in1_2exp /
(in0_tanh_denominator * in1_tanh_denominator * in1_tanh_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
...@@ -208,9 +229,12 @@ namespace ngraph ...@@ -208,9 +229,12 @@ namespace ngraph
break; break;
case 5 /*Tanh|Identity*/: case 5 /*Tanh|Identity*/:
{ {
auto i0 = delta * (in1 * (4.f * (in0 * 2.f).exp())) / auto in0_2exp = (in0 * 2.f).exp();
(((in0 * 2.f).exp() + 1.f) * ((in0 * 2.f).exp() + 1.f)); auto in0_tanh_denominator = in0_2exp + 1.f;
auto i1 = delta * ((in0 * 2.f).exp() - 1.f) / ((in0 * 2.f).exp() + 1.f);
auto i0 = delta * in1 * 4.f * in0_2exp /
(in0_tanh_denominator * in0_tanh_denominator);
auto i1 = delta * (in0_2exp - 1.f) / in0_tanh_denominator;
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
...@@ -219,9 +243,12 @@ namespace ngraph ...@@ -219,9 +243,12 @@ namespace ngraph
break; break;
case 6 /*Identity|Logistic*/: case 6 /*Identity|Logistic*/:
{ {
auto i0 = delta * 1.f / ((-in1).exp() + 1.f); auto in1_neg_exp = (-in1).exp();
auto i1 = delta * (in0 * (-in1).exp()) / auto in1_log_denominator = in1_neg_exp + 1.f;
(((-in1).exp() + 1.f) * ((-in1).exp() + 1.f));
auto i0 = delta * 1.f / in1_log_denominator;
auto i1 =
delta * in0 * in1_neg_exp / (in1_log_denominator * in1_log_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
...@@ -230,9 +257,12 @@ namespace ngraph ...@@ -230,9 +257,12 @@ namespace ngraph
break; break;
case 7 /*Identity|Tanh*/: case 7 /*Identity|Tanh*/:
{ {
auto i0 = delta * ((in1 * 2.f).exp() - 1.f) / ((in1 * 2.f).exp() + 1.f); auto in1_2exp = (in1 * 2.f).exp();
auto i1 = delta * (in0 * (4.f * (in1 * 2.f).exp())) / auto in1_tanh_denominator = in1_2exp + 1.f;
(((in1 * 2.f).exp() + 1.f) * ((in1 * 2.f).exp() + 1.f));
auto i0 = delta * (in1_2exp - 1.f) / in1_tanh_denominator;
auto i1 = delta * (in0 * (4.f * in1_2exp)) /
(in1_tanh_denominator * in1_tanh_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0; arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment