/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include <cmath>

#include "mkldnn_test_common.hpp"
#include "gtest/gtest.h"

#include "mkldnn.hpp"

#define ENGINE engine::kind::cpu
#define INST_TEST_CASE(str, ...) INSTANTIATE_TEST_CASE_P( \
        str, bnorm_test, ::testing::Values(__VA_ARGS__));

namespace mkldnn {

struct test_bnorm_sizes_t {
    int mb, c, d, h, w;
};

struct test_bnorm_formats_t {
    mkldnn::memory::format data_format;
    mkldnn::memory::format diff_format;
};

struct test_bnorm_params_t {
    mkldnn::engine::kind engine_kind;
    test_bnorm_formats_t formats;
    test_bnorm_sizes_t sizes;
    float epsilon;
    int ndims;
    bool expect_to_fail;
    mkldnn_status_t expected_status;
};

template <typename T> void fill(memory &m) {
    auto numElements = m.get_primitive_desc().get_size() / sizeof(T);
    T *dataPtr = reinterpret_cast<T *>(m.get_data_handle());
    fill_data<T>(numElements, dataPtr);
}

template <typename data_t>
class bnorm_test_common : public ::testing::TestWithParam<test_bnorm_params_t> {
private:
    std::shared_ptr<test_memory> src;
    std::shared_ptr<test_memory> dst;
    std::shared_ptr<test_memory> diff_src;
    std::shared_ptr<test_memory> diff_dst;
    std::shared_ptr<memory> weights;
    std::shared_ptr<memory> diff_weights;
    std::shared_ptr<memory> mean;
    std::shared_ptr<memory> variance;
    std::shared_ptr<memory::desc> data_d;
    std::shared_ptr<memory::desc> diff_data_d;
    std::shared_ptr<batch_normalization_forward::primitive_desc> bnorm_fwd_pd;
    std::shared_ptr<batch_normalization_backward::primitive_desc> bnorm_bwd_pd;
    test_bnorm_params_t p;
    std::shared_ptr<engine> eng;
    memory::data_type data_type;

protected:
    virtual void SetUp() {
        p = ::testing::TestWithParam<decltype(p)>::GetParam();
        catch_expected_failures([=](){Test();}, p.expect_to_fail,
                    p.expected_status);
    }

    void Test() {
        p = ::testing::TestWithParam<decltype(p)>::GetParam();

        ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
        eng.reset(new engine(p.engine_kind, 0));
        memory::data_type data_type = data_traits<data_t>::data_type;
        ASSERT_TRUE(isF32(data_type) || isS8(data_type));

        test_bnorm_sizes_t bs = p.sizes;
        bool has_spatial = (p.formats.data_format != mkldnn_nc);
        if (has_spatial) {
            if (p.ndims == 5) {
                data_d.reset(new memory::desc({ bs.mb, bs.c, bs.d, bs.h, bs.w },
                    data_type, p.formats.data_format));
                diff_data_d.reset(new memory::desc({ bs.mb, bs.c, bs.d, bs.h, bs.w },
                    data_type, p.formats.diff_format));
            } else {
                data_d.reset(new memory::desc({ bs.mb, bs.c, bs.h, bs.w },
                    data_type, p.formats.data_format));
                diff_data_d.reset(new memory::desc({ bs.mb, bs.c, bs.h, bs.w },
                    data_type, p.formats.diff_format));
            }
        } else {
            data_d.reset(new memory::desc({ bs.mb, bs.c },
                data_type, p.formats.data_format));
            diff_data_d.reset(new memory::desc({ bs.mb, bs.c },
                data_type, p.formats.diff_format));
        }

        src.reset(new test_memory(*data_d, *eng));
        dst.reset(new test_memory(*data_d, *eng));
        diff_src.reset(new test_memory(*diff_data_d, *eng));
        diff_dst.reset(new test_memory(*diff_data_d, *eng));

        auto training = prop_kind::forward_training;
        auto inference = prop_kind::forward_inference;

        if (isF32(data_type)) {
            Forward(training);
            Forward(training, use_global_stats);
            Forward(training, use_scale_shift);
            Forward(training, use_scale_shift | use_global_stats);
            Forward(inference);
            Forward(inference, use_global_stats);
            Forward(inference, use_scale_shift);

            Backward(backward_data);
            Backward(backward_data, use_global_stats);
            Backward(backward_data, use_scale_shift);
            Backward(backward_data, use_scale_shift | use_global_stats);
            Backward(backward, use_scale_shift);
            Backward(backward, use_scale_shift | use_global_stats);
        } else if (isS8(data_type)) {
            Forward(inference, use_global_stats);
            Forward(inference, use_global_stats | use_scale_shift);
        }
    }

    void Forward(prop_kind pk, unsigned flags = 0u) {
        bool useScaleShift = flags & use_scale_shift;
        bool useGlobalStats = flags & use_global_stats;
        bool isTraining = pk == prop_kind::forward_training;

        auto bnorm_fwd_d = batch_normalization_forward::desc(pk,
                    *data_d, p.epsilon, flags);

        bnorm_fwd_pd.reset(new batch_normalization_forward::primitive_desc(
                    bnorm_fwd_d, *eng));

        weights.reset(new memory(bnorm_fwd_pd->weights_primitive_desc()));
        if (isTraining || useGlobalStats) {
            mean.reset(new memory(bnorm_fwd_pd->mean_primitive_desc()));
            variance.reset(
                    new memory(bnorm_fwd_pd->variance_primitive_desc()));
        }

        fill<data_t>(src->get());
        fill<data_t>(dst->get());
        if (useScaleShift)
            fill<float>(*weights);
        if (useGlobalStats) {
            fill<float>(*mean);
            fill<float>(*variance);
        }
        check_zero_tail<data_t>(1, src->get());
        check_zero_tail<data_t>(1, dst->get());

        auto bn = createBnormFwd(isTraining, useGlobalStats, useScaleShift);

        std::vector<primitive> pipeline;
        pipeline.push_back(bn);
        stream(stream::kind::lazy).submit(pipeline).wait();

        check_zero_tail<data_t>(0, dst->get());

        check_bnorm_fwd(p, src->get(), *mean, *variance, *weights, dst->get(),
                flags, pk);
    }

    void Backward(prop_kind pk, unsigned flags = 0u) {
        bool useScaleShift = flags & use_scale_shift;

        auto bnorm_fwd_d = batch_normalization_forward::desc(
                prop_kind::forward_training, *data_d, p.epsilon, flags);
        bnorm_fwd_pd.reset(new batch_normalization_forward::primitive_desc(
                    bnorm_fwd_d, *eng));

        auto bnorm_bwd_d = batch_normalization_backward::desc(
                pk, *diff_data_d, *data_d, p.epsilon, flags);
        bnorm_bwd_pd.reset(
                new batch_normalization_backward::primitive_desc(
                bnorm_bwd_d, *eng, *bnorm_fwd_pd));

        if (useScaleShift)
            weights.reset(new memory(
                    bnorm_bwd_pd->weights_primitive_desc()));
        diff_weights.reset(new memory(bnorm_bwd_pd->diff_weights_primitive_desc()));
        mean.reset(new memory(bnorm_bwd_pd->mean_primitive_desc()));
        variance.reset(new memory(
                    bnorm_bwd_pd->variance_primitive_desc()));

        if (useScaleShift)
            fill<float>(*weights);
        fill<float>(diff_src->get());
        fill<float>(diff_dst->get());
        fill<float>(*mean);
        fill<float>(*variance);
        check_zero_tail<data_t>(1, diff_src->get());
        check_zero_tail<data_t>(1, diff_dst->get());

        auto bnorm_bwd = createBnormBwd(useScaleShift, pk);

        std::vector<primitive> pipeline;
        pipeline.push_back(bnorm_bwd);
        stream(stream::kind::lazy).submit(pipeline).wait();

        check_bnorm_bwd(p, src->get(), diff_dst->get(), *mean, *variance,
                *weights, diff_src->get(), *diff_weights, flags, pk);
        check_zero_tail<data_t>(0, diff_src->get());
    }

    inline bool isF32(memory::data_type data_type) {
        return data_type == mkldnn::memory::data_type::f32;
    }

    inline bool isS8(memory::data_type data_type) {
        return data_type == mkldnn::memory::data_type::s8;
    }

    primitive createBnormFwd(bool isTraining, bool useGlobalStats,
            bool useScaleShift)
    {
        if (!isTraining && !useGlobalStats) {
            return useScaleShift
                ? batch_normalization_forward(*bnorm_fwd_pd,
                    src->get(), *weights, dst->get())
                : batch_normalization_forward(*bnorm_fwd_pd, src->get(),
                        dst->get());
        } else {
            if (useGlobalStats) {
                return useScaleShift
                    ? batch_normalization_forward(*bnorm_fwd_pd,
                        src->get(), (const primitive::at)*mean,
                        (const primitive::at)*variance, *weights, dst->get())
                    : batch_normalization_forward(*bnorm_fwd_pd,
                        src->get(), (const primitive::at)*mean,
                        (const primitive::at)*variance, dst->get());
            } else {
                return useScaleShift
                    ? batch_normalization_forward(*bnorm_fwd_pd,
                        src->get(), *weights, dst->get(), *mean, *variance)
                    : batch_normalization_forward(*bnorm_fwd_pd,
                        src->get(), dst->get(), *mean, *variance);
            }
        }
    }

    primitive createBnormBwd(bool useScaleShift, prop_kind pk)
    {
        if (useScaleShift) {
            return pk == prop_kind::backward_data
                ? batch_normalization_backward(*bnorm_bwd_pd,
                    src->get(), *mean, *variance, diff_dst->get(), *weights,
                    diff_src->get())
                : batch_normalization_backward(*bnorm_bwd_pd,
                    src->get(), *mean, *variance, diff_dst->get(), *weights,
                    diff_src->get(), *diff_weights);
        } else {
            return batch_normalization_backward(*bnorm_bwd_pd, src->get(),
                    *mean, *variance, diff_dst->get(), diff_src->get());
        }
    }

    void check_bnorm_fwd(const test_bnorm_params_t &p, const memory &src,
            const memory &mean, const memory &variance, const memory &weights,
            const memory &dst, unsigned flags, prop_kind pk) {
        memory::data_type data_type = data_traits<data_t>::data_type;
        const test_bnorm_sizes_t &bp = p.sizes;
        if (bp.mb * bp.c * bp.d * bp.h * bp.w == 0)
            return;

        const bool use_weights = flags & use_scale_shift;
        const bool calculate_stats = !(flags & use_global_stats);
        const bool is_training = pk == prop_kind::forward_training;

        const data_t *src_data = (const data_t *)src.get_data_handle();
        const data_t *dst_data = (const data_t *)dst.get_data_handle();
        const float *weights_data = use_weights ?
                (const float *)weights.get_data_handle() : nullptr;
        const float *mean_data = (!calculate_stats || is_training) ?
                (const float *)mean.get_data_handle() : nullptr;
        const float *variance_data = (!calculate_stats || is_training) ?
                (const float *)variance.get_data_handle() : nullptr;

        const memory::desc src_d = src.get_primitive_desc().desc();
        const memory::desc dst_d = dst.get_primitive_desc().desc();
        const memory::desc weights_d = use_weights ?
                weights.get_primitive_desc().desc() : zero_md();

        float eps = static_cast<float>(1.e-4 * bp.mb * bp.d * bp.h * bp.w);

        size_t padded_c = src.get_primitive_desc().desc().data.layout_desc
            .blocking.padding_dims[1];

        mkldnn::impl::parallel_nd(bp.c, [&](int c) {
            float ref_mean = calculate_stats ? float(0) : mean_data[c];
            float ref_variance = calculate_stats ? float(0) :
                    variance_data[c];
            if (calculate_stats) {
                for (int n = 0; n < bp.mb; n++)
                for (int d = 0; d < bp.d; d++)
                for (int h = 0; h < bp.h; h++)
                for (int w = 0; w < bp.w; w++) {
                    size_t sidx = n * padded_c * bp.d * bp.h * bp.w +
                        c * bp.d * bp.h * bp.w + d * bp.h * bp.w + h * bp.w + w;
                    ref_mean += src_data[map_index(src_d, sidx)];
                }
                ref_mean /= bp.mb * bp.d * bp.h * bp.w;
                if (is_training) {
                    float mean_norm_max =
                        std::max(std::abs(mean_data[c]), std::abs(ref_mean));
                    if (mean_norm_max < eps)
                        mean_norm_max = float(1);
                    EXPECT_NEAR((mean_data[c] - ref_mean) / mean_norm_max,
                        0., eps);
                }

                for (int n = 0; n < bp.mb; n++)
                for (int d = 0; d < bp.d; d++)
                for (int h = 0; h < bp.h; h++)
                for (int w = 0; w < bp.w; w++) {
                    size_t sidx = n * padded_c * bp.d * bp.h * bp.w
                    + c * bp.d * bp.h * bp.w + d * bp.h * bp.w + h * bp.w + w;
                    float tmp = src_data[map_index(src_d, sidx)] - ref_mean;
                    ref_variance += tmp * tmp;
                }
                ref_variance /= bp.mb * bp.d * bp.h * bp.w;
                if (is_training) {
                    float variance_norm_max = std::max(
                            std::abs(variance_data[c]), std::abs(ref_variance));
                    if (variance_norm_max < eps)
                        variance_norm_max = float(1);
                    EXPECT_NEAR((variance_data[c] - ref_variance) /
                            variance_norm_max, 0., eps);
                }
            }
            float ref_sqrt_variance =
                static_cast<float>(sqrt(ref_variance + p.epsilon));
            float ref_rsqrt_variance = float(1) / (ref_sqrt_variance);

            for (int n = 0; n < bp.mb; n++)
            for (int d = 0; d < bp.d; d++)
            for (int h = 0; h < bp.h; h++)
            for (int w = 0; w < bp.w; w++) {
                size_t sdidx = n * padded_c * bp.d * bp.h * bp.w +
                    c * bp.d * bp.h * bp.w + d * bp.h * bp.w + h * bp.w + w;
                data_t ref_dst = data_t(0);
                float tmp_dst = float(0);
                if (use_weights) {
                    tmp_dst = weights_data[map_index(weights_d, c)] *
                        ((float)src_data[map_index(src_d, sdidx)] - ref_mean) *
                        ref_rsqrt_variance +
                        weights_data[map_index(weights_d, bp.c + c)];
                } else {
                    tmp_dst = ((float)src_data[map_index(src_d, sdidx)] -
                        ref_mean) * ref_rsqrt_variance;
                }

                if (isF32(data_type)) {
                    ref_dst = tmp_dst;
                } else if (isS8(data_type)) {
                    ref_dst = out_round<data_t>(
                        saturate<data_t, float>(tmp_dst));
                }

                data_t out = dst_data[map_index(dst_d, sdidx)];
                float norm_max = std::max(std::abs(out), std::abs(ref_dst));
                if (norm_max < 1e-2 || isS8(data_type))
                    norm_max = 1.;
                EXPECT_NEAR((out - ref_dst) / norm_max, 0., eps);
            }
        });
    }

    void check_bnorm_bwd(const test_bnorm_params_t &p, const memory &src,
            const memory &diff_dst, const memory &mean, const memory &variance,
            const memory &weights, const memory &diff_src,
            const memory &diff_weights, unsigned flags, prop_kind pk) {
        const test_bnorm_sizes_t &bp = p.sizes;
        const bool use_weights = flags & use_scale_shift;
        const bool calculate_diff_stats = !(flags & use_global_stats);

        const float *src_data = (const float *)src.get_data_handle();
        const float *weights_data = use_weights ?
            (const float *)weights.get_data_handle() : nullptr;
        const float *diff_dst_data =
            (const float *)diff_dst.get_data_handle();
        const float *mean_data = (const float *)mean.get_data_handle();
        const float *variance_data =
            (const float *)variance.get_data_handle();
        const float *diff_src_data = (float *)diff_src.get_data_handle();
        const float *diff_weights_data = (pk == prop_kind::backward) ?
            (float *)diff_weights.get_data_handle() : nullptr;

        const memory::desc src_d = src.get_primitive_desc().desc();
        const memory::desc diff_dst_d = diff_dst.get_primitive_desc().desc();
        const memory::desc weights_d = weights.get_primitive_desc().desc();
        const memory::desc diff_src_d = diff_src.get_primitive_desc().desc();
        const memory::desc diff_weights_d =
            diff_weights.get_primitive_desc().desc();

        if (bp.mb * bp.c * bp.d * bp.h * bp.w == 0) {
            if (pk == backward) {
                for (int c = 0; c < bp.c; ++c) {
                    auto dg = diff_weights_data[map_index(diff_weights_d, c)];
                    auto db =
                        diff_weights_data[map_index(diff_weights_d, bp.c + c)];
                    EXPECT_NEAR(dg, 0., 1e-7);
                    EXPECT_NEAR(db, 0., 1e-7);
                }
            }
            return;
        }

        const float eps =
            static_cast<float>(1.e-4 * bp.mb * bp.d * bp.h * bp.w);

        size_t padded_c = src.get_primitive_desc().desc().data.layout_desc.
            blocking.padding_dims[1];
        mkldnn::impl::parallel_nd(bp.c, [&](int c) {
            float ref_diff_gamma = float(0);
            float ref_diff_beta = float(0);

            auto v_mean = mean_data[c];
            auto v_variance = variance_data[c];
            const float sqrt_variance = 1.0f / sqrt(v_variance + p.epsilon);

            auto gamma =
                use_weights ? weights_data[map_index(weights_d, c)] : 1;

            for (int n = 0; n < bp.mb; n++)
            for (int d = 0; d < bp.d; d++)
            for (int h = 0; h < bp.h; h++)
            for (int w = 0; w < bp.w; w++) {
                size_t sidx = n * padded_c * bp.d * bp.h * bp.w +
                    c * bp.d * bp.h * bp.w + d * bp.h * bp.w + h * bp.w + w;
                ref_diff_gamma += (src_data[map_index(src_d, sidx)] - v_mean)
                    * diff_dst_data[map_index(diff_dst_d, sidx)];
                ref_diff_beta += diff_dst_data[map_index(diff_dst_d, sidx)];
            }
            ref_diff_gamma *= sqrt_variance;

            if (pk == backward) {
                auto diff_gamma =
                    diff_weights_data[map_index(diff_weights_d, c)];
                float norm_max =
                    std::max(std::abs(diff_gamma), std::abs(ref_diff_gamma));
                if (norm_max < 10e-3)
                    norm_max = float(1);
                EXPECT_NEAR((diff_gamma - ref_diff_gamma) / norm_max, 0., eps);

                auto diff_beta =
                    diff_weights_data[map_index(diff_weights_d, bp.c + c)];
                norm_max =
                    std::max(std::abs(diff_beta), std::abs(ref_diff_beta));
                if (norm_max < 10e-3)
                    norm_max = float(1);
                EXPECT_NEAR((diff_beta - ref_diff_beta) / norm_max, 0., eps);
            }

            for (int n = 0; n < bp.mb; n++)
            for (int d = 0; d < bp.d; d++)
            for (int h = 0; h < bp.h; h++)
            for (int w = 0; w < bp.w; w++) {
                size_t sidx = n * padded_c * bp.d * bp.h * bp.w +
                    c * bp.d * bp.h * bp.w + d * bp.h * bp.w + h * bp.w + w;
                float ref_diff_src =
                    diff_dst_data[map_index(diff_dst_d, sidx)];
                if (calculate_diff_stats) {
                    ref_diff_src -= ref_diff_beta/(bp.mb*bp.d*bp.h*bp.w)
                    + (src_data[map_index(src_d, sidx)] - v_mean)
                    *ref_diff_gamma*sqrt_variance/(bp.mb*bp.d*bp.h*bp.w);
                }
                ref_diff_src *= gamma*sqrt_variance;
                float out_diff_src =
                    diff_src_data[map_index(diff_src_d, sidx)];
                float norm_max =
                    std::max(std::abs(out_diff_src), std::abs(ref_diff_src));
                if (norm_max < eps)
                    norm_max = float(1);
                EXPECT_NEAR((out_diff_src - ref_diff_src) / norm_max, 0., eps);
            }
        });
    }
};

}