Commit b1da7bc0 authored by Adam Rogowiec's avatar Adam Rogowiec

Cache intermediate calcucations results.

parent 3fab85bb
......@@ -141,12 +141,15 @@ namespace ngraph
{
case 0 /*Logistic|Logistic*/:
{
auto i0 =
delta * (-in0).exp() /
(((-in1).exp() + 1.f) * (((-in0).exp() + 1.f) * ((-in0).exp() + 1.f)));
auto i1 =
delta * (-in1).exp() /
(((-in0).exp() + 1.f) * (((-in1).exp() + 1.f) * ((-in1).exp() + 1.f)));
auto in0_neg_exp = (-in0).exp();
auto in0_log_denominator = in0_neg_exp + 1.f;
auto in1_neg_exp = (-in1).exp();
auto in1_log_denominator = in1_neg_exp + 1.f;
auto i0 = delta * in0_neg_exp /
(in1_log_denominator * in0_log_denominator * in0_log_denominator);
auto i1 = delta * in1_neg_exp /
(in0_log_denominator * in1_log_denominator * in1_log_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
......@@ -155,12 +158,17 @@ namespace ngraph
break;
case 1 /*Logistic|Tanh*/:
{
auto i0 = delta * (((in1 * 2.f).exp() - 1.f) * (-in0).exp()) /
(((in1 * 2.f).exp() + 1.f) *
(((-in0).exp() + 1.f) * ((-in0).exp() + 1.f)));
auto i1 = delta * (4.f * (in1 * 2.f).exp()) /
(((-in0).exp() + 1.f) *
(((in1 * 2.f).exp() + 1.f) * ((in1 * 2.f).exp() + 1.f)));
auto in0_neg_exp = (-in0).exp();
auto in0_log_denominator = in0_neg_exp + 1.f;
auto in1_2exp = (in1 * 2.f).exp();
auto in1_tanh_denominator = in1_2exp + 1.f;
auto i0 = delta * ((in1_2exp - 1.f) * in0_neg_exp) /
(in1_tanh_denominator *
in0_log_denominator * in0_log_denominator);
auto i1 = delta * (4.f * in1_2exp) /
(in0_log_denominator *
in1_tanh_denominator * in1_tanh_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
......@@ -169,9 +177,12 @@ namespace ngraph
break;
case 2 /*Logistic|Identity*/:
{
auto i0 = delta * (in1 * (-in0).exp()) /
(((-in0).exp() + 1.f) * ((-in0).exp() + 1.f));
auto i1 = delta / (((-in0).exp() + 1.f));
auto in0_neg_exp = (-in0).exp();
auto in0_log_denominator = in0_neg_exp + 1.f;
auto i0 = delta * (in1 * in0_neg_exp) /
(in0_log_denominator * in0_log_denominator);
auto i1 = delta / in0_log_denominator;
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
......@@ -180,12 +191,17 @@ namespace ngraph
break;
case 3 /*Tanh|Logistic*/:
{
auto i0 = delta * (4.f * (in0 * 2.f).exp()) /
(((-in1).exp() + 1.f) * ((in0 * 2.f).exp() + 1.f) *
((in0 * 2.f).exp() + 1.f));
auto in0_2exp = (in0 * 2.f).exp();
auto in0_tanh_denominator = in0_2exp + 1.f;
auto in1_neg_exp = (-in1).exp();
auto in1_log_denominator = in1_neg_exp + 1.f;
auto i0 =
delta * (4.f * in0_2exp) /
(in1_log_denominator * in0_tanh_denominator * in0_tanh_denominator);
auto i1 =
delta * (((in0 * 2.f).exp() - 1.f) * in1.exp()) /
(((in0 * 2.f).exp() + 1.f) * ((in1.exp() + 1.f) * (in1.exp() + 1.f)));
delta * ((in0_2exp - 1.f) * in1_neg_exp) /
(in0_tanh_denominator * in1_log_denominator * in1_log_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
......@@ -194,12 +210,17 @@ namespace ngraph
break;
case 4 /*Tanh|Tanh*/:
{
auto i0 = delta * (((in1 * 2.f).exp() - 1.f) * (4.f * (in0 * 2.f).exp())) /
(((in1 * 2.f).exp() + 1.f) *
(((in0 * 2.f).exp() + 1.f) * ((in0 * 2.f).exp() + 1.f)));
auto i1 = delta * (((in0 * 2.f).exp() - 1.f) * (4.f * (in1 * 2.f).exp())) /
(((in0 * 2.f).exp() + 1.f) *
(((in1 * 2.f).exp() + 1.f) * ((in1 * 2.f).exp() + 1.f)));
auto in0_2exp = (in0 * 2.f).exp();
auto in0_tanh_denominator = in0_2exp + 1.f;
auto in1_2exp = (in1 * 2.f).exp();
auto in1_tanh_denominator = in1_2exp + 1.f;
auto i0 =
delta * (in1_2exp - 1.f) * 4.f * in0_2exp /
(in1_tanh_denominator * in0_tanh_denominator * in0_tanh_denominator);
auto i1 =
delta * (in0_2exp - 1.f) * 4.f * in1_2exp /
(in0_tanh_denominator * in1_tanh_denominator * in1_tanh_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
......@@ -208,9 +229,12 @@ namespace ngraph
break;
case 5 /*Tanh|Identity*/:
{
auto i0 = delta * (in1 * (4.f * (in0 * 2.f).exp())) /
(((in0 * 2.f).exp() + 1.f) * ((in0 * 2.f).exp() + 1.f));
auto i1 = delta * ((in0 * 2.f).exp() - 1.f) / ((in0 * 2.f).exp() + 1.f);
auto in0_2exp = (in0 * 2.f).exp();
auto in0_tanh_denominator = in0_2exp + 1.f;
auto i0 = delta * in1 * 4.f * in0_2exp /
(in0_tanh_denominator * in0_tanh_denominator);
auto i1 = delta * (in0_2exp - 1.f) / in0_tanh_denominator;
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
......@@ -219,9 +243,12 @@ namespace ngraph
break;
case 6 /*Identity|Logistic*/:
{
auto i0 = delta * 1.f / ((-in1).exp() + 1.f);
auto i1 = delta * (in0 * (-in1).exp()) /
(((-in1).exp() + 1.f) * ((-in1).exp() + 1.f));
auto in1_neg_exp = (-in1).exp();
auto in1_log_denominator = in1_neg_exp + 1.f;
auto i0 = delta * 1.f / in1_log_denominator;
auto i1 =
delta * in0 * in1_neg_exp / (in1_log_denominator * in1_log_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
......@@ -230,9 +257,12 @@ namespace ngraph
break;
case 7 /*Identity|Tanh*/:
{
auto i0 = delta * ((in1 * 2.f).exp() - 1.f) / ((in1 * 2.f).exp() + 1.f);
auto i1 = delta * (in0 * (4.f * (in1 * 2.f).exp())) /
(((in1 * 2.f).exp() + 1.f) * ((in1 * 2.f).exp() + 1.f));
auto in1_2exp = (in1 * 2.f).exp();
auto in1_tanh_denominator = in1_2exp + 1.f;
auto i0 = delta * (in1_2exp - 1.f) / in1_tanh_denominator;
auto i1 = delta * (in0 * (4.f * in1_2exp)) /
(in1_tanh_denominator * in1_tanh_denominator);
i0_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = i0;
i1_delta.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment