Commit ee74e576 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

DEX Sigmoid Multiply (#1389)

* dex sigmoid multiply

* sigmoid multiply

* refactor compute logic into standalone kernels

* address jayaram's feedback
parent 3d664abd
......@@ -14,11 +14,12 @@
* limitations under the License.
*******************************************************************************/
//#include "ngraph/runtime/cpu/kernel/avg_pool.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/sigmoid_multiply.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
using namespace std;
using namespace ngraph;
......@@ -107,8 +108,69 @@ namespace ngraph
functors.emplace_back(functor);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::SigmoidMultiply)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& arg1_tensor = tensor_data[args[1].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
auto tensor_size = shape_size(args[0].get_shape());
auto sigmoid_mul = static_cast<const ngraph::op::SigmoidMultiply*>(node);
const size_t index =
static_cast<size_t>(sigmoid_mul->get_input_func_type(0)) *
static_cast<size_t>(ngraph::op::SigmoidMultiply::FunctionType::NumTypes) +
static_cast<size_t>(sigmoid_mul->get_input_func_type(1));
auto functor = [&, index, tensor_size](CPURuntimeContext* ctx) {
ngraph::runtime::cpu::kernel::sigmoid_multiply(
arg0_tensor, arg1_tensor, out_tensor, tensor_size, index);
};
functors.emplace_back(functor);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::SigmoidMultiplyBackprop)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& arg1_tensor = tensor_data[args[1].get_name()];
auto& arg2_tensor = tensor_data[args[2].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
auto& out1_tensor = tensor_data[out[1].get_name()];
auto tensor_size = shape_size(args[0].get_shape());
auto sigmoid_mul = static_cast<const ngraph::op::SigmoidMultiplyBackprop*>(node);
const size_t index =
static_cast<size_t>(sigmoid_mul->get_input_func_type(0)) *
static_cast<size_t>(ngraph::op::SigmoidMultiply::FunctionType::NumTypes) +
static_cast<size_t>(sigmoid_mul->get_input_func_type(1));
auto functor = [&, index, tensor_size](CPURuntimeContext* ctx) {
ngraph::runtime::cpu::kernel::sigmoid_multiply_backprop(arg0_tensor,
arg1_tensor,
arg2_tensor,
out0_tensor,
out1_tensor,
tensor_size,
index);
};
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(Sigmoid);
REGISTER_OP_BUILDER(SigmoidBackprop);
REGISTER_OP_BUILDER(SigmoidMultiply);
REGISTER_OP_BUILDER(SigmoidMultiplyBackprop);
}
}
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#define EIGEN_USE_THREADS
#include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/kernel/eigen_thread_pool.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
template <typename ElementType>
static Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>>
wrap_into_tensor_map(void* data, size_t tensor_size)
{
Eigen::array<Eigen::Index, 1> dims;
dims[0] = tensor_size;
Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> out(
static_cast<ElementType*>(data), dims);
return out;
}
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace kernel
{
using namespace std;
using namespace ngraph;
void sigmoid_multiply(void* arg0_tensor,
void* arg1_tensor,
void* out_tensor,
size_t tensor_size,
size_t index)
{
auto in0 = wrap_into_tensor_map<float>(arg0_tensor, tensor_size);
auto in1 = wrap_into_tensor_map<float>(arg1_tensor, tensor_size);
auto out_tm = wrap_into_tensor_map<float>(out_tensor, tensor_size);
switch (index)
{
case 0 /*Logistic|Logistic*/:
{
auto c = (in0.exp() * in1.exp()) / ((in0.exp() + 1.f) * (in1.exp() + 1.f));
out_tm.device(eigen::global_thread_pool_device) = c;
}
break;
case 1 /*Logistic|Tanh*/:
{
auto c = (in0.exp() * ((in1 * 2.f).exp() - 1.f)) /
((in0.exp() + 1.f) * ((in1 * 2.f).exp() + 1.f));
out_tm.device(eigen::global_thread_pool_device) = c;
}
break;
case 2 /*Logistic|Identity*/:
{
auto c = (in0.exp() * in1) / (in0.exp() + 1.f);
out_tm.device(eigen::global_thread_pool_device) = c;
}
break;
case 3 /*Tanh|Logistic*/:
{
auto c = (((in0 * 2.f).exp() - 1.f) * in1.exp()) /
(((in0 * 2.f).exp() + 1.f) * (in1.exp() + 1.f));
out_tm.device(eigen::global_thread_pool_device) = c;
}
break;
case 4 /*Tanh|Tanh*/:
{
auto c = (((in0 * 2.f).exp() - 1.f) * ((in1 * 2.f).exp() - 1.f)) /
(((in0 * 2.f).exp() + 1.f) * ((in1 * 2.f).exp() + 1.f));
out_tm.device(eigen::global_thread_pool_device) = c;
}
break;
case 5 /*Tanh|Identity*/:
{
auto c = (((in0 * 2.f).exp() - 1.f) * in1) / ((in0 * 2.f).exp() + 1.f);
out_tm.device(eigen::global_thread_pool_device) = c;
}
break;
case 6 /*Identity|Logistic*/:
{
auto c = (in0 * in1.exp()) / (in1.exp() + 1.f);
out_tm.device(eigen::global_thread_pool_device) = c;
}
break;
case 7 /*Identity|Tanh*/:
{
auto c = (in0 * ((in1 * 2.f).exp() - 1.f)) / ((in1 * 2.f).exp() + 1.f);
out_tm.device(eigen::global_thread_pool_device) = c;
}
break;
case 8 /*Identity|Identity*/:
{
auto c = (in0 * in1);
out_tm.device(eigen::global_thread_pool_device) = c;
}
break;
default: throw ngraph_error("unsupported combination for SigmoidMultiply");
}
}
void sigmoid_multiply_backprop(void* arg0_tensor,
void* arg1_tensor,
void* arg2_tensor,
void* out0_tensor,
void* out1_tensor,
size_t tensor_size,
size_t index)
{
auto in0 = wrap_into_tensor_map<float>(arg0_tensor, tensor_size);
auto in1 = wrap_into_tensor_map<float>(arg1_tensor, tensor_size);
auto delta = wrap_into_tensor_map<float>(arg2_tensor, tensor_size);
auto i0_delta = wrap_into_tensor_map<float>(out0_tensor, tensor_size);
auto i1_delta = wrap_into_tensor_map<float>(out1_tensor, tensor_size);
switch (index)
{
case 0 /*Logistic|Logistic*/:
{
auto i0 = delta * (in1.exp() * in0.exp()) /
((in1.exp() + 1.f) * ((in0.exp() + 1.f) * (in0.exp() + 1.f)));
auto i1 = delta * (in0.exp() * in1.exp()) /
((in0.exp() + 1.f) * ((in1.exp() + 1.f) * (in1.exp() + 1.f)));
i0_delta.device(eigen::global_thread_pool_device) = i0;
i1_delta.device(eigen::global_thread_pool_device) = i1;
}
break;
case 1 /*Logistic|Tanh*/:
{
auto i0 =
delta * (((in1 * 2.f).exp() - 1.f) * in0.exp()) /
(((in1 * 2.f).exp() + 1.f) * ((in0.exp() + 1.f) * (in0.exp() + 1.f)));
auto i1 = delta * (in0.exp() * (4.f * (in1 * 2.f).exp())) /
((in0.exp() + 1.f) *
(((in1 * 2.f).exp() + 1.f) * ((in1 * 2.f).exp() + 1.f)));
i0_delta.device(eigen::global_thread_pool_device) = i0;
i1_delta.device(eigen::global_thread_pool_device) = i1;
}
break;
case 2 /*Logistic|Identity*/:
{
auto i0 =
delta * (in1 * in0.exp()) / ((in0.exp() + 1.f) * (in0.exp() + 1.f));
auto i1 = delta * in0.exp() / ((in0.exp() + 1.f));
i0_delta.device(eigen::global_thread_pool_device) = i0;
i1_delta.device(eigen::global_thread_pool_device) = i1;
}
break;
case 3 /*Tanh|Logistic*/:
{
auto i0 = delta * (in1.exp() * (4.f * (in0 * 2.f).exp())) /
((in1.exp() + 1.f) * ((in0 * 2.f).exp() + 1.f) *
((in0 * 2.f).exp() + 1.f));
auto i1 =
delta * (((in0 * 2.f).exp() - 1.f) * in1.exp()) /
(((in0 * 2.f).exp() + 1.f) * ((in1.exp() + 1.f) * (in1.exp() + 1.f)));
i0_delta.device(eigen::global_thread_pool_device) = i0;
i1_delta.device(eigen::global_thread_pool_device) = i1;
}
break;
case 4 /*Tanh|Tanh*/:
{
auto i0 = delta * (((in1 * 2.f).exp() - 1.f) * (4.f * (in0 * 2.f).exp())) /
(((in1 * 2.f).exp() + 1.f) *
(((in0 * 2.f).exp() + 1.f) * ((in0 * 2.f).exp() + 1.f)));
auto i1 = delta * (((in0 * 2.f).exp() - 1.f) * (4.f * (in1 * 2.f).exp())) /
(((in0 * 2.f).exp() + 1.f) *
(((in1 * 2.f).exp() + 1.f) * ((in1 * 2.f).exp() + 1.f)));
i0_delta.device(eigen::global_thread_pool_device) = i0;
i1_delta.device(eigen::global_thread_pool_device) = i1;
}
break;
case 5 /*Tanh|Identity*/:
{
auto i0 = delta * (in1 * (4.f * (in0 * 2.f).exp())) /
(((in0 * 2.f).exp() + 1.f) * ((in0 * 2.f).exp() + 1.f));
auto i1 = delta * ((in0 * 2.f).exp() - 1.f) / ((in0 * 2.f).exp() + 1.f);
i0_delta.device(eigen::global_thread_pool_device) = i0;
i1_delta.device(eigen::global_thread_pool_device) = i1;
}
break;
case 6 /*Identity|Logistic*/:
{
auto i0 = delta * (in1.exp()) / (in1.exp() + 1.f);
auto i1 =
delta * (in0 * in1.exp()) / ((in1.exp() + 1.f) * (in1.exp() + 1.f));
i0_delta.device(eigen::global_thread_pool_device) = i0;
i1_delta.device(eigen::global_thread_pool_device) = i1;
}
break;
case 7 /*Identity|Tanh*/:
{
auto i0 = delta * ((in1 * 2.f).exp() - 1.f) / ((in1 * 2.f).exp() + 1.f);
auto i1 = delta * (in0 * (4.f * (in1 * 2.f).exp())) /
(((in1 * 2.f).exp() + 1.f) * ((in1 * 2.f).exp() + 1.f));
i0_delta.device(eigen::global_thread_pool_device) = i0;
i1_delta.device(eigen::global_thread_pool_device) = i1;
}
break;
case 8 /*Identity|Identity*/:
{
auto i0 = delta * in1;
auto i1 = delta * in0;
i0_delta.device(eigen::global_thread_pool_device) = i0;
i1_delta.device(eigen::global_thread_pool_device) = i1;
}
break;
default: throw ngraph_error("unsupported combination for SigmoidMultiply");
}
}
}
}
}
}
......@@ -34,7 +34,8 @@ namespace ngraph
{
Logistic,
Tanh,
Identity
Identity,
NumTypes
};
/// Input nodes are expected to be actual inputs where the corresponding input
/// FunctionType will be applied to those inputs in the fused operation.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment