/*******************************************************************************
* Copyright 2016-2019 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef TEST_ELTWISE_HPP
#define TEST_ELTWISE_HPP

#include "gtest/gtest.h"
#include "mkldnn_test_common.hpp"

#include "mkldnn.hpp"

namespace mkldnn {

template <typename T, typename A> inline T relu_fwd(T s, A alpha) {
    return s > 0 ? s : static_cast<T>(s * alpha);
}
template <typename T, typename A> inline T relu_bwd(T dd, T s, A alpha) {
    return s > 0 ? dd : static_cast<T>(dd * alpha);
}
template <typename T> T tanh_fwd(T s) {
    return static_cast<T>(::tanhf((float)s));
}
template <typename T> T tanh_bwd(T dd, T s) {
    const float th = ::tanhf((float)s);
    return static_cast<T>(dd * (1 - th) * (1 + th));
}

template <typename T, typename A> T elu_fwd(T s, A alpha) {
    return s > 0 ? s : static_cast<T>(alpha * (::expf(s) - 1));
}
template <typename T, typename A> T elu_bwd(T dd, T s, A alpha) {
    return static_cast<T>(dd * (s > 0 ? 1 : alpha * ::expf(s)));
}

template <typename T>
T square_fwd(T s) {
    return s * s;
}

template <typename T>
T square_bwd(T dd, T s) {
    return dd * 2*s;
}

template <typename T>
T abs_fwd(T s) {
    return s > 0 ? s : -s;;
}

template <typename T>
T abs_bwd(T dd, T s) {
    return dd * (s > 0 ? 1 : s < 0 ? -1 : 0);
}

template <typename T>
T sqrt_fwd(T s) {
    return s > 0 ? ::sqrtf(s) : 0;
}

template <typename T>
T sqrt_bwd(T dd, T s) {
    return s > 0 ? dd / (2 * ::sqrtf(s)) : 0;
}

template <typename T, typename A>
T linear_fwd(T s, A alpha, A beta) {
    return alpha * s + beta;
}

template <typename T, typename A>
T linear_bwd(T dd, T s, A alpha, A beta) {
    (void) s;
    (void) beta;
    return dd * alpha;
}

template <typename T, typename A>
T bounded_relu_fwd(T s, A alpha) {
    s = s > 0 ? s : 0;
    return s > alpha ? alpha : s;
}

template <typename T, typename A>
T bounded_relu_bwd(T dd, T s, A alpha) {
    return dd * ((0 < s && s < alpha) ? 1 : 0);
}

template <typename T>
T soft_relu_fwd(T s) {
    return s < (T)logf(FLT_MAX) ? log1pf(::expf(s)) : s;
}

template <typename T>
T soft_relu_bwd(T dd, T s) {
    return dd / (1 + ::expf(-s));
}

template <typename T>
T logistic_fwd(T s) {
    T v = (T)(::expf(- (float)s));
    return 1 / (1 + v);
}

template <typename T>
T logistic_bwd(T dd, T s) {
    T v = logistic_fwd<T>(s);
    return dd * v * (1 - v);
}

template <typename T, typename A>
inline T clamp_fwd(T s, A alpha, A beta) {
    return s > alpha ? alpha : s < beta ? beta : s;
}

template <typename T, typename A>
inline T clamp_bwd(T dd, T s, A alpha, A beta) {
    return dd * (beta < s && s < alpha ? 1 : 0);
}

template <typename T>
inline T exp_fwd(T s) {
    return (T)(::expf((float)s));
}

template <typename T>
 inline T exp_bwd(T dd, T s) {
    return (T)(::expf((float)s));
}

template <typename T>
inline T not_fwd(T s) {
    return (T)(!s);
}

struct eltwise_test_params {
    engine::kind engine_kind;
    algorithm alg_kind;
    memory::format data_format;
    memory::format diff_format;
    float alpha, beta;
    memory::dims dims;
    bool expect_to_fail;
    mkldnn_status_t expected_status;
};

size_t n_elems(const memory::desc &md) {
    size_t p = 1;
    const ptrdiff_t *pdims = md.data.layout_desc.blocking.padding_dims;
    for (int i = 0; i < md.data.ndims; ++i)
        p *= (size_t)(pdims[i]);
    return p;
}

template <typename data_t>
void ref_eltwise_fwd(const eltwise_test_params &p,
        const memory::desc &md, const memory &src, const memory &dst)
{
    data_t *src_data = (data_t *)src.get_data_handle();
    data_t *dst_data = (data_t *)dst.get_data_handle();

    size_t n = n_elems(md);
    for (size_t i = 0; i < n; ++i) {
        data_t s = src_data[i];
        data_t ref_d = 0;
        switch (p.alg_kind) {
        case eltwise_relu:        ref_d = relu_fwd(s, p.alpha);           break;
        case eltwise_tanh:        ref_d = tanh_fwd(s);                    break;
        case eltwise_elu:         ref_d = elu_fwd(s, p.alpha);            break;
        case eltwise_square:      ref_d = square_fwd(s);                  break;
        case eltwise_abs:         ref_d = abs_fwd(s);                     break;
        case eltwise_sqrt:        ref_d = sqrt_fwd(s);                    break;
        case eltwise_linear:      ref_d = linear_fwd(s, p.alpha, p.beta); break;
        case eltwise_bounded_relu: ref_d = bounded_relu_fwd(s, p.alpha);  break;
        case eltwise_soft_relu:   ref_d = soft_relu_fwd(s);               break;
        case eltwise_logistic:    ref_d = logistic_fwd(s);                break;
        case eltwise_clamp:       ref_d = clamp_fwd(s, p.alpha, p.beta);  break;
        case eltwise_exp:         ref_d = exp_fwd(s);                     break;
        case eltwise_not:         ref_d = not_fwd(s);                     break;
        default: assert(!"unknown alg_kind");
        }
        dst_data[i] = ref_d;
    }
}

template <typename data_t>
void compare_eltwise_fwd(const eltwise_test_params &p,
        const memory::desc &md, const memory &dst, const memory &ref_dst,
        const float eps)
{
    data_t *ref_dst_data = (data_t *)ref_dst.get_data_handle();
    data_t *dst_data = (data_t *)dst.get_data_handle();

    size_t n = n_elems(md);
    for (size_t i = 0; i < n; ++i) {
        float diff_err = dst_data[i] - ref_dst_data[i];
        float rel_err = std::abs(
                (std::min)(std::abs((float)ref_dst_data[i]), std::abs(diff_err)) > 1e-5
                ? diff_err / ref_dst_data[i]
                : diff_err);
        if (p.alg_kind == eltwise_soft_relu){
            EXPECT_NEAR(rel_err, 0, 2 * eps);
        }
        else{
            EXPECT_NEAR(rel_err, 0, eps);
        }
    }
}


template <typename data_t>
void check_eltwise_bwd(const eltwise_test_params &p,
        const memory::desc &md, const memory &src, const memory &diff_dst,
        const memory &diff_src, const float eps)
{
    data_t *src_data = (data_t *)src.get_data_handle();
    data_t *diff_dst_data = (data_t *)diff_dst.get_data_handle();
    data_t *diff_src_data = (data_t *)diff_src.get_data_handle();

    const memory::desc data_d = src.get_primitive_desc().desc();
    const memory::desc diff_data_d = diff_src.get_primitive_desc().desc();

    ASSERT_EQ(md.data.data_type, memory::data_type::f32);

    size_t n = n_elems(md);
    for (size_t i = 0; i < n; ++i) {
        data_t ref_s = src_data[map_index(data_d, i)];
        data_t ref_dd = diff_dst_data[map_index(diff_data_d, i)];
        data_t ref_ds = 0;
        switch (p.alg_kind) {
        case eltwise_relu:   ref_ds = relu_bwd(ref_dd, ref_s, p.alpha); break;
        case eltwise_tanh:   ref_ds = tanh_bwd(ref_dd, ref_s); break;
        case eltwise_elu:    ref_ds = elu_bwd(ref_dd, ref_s, p.alpha); break;
        case eltwise_square: ref_ds = square_bwd(ref_dd, ref_s); break;
        case eltwise_abs:    ref_ds = abs_bwd(ref_dd, ref_s); break;
        case eltwise_sqrt:   ref_ds = sqrt_bwd(ref_dd, ref_s); break;
        case eltwise_linear:
            ref_ds = linear_bwd(ref_dd, ref_s, p.alpha, p.beta);
            break;
        case eltwise_bounded_relu:
            ref_ds = bounded_relu_bwd(ref_dd, ref_s, p.alpha);
            break;
        case eltwise_soft_relu:
            ref_ds = soft_relu_bwd(ref_dd, ref_s);
            break;
        case eltwise_logistic: ref_ds = logistic_bwd(ref_dd, ref_s); break;
        case eltwise_clamp: ref_ds = clamp_bwd(ref_dd, ref_s, p.alpha, p.beta); break;
        case eltwise_exp: ref_ds = exp_bwd(ref_dd, ref_s); break;
        default: assert(!"unknown alg_kind");
        }
        float diff_err = diff_src_data[map_index(diff_data_d, i)] - ref_ds;
        float rel_err = std::abs(
                (std::min)(std::abs((float)ref_ds), std::abs(diff_err)) > 1e-6
                ? diff_err / ref_ds
                : diff_err);
        EXPECT_NEAR(rel_err, 0, eps);
    }
}

}

#endif