/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include <assert.h>

#include "memory_pd.hpp"
#include "mkldnn_traits.hpp"
#include "mkldnn_thread.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"

#include "format_traits.hpp"

#include "cpu_memory.hpp"

namespace mkldnn {
namespace impl {
namespace cpu {

using namespace mkldnn::impl;
using namespace mkldnn::impl::data_type;
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::memory_format;

using dk = data_kind_t;
using bf = block_format_t;

template <data_type_t dt, memory_format_t fmt>
typename utils::enable_if<format_traits<fmt>::data_kind == dk::data>::type
typed_zero_pad_data(
    const memory_desc_wrapper &m_d, typename prec_traits<dt>::type *data) {
    constexpr int blksize = format_traits<fmt>::blk_size;

    const auto &dims = m_d.dims();
    const auto &pdims = m_d.blocking_desc().padding_dims;

    const int C = pdims[1] / blksize - 1;
    const int c_tail_start = dims[1] % blksize;
    assert(c_tail_start != 0);
    const size_t sp_rest = utils::array_product(dims + 3, m_d.ndims() - 3);

    parallel_nd(dims[0], dims[2], [&](int n, int sp0) {
        auto *d = &data[m_d.blk_off(n, C, sp0)];
        for (size_t sp = 0; sp < sp_rest; ++sp) {
            for (int c = c_tail_start; c < blksize; ++c)
                d[sp * blksize + c] = 0;
        }
    });
}

template <data_type_t dt, memory_format_t fmt>
typename utils::enable_if<false
|| format_traits<fmt>::blk_fmt == bf::_4o
|| format_traits<fmt>::blk_fmt == bf::_8o
|| format_traits<fmt>::blk_fmt == bf::_16o
>::type typed_zero_pad_weights(const memory_desc_wrapper &m_d,
        typename prec_traits<dt>::type *data) {
    static constexpr int w_groups = format_traits<fmt>::data_kind == dk::gwei;
    constexpr int is_1d = format_traits<fmt>::ndims_sp == 1;
    constexpr int is_3d = format_traits<fmt>::ndims_sp == 3;
    constexpr int blksize = format_traits<fmt>::blk_size;

    const auto &dims = m_d.dims();
    const auto &pdims = m_d.blocking_desc().padding_dims;

    const int G = w_groups ? dims[0] : 1;
    const int NB_OC = pdims[w_groups + 0] / blksize;
    const int IC = dims[w_groups + 1];
    const int D = is_3d ? dims[w_groups + 2] : 1;
    const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
    const int W = dims[w_groups + 3 - is_1d + is_3d];

    const int oc_tail = pdims[w_groups + 0] - dims[w_groups + 0];

    parallel_nd(G, IC, D, H, W,
        [&](int g, int ic, int d, int h, int w) {
        auto x = &data[wei_blk_off_like_gwei3D<fmt>(m_d,
                g, NB_OC - 1, ic, d, h, w)];
        for (int oc = blksize - oc_tail; oc < blksize; ++oc)
            x[oc] = 0;
    });
}

template <data_type_t dt, memory_format_t fmt>
typename utils::enable_if<false
|| format_traits<fmt>::blk_fmt == bf::_8i
|| format_traits<fmt>::blk_fmt == bf::_16i
>::type typed_zero_pad_weights(const memory_desc_wrapper &m_d,
        typename prec_traits<dt>::type *data) {
    static constexpr int w_groups = format_traits<fmt>::data_kind == dk::gwei;
    constexpr int is_1d = format_traits<fmt>::ndims_sp == 1;
    constexpr int is_3d = format_traits<fmt>::ndims_sp == 3;
    constexpr int blksize = format_traits<fmt>::blk_size;

    const auto &dims = m_d.dims();
    const auto &pdims = m_d.blocking_desc().padding_dims;

    const int G = w_groups ? dims[0] : 1;
    const int OC = dims[w_groups + 0];
    const int NB_IC = pdims[w_groups + 1] / blksize;
    const int D = is_3d ? dims[w_groups + 2] : 1;
    const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
    const int W = dims[w_groups + 3 + is_3d];

    const int ic_tail = pdims[w_groups + 1] - dims[w_groups + 1];

    parallel_nd(G, OC, D, H, W,
        [&](int g, int oc, int d, int h, int w) {
        auto x = &data[wei_blk_off_like_gwei3D<fmt>(m_d,
                g, oc, NB_IC - 1, d, h, w)];
        for (int ic = blksize - ic_tail; ic < blksize; ++ic)
            x[ic] = 0;
    });
}

template <data_type_t dt, memory_format_t fmt>
typename utils::enable_if<
block_format_traits<format_traits<fmt>::blk_fmt>::blk_ndims == 2>::type
typed_zero_pad_weights(const memory_desc_wrapper &m_d,
        typename prec_traits<dt>::type *data) {
    using data_t = typename prec_traits<dt>::type;
    static constexpr int w_groups = format_traits<fmt>::data_kind == dk::gwei;
    constexpr int is_1d = format_traits<fmt>::ndims_sp == 1;
    constexpr int is_3d = format_traits<fmt>::ndims_sp == 3;
    constexpr int blksize = format_traits<fmt>::blk_size;
    const auto &dims = m_d.dims();
    const auto &pdims = m_d.blocking_desc().padding_dims;

    const int G = w_groups ? dims[0] : 1;
    const int NB_OC = pdims[w_groups + 0] / blksize;
    const int NB_IC = pdims[w_groups + 1] / blksize;
    const int D = is_3d ? dims[w_groups + 2] : 1;
    const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
    const int W = dims[w_groups + 3 - is_1d + is_3d];

    auto ker = [&](data_t *d, const int oc_tail, const int ic_tail) {
#       define blk_off OI_blk_off<format_traits<fmt>::blk_fmt>
        int oc = 0;
        for (; oc < blksize - oc_tail; ++oc) {
            for (int ic = blksize - ic_tail; ic < blksize; ++ic)
                d[blk_off(oc, ic)] = 0;
        }
        for (; oc < blksize; ++oc)
            for (int ic = 0; ic < blksize; ++ic)
                d[blk_off(oc, ic)] = 0;
#       undef blk_off
    };

    const int oc_tail = pdims[w_groups + 0] - dims[w_groups + 0];
    const int ic_tail = pdims[w_groups + 1] - dims[w_groups + 1];

    if (ic_tail) {
        parallel_nd(G, NB_OC, D, H, W,
            [&](int g, int nb_oc, int d, int h, int w) {
            auto x = &data[wei_blk_off_like_gwei3D<fmt>(m_d,
                    g, nb_oc, NB_IC - 1, d, h, w)];
            ker(x, 0, ic_tail);
        });
    }

    if (oc_tail) {
        parallel_nd(G, NB_IC, D, H, W,
            [&](int g, int nb_ic, int d, int h, int w) {
            auto x = &data[wei_blk_off_like_gwei3D<fmt>(m_d,
                    g, NB_OC - 1, nb_ic, d, h, w)];
            ker(x, oc_tail, 0);
        });
    }
}

template <data_type_t dt, memory_format_t fmt>
typename utils::enable_if<false
|| format_traits<fmt>::blk_fmt == bf::_8g
|| format_traits<fmt>::blk_fmt == bf::_16g
>::type typed_zero_pad_weights(const memory_desc_wrapper &m_d,
        typename prec_traits<dt>::type *data) {
    constexpr int blksize = format_traits<fmt>::blk_size;

    const auto &dims = m_d.dims();
    const auto &pdims = m_d.blocking_desc().padding_dims;

    const int G = pdims[0] / blksize - 1;
    const int g_tail_start = dims[0] % blksize;
    assert(g_tail_start != 0);
    const ptrdiff_t sz_rest
        = (ptrdiff_t)utils::array_product(dims + 1, m_d.ndims() - 1);

    auto *d = &data[m_d.blk_off(G)];

    parallel_nd(sz_rest, [&](ptrdiff_t s) {
        for (int g = g_tail_start; g < blksize; ++g)
            d[s * blksize + g] = 0;
    });
}

template <data_type_t dt>
void typed_zero_pad_generic_blocked(const memory_desc_wrapper &m_d,
        typename prec_traits<dt>::type *data) {
    const int ndims = m_d.ndims();
    const auto &dims = m_d.dims();
    const auto &pdims = m_d.blocking_desc().padding_dims;

    const ptrdiff_t nelems = (ptrdiff_t)m_d.nelems(true);

    /* [D_0] .. [D_k][D_k+1] .. [D_ndim - 1]
     *            |  \                     /
     *            |   ---------------------
     *           has        contiguous
     *         padding
     *
     * step     <-- D_k+1 * ... * D_ndims-1
     * step_dim <-- k
     */

    ptrdiff_t step = 1;
    int step_dim = ndims - 1;
    for (; step_dim >= 0; --step_dim) {
        if (dims[step_dim] != pdims[step_dim]) break;
        step *= dims[step_dim];
    }

    assert(step_dim >= 0 && "no zero padding is required");
    if (step_dim < 0) return;

    parallel_nd(nelems / step, [&](ptrdiff_t e1) {
        bool need_zero = false;

        ptrdiff_t idx = e1;
        for (int d = step_dim; d >= 0; --d) {
            if (idx % pdims[d] >= dims[d]) {
                need_zero = true;
                break;
            }
            idx /= pdims[d];
        }

        if (need_zero) {
            for (ptrdiff_t e0 = 0; e0 < step; ++e0)
                data[m_d.off_l(e1 * step + e0, true)] = 0;
        }
    });
}

template <data_type_t dt>
status_t cpu_memory_t::typed_zero_pad() const {
    const memory_desc_wrapper mpd(pd());

    // FIXME: guard this check for non-blocked layout
    if (mpd.nelems(false) == mpd.nelems(true))
        return success;

    auto *data = (typename prec_traits<dt>::type *)data_;
    const auto fmt = mpd.format();

    /* data */
#   define MAYBE_DATA(f) if (fmt == f) \
    { typed_zero_pad_data<dt, f>(mpd, data); return success; }
    MAYBE_DATA(nCw4c);
    MAYBE_DATA(nCw8c);
    MAYBE_DATA(nCw16c);
    MAYBE_DATA(nChw4c);
    MAYBE_DATA(nChw8c);
    MAYBE_DATA(nCdhw4c);
    MAYBE_DATA(nCdhw8c);
    MAYBE_DATA(nChw16c);
    MAYBE_DATA(nCdhw16c);

    /* weights */
#   define MAYBE_WEIGHTS(f) if (fmt == f) \
    { typed_zero_pad_weights<dt, f>(mpd, data); return success; }
    MAYBE_WEIGHTS(OIdhw4i4o);
    MAYBE_WEIGHTS(OIdhw8i8o);
    MAYBE_WEIGHTS(OIdhw8o8i);
    MAYBE_WEIGHTS(OIdhw16i16o);
    MAYBE_WEIGHTS(OIdhw16o16i);
    MAYBE_WEIGHTS(Oidhw4o);
    MAYBE_WEIGHTS(Oidhw16o);
    MAYBE_WEIGHTS(Odhwi16o);
    MAYBE_WEIGHTS(Odhwi8o);
    MAYBE_WEIGHTS(oIhw8i);
    MAYBE_WEIGHTS(oIhw16i);
    MAYBE_WEIGHTS(oIdhw8i);
    MAYBE_WEIGHTS(oIdhw16i);
    MAYBE_WEIGHTS(OIhw4i4o);
    MAYBE_WEIGHTS(OIhw8i8o);
    MAYBE_WEIGHTS(OIhw16i16o);
    MAYBE_WEIGHTS(OIhw4i16o4i);
    MAYBE_WEIGHTS(OIhw4i16o4i_s8s8);
    MAYBE_WEIGHTS(OIw4i4o);
    MAYBE_WEIGHTS(Owi8o);
    MAYBE_WEIGHTS(OIw8i8o);
    MAYBE_WEIGHTS(OIw8o8i);
    MAYBE_WEIGHTS(OIw16i16o);
    MAYBE_WEIGHTS(OIw16o16i);
    MAYBE_WEIGHTS(Oiw4o);
    MAYBE_WEIGHTS(Oiw16o);
    MAYBE_WEIGHTS(Owi16o);
    MAYBE_WEIGHTS(OIw8i16o2i);
    MAYBE_WEIGHTS(OIw8o16i2o);
    MAYBE_WEIGHTS(IOw8o16i2o);
    MAYBE_WEIGHTS(IOw16o16i);
    MAYBE_WEIGHTS(OIw4i16o4i);
    MAYBE_WEIGHTS(OIw4i16o4i_s8s8);
    MAYBE_WEIGHTS(OIhw8i16o2i);
    MAYBE_WEIGHTS(OIhw8o16i2o);
    MAYBE_WEIGHTS(IOhw8o16i2o);
    MAYBE_WEIGHTS(OIdhw8i16o2i);
    MAYBE_WEIGHTS(OIdhw8o16i2o);
    MAYBE_WEIGHTS(IOdhw8o16i2o);
    MAYBE_WEIGHTS(OIhw8o8i);
    MAYBE_WEIGHTS(OIhw16o16i);
    MAYBE_WEIGHTS(IOhw16o16i);
    MAYBE_WEIGHTS(Oihw4o);
    MAYBE_WEIGHTS(Oihw16o);
    MAYBE_WEIGHTS(Ohwi8o);
    MAYBE_WEIGHTS(Ohwi4o);
    MAYBE_WEIGHTS(Ohwi16o);
    MAYBE_WEIGHTS(gOIhw4o4i_s8s8);
    MAYBE_WEIGHTS(gOIhw4o4i_s8s8);
    MAYBE_WEIGHTS(gOIhw4i4o);
    MAYBE_WEIGHTS(gOIhw8i8o);
    MAYBE_WEIGHTS(gOIhw16i16o);
    MAYBE_WEIGHTS(gOIhw4i16o4i);
    MAYBE_WEIGHTS(gOIhw4i16o4i_s8s8);
    MAYBE_WEIGHTS(gOIhw2i8o4i);
    MAYBE_WEIGHTS(gOIhw2i8o4i_s8s8);
    MAYBE_WEIGHTS(gOIw4i4o);
    MAYBE_WEIGHTS(gOwi8o);
    MAYBE_WEIGHTS(gOIw8i8o);
    MAYBE_WEIGHTS(gOIw8o8i);
    MAYBE_WEIGHTS(gOIw16i16o);
    MAYBE_WEIGHTS(gOIw16o16i);
    MAYBE_WEIGHTS(gOiw4o);
    MAYBE_WEIGHTS(gOiw16o);
    MAYBE_WEIGHTS(gOwi16o);
    MAYBE_WEIGHTS(gOIw8i16o2i);
    MAYBE_WEIGHTS(gOIw8o16i2o);
    MAYBE_WEIGHTS(gIOw8o16i2o);
    MAYBE_WEIGHTS(gIOw16o16i);
    MAYBE_WEIGHTS(gOIw4i16o4i);
    MAYBE_WEIGHTS(gOIw4i16o4i_s8s8);
    MAYBE_WEIGHTS(gOIhw8i16o2i);
    MAYBE_WEIGHTS(gOIhw8o16i2o);
    MAYBE_WEIGHTS(gIOhw8o16i2o);
    MAYBE_WEIGHTS(gOIdhw8i16o2i);
    MAYBE_WEIGHTS(gOIdhw8o16i2o);
    MAYBE_WEIGHTS(gIOdhw8o16i2o);
    MAYBE_WEIGHTS(gOIhw8o8i);
    MAYBE_WEIGHTS(gOIhw16o16i);
    MAYBE_WEIGHTS(gIOhw16o16i);
    MAYBE_WEIGHTS(gOihw4o);
    MAYBE_WEIGHTS(gOihw16o);
    MAYBE_WEIGHTS(gOhwi8o);
    MAYBE_WEIGHTS(gOhwi4o);
    MAYBE_WEIGHTS(gOhwi16o);
    MAYBE_WEIGHTS(gOIdhw4i4o);
    MAYBE_WEIGHTS(gOIdhw8i8o);
    MAYBE_WEIGHTS(gOIdhw8o8i);
    MAYBE_WEIGHTS(gOIdhw16i16o);
    MAYBE_WEIGHTS(gOIdhw16o16i);
    MAYBE_WEIGHTS(gOidhw4o);
    MAYBE_WEIGHTS(gOidhw16o);
    MAYBE_WEIGHTS(gOdhwi16o);
    MAYBE_WEIGHTS(gOdhwi8o);
    MAYBE_WEIGHTS(Goihw8g);
    MAYBE_WEIGHTS(Goihw16g);
    MAYBE_WEIGHTS(Goiw16g);
#   undef MAYBE_WEIGHTS

    // the last line of defence
    if (types::format_normalize(fmt) == blocked) {
        typed_zero_pad_generic_blocked<dt>(mpd, data);
        return success;
    }

    return unimplemented;
}

status_t cpu_memory_t::zero_pad() const {
    memory_desc_wrapper md(pd());
    const bool skip_zeroing = false
        || data_ == nullptr
        || md.is_zero()
        || !md.is_blocking_desc();
    if (skip_zeroing) return success;

    switch (md.data_type()) {
        case f32: return typed_zero_pad<f32>();
        case s32: return typed_zero_pad<s32>();
        case s16: return typed_zero_pad<s16>();
        case bf16: return typed_zero_pad<s16>();
        case s8: return typed_zero_pad<s8>();
        case u8: return typed_zero_pad<u8>();
        case bin: return typed_zero_pad<u8>();
        default: assert(!"memory is undefined"); return unimplemented;
    }
    return unimplemented;
}

}
}
}