/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include <assert.h>
#include <math.h>

#include "c_types_map.hpp"
#include "math_utils.hpp"
#include "mkldnn_thread.hpp"
#include "nstl.hpp"
#include "type_helpers.hpp"
#include "bfloat16_utils.hpp"
#include "ref_pooling.hpp"

namespace mkldnn {
namespace impl {
namespace cpu {

using namespace nstl;
using namespace bf16_cvt_utils;

template <data_type_t data_type, data_type_t acc_type>
void ref_pooling_fwd_t<data_type, acc_type>::execute_forward() const {
    using namespace alg_kind;
    using namespace prop_kind;

    auto alg = pd()->desc()->alg_kind;

    auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
    auto dst = reinterpret_cast<data_t *>(this->memory(0));
    auto ws = alg == pooling_max && pd()->desc()->prop_kind == forward_training
        ? reinterpret_cast<unsigned char *>(this->memory(1)) : nullptr;

    const memory_desc_wrapper src_d(pd()->src_pd());
    const memory_desc_wrapper dst_d(pd()->dst_pd());
    const memory_desc_wrapper ws_d(pd()->workspace_pd());
    const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;

    const int ID = pd()->ID();
    const int IH = pd()->IH();
    const int IW = pd()->IW();
    const int KD = pd()->KD();
    const int KH = pd()->KH();
    const int KW = pd()->KW();
    const int SD = pd()->KSD();
    const int SH = pd()->KSH();
    const int SW = pd()->KSW();
    const int padF = pd()->padFront();
    const int padT = pd()->padT();
    const int padL = pd()->padL();
    const int padB = pd()->padB();
    const int padR = pd()->padR();
    const int padBack = pd()->padBack();

    const int MB = pd()->MB();
    const int OC = pd()->C();
    const int OD = pd()->OD();
    const int OH = pd()->OH();
    const int OW = pd()->OW();

    const bool is_3d = pd()->desc()->src_desc.ndims == 5;

    auto set_ws = [=](int mb, int oc, int od, int oh, int ow, int value) {
        // value = -1 means that pool window is placed outside of source domain
        // for current {od, oh, ow} point
        if (ws) {
            assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
            size_t offset = is_3d
                ? ws_d.off(mb, oc, od, oh, ow) : ws_d.off(mb, oc, oh, ow);
            if (ws_dt == data_type::u8) {
                const int u8_max = numeric_limits<
                    typename prec_traits<data_type::u8>::type>::max();
                if (value == -1)
                    value = u8_max;
                assert(0 <= value && value <= u8_max);
                ws[offset] = value;
            } else
                reinterpret_cast<int *>(ws)[offset] = value;
        }
    };

    auto ker_max = [=](data_t *d, int mb, int oc, int od, int oh, int ow) {
        bool is_initialized = false;
        int current_pool_size = 0;
        for (int kd = 0; kd < KD; ++kd) {
            for (int kh = 0; kh < KH; ++kh) {
                for (int kw = 0; kw < KW; ++kw) {
                    const int id = od * SD - padF + kd;
                    const int ih = oh * SH - padT + kh;
                    const int iw = ow * SW - padL + kw;

                    if (id < 0 || id >= ID) continue;
                    if (ih < 0 || ih >= IH) continue;
                    if (iw < 0 || iw >= IW) continue;

                    const auto offset = is_3d
                        ? src_d.off(mb, oc, id, ih, iw)
                        : src_d.off(mb, oc, ih, iw);
                    const auto s = src[offset];
                    if (!is_initialized) {
                        d[0] = s;
                        set_ws(mb, oc, od, oh, ow, kd * KH * KW + kh*KW + kw);
                        is_initialized = true;
                    } else {
                        if (s > d[0]) {
                            d[0] = s;
                            set_ws(mb, oc, od, oh, ow, kd * KH * KW + kh * KW + kw);
                        }
                    }
                    current_pool_size++;
                }
            }
        }

        // corner case: pool window is outside of real input domain
        // for this point.
        if (current_pool_size == 0)
            set_ws(mb, oc, 1, oh, ow, -1);
    };
    auto ker_avg = [=](data_t *d, int mb, int oc, int od, int oh, int ow) {
        auto id_start = od*SD - padF;
        auto ih_start = oh*SH - padT;
        auto iw_start = ow*SW - padL;
        auto id_end = nstl::min(od*SD - padF + KD, ID + padBack);
        auto ih_end = nstl::min(oh*SH - padT + KH, IH + padB);
        auto iw_end = nstl::min(ow*SW - padL + KW, IW + padR);

        // case alg == pooling_avg_include_padding
        auto num_summands = (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start);

        id_start = nstl::max(id_start, 0);
        ih_start = nstl::max(ih_start, 0);
        iw_start = nstl::max(iw_start, 0);
        id_end = nstl::min(id_end, ID);
        ih_end = nstl::min(ih_end, IH);
        iw_end = nstl::min(iw_end, IW);

        if (alg == pooling_avg_exclude_padding)
            num_summands = (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start);
        if (num_summands == 0) return;

        acc_data_t dst = 0;
        for (int id = id_start; id < id_end; ++id) {
            for (int ih = ih_start; ih < ih_end; ++ih) {
                for (int iw = iw_start; iw < iw_end; ++iw) {
                    const auto offset = is_3d
                        ? src_d.off(mb, oc, id, ih, iw)
                        : src_d.off(mb, oc, ih, iw);
                    dst += src[offset];
                }
            }
        }

        d[0] = math::out_round<data_t>((float)dst / num_summands);
    };

    if (alg == pooling_max) {
        parallel_nd(MB, OC, OD, OH, OW,
            [&](int mb, int oc, int od, int oh, int ow) {
            data_t *d = is_3d
                ? &dst[dst_d.off(mb, oc, od, oh, ow)]
                : &dst[dst_d.off(mb, oc, oh, ow)];
                d[0] = (data_t)0;
                set_ws(mb, oc, od, oh, ow, 0);
                ker_max(d, mb, oc, od, oh, ow);
        });
    } else {
        parallel_nd(MB, OC, OD, OH, OW,
            [&](int mb, int oc, int od, int oh, int ow) {
            data_t *d = is_3d
                ? &dst[dst_d.off(mb, oc, od, oh, ow)]
                : &dst[dst_d.off(mb, oc, oh, ow)];
            d[0] = (data_t)0;
            ker_avg(d, mb, oc, od, oh, ow);
        });
    }
}

template <>
void ref_pooling_fwd_t<data_type::bf16, data_type::f32>::execute_forward() const {
    using namespace alg_kind;
    using namespace prop_kind;

    auto alg = pd()->desc()->alg_kind;

    auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
    auto dst = reinterpret_cast<data_t *>(this->memory(0));
    auto ws = alg == pooling_max && pd()->desc()->prop_kind == forward_training
        ? reinterpret_cast<unsigned char *>(this->memory(1)) : nullptr;

    const memory_desc_wrapper src_d(pd()->src_pd());
    const memory_desc_wrapper dst_d(pd()->dst_pd());
    const memory_desc_wrapper ws_d(pd()->workspace_pd());
    const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;

    const int ID = pd()->ID();
    const int IH = pd()->IH();
    const int IW = pd()->IW();
    const int KD = pd()->KD();
    const int KH = pd()->KH();
    const int KW = pd()->KW();
    const int SD = pd()->KSD();
    const int SH = pd()->KSH();
    const int SW = pd()->KSW();
    const int padF = pd()->padFront();
    const int padT = pd()->padT();
    const int padL = pd()->padL();
    const int MB = pd()->MB();
    const int OC = pd()->C();
    const int OD = pd()->OD();
    const int OH = pd()->OH();
    const int OW = pd()->OW();
    const int padB = pd()->padB();
    const int padR = pd()->padR();
    const int padBack = pd()->padBack();

    const bool is_3d = pd()->desc()->src_desc.ndims == 5;

    auto set_ws = [=](int mb, int oc, int od, int oh, int ow, int value) {
        // value = -1 means that pool window is placed outside of source domain
        // for current {od, oh, ow} point
        if (ws) {
            assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
            size_t offset = is_3d
                ? ws_d.off(mb, oc, od, oh, ow) : ws_d.off(mb, oc, oh, ow);
            if (ws_dt == data_type::u8) {
                const int u8_max = numeric_limits<
                    typename prec_traits<data_type::u8>::type>::max();
                if (value == -1)
                    value = u8_max;
                assert(0 <= value && value <= u8_max);
                ws[offset] = value;
            } else
                reinterpret_cast<int *>(ws)[offset] = value;
        }
    };

    auto ker_max = [=](mkldnn_bfloat16_t *d,
            int mb, int oc, int od, int oh, int ow) {
        bool is_initialized = false;
        int current_pool_size = 0;
        float d_max = cvt_bfloat16_to_float(d[0]);
        for (int kd = 0; kd < KD; ++kd) {
            for (int kh = 0; kh < KH; ++kh) {
                for (int kw = 0; kw < KW; ++kw) {
                    const int id = od * SD - padF + kd;
                    const int ih = oh * SH - padT + kh;
                    const int iw = ow * SW - padL + kw;

                    if (id < 0 || id >= ID) continue;
                    if (ih < 0 || ih >= IH) continue;
                    if (iw < 0 || iw >= IW) continue;

                    const auto offset = is_3d
                        ? src_d.off(mb, oc, id, ih, iw)
                        : src_d.off(mb, oc, ih, iw);
                    float s = cvt_bfloat16_to_float(src[offset]);
                    if (!is_initialized) {
                        d_max = s;
                        set_ws(mb, oc, od, oh, ow, kd * KH * KW + kh*KW + kw);
                        is_initialized = true;
                    } else {
                        if (s > d_max) {
                            d_max = s;
                            set_ws(mb, oc, od, oh, ow, kd * KH * KW + kh*KW + kw);
                        }
                    }
                    current_pool_size++;
                }
            }
        }

        d[0] = cvt_float_to_bfloat16(d_max);

        // corner case: pool window is outside of real input domain
        // for this point.
        if (current_pool_size == 0)
            set_ws(mb, oc, 1, oh, ow, -1);
    };

    auto ker_avg = [=](mkldnn_bfloat16_t *d,
            int mb, int oc, int od, int oh, int ow) {
        auto id_start = od*SD - padF;
        auto ih_start = oh*SH - padT;
        auto iw_start = ow*SW - padL;
        auto id_end = nstl::min(od*SD - padF + KD, ID + padBack);
        auto ih_end = nstl::min(oh*SH - padT + KH, IH + padB);
        auto iw_end = nstl::min(ow*SW - padL + KW, IW + padR);

        // case alg == pooling_avg_include_padding
        auto num_summands = (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start);

        id_start = nstl::max(id_start, 0);
        ih_start = nstl::max(ih_start, 0);
        iw_start = nstl::max(iw_start, 0);
        id_end = nstl::min(id_end, ID);
        ih_end = nstl::min(ih_end, IH);
        iw_end = nstl::min(iw_end, IW);

        if (alg == pooling_avg_exclude_padding)
            num_summands = (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start);
        if (num_summands == 0) return;

        float dst = 0;
        for (int id = id_start; id < id_end; ++id) {
            for (int ih = ih_start; ih < ih_end; ++ih) {
                for (int iw = iw_start; iw < iw_end; ++iw) {
                    const auto offset = is_3d
                        ? src_d.off(mb, oc, id, ih, iw)
                        : src_d.off(mb, oc, ih, iw);
                    const auto s = cvt_bfloat16_to_float(src[offset]);
                    dst += s;
                }
            }
        }

        dst = math::out_round<float>((float)dst / num_summands);
        d[0] = cvt_float_to_bfloat16(dst);
    };

    if (alg == pooling_max) {
        parallel_nd(MB, OC, OD, OH, OW,
            [&](int mb, int oc, int od, int oh, int ow) {
            data_t *d = is_3d
                ? &dst[dst_d.off(mb, oc, od, oh, ow)]
                : &dst[dst_d.off(mb, oc, oh, ow)];
                d[0] = approx_bfloat16_lowest();
                set_ws(mb, oc, od, oh, ow, 0);
                ker_max(d, mb, oc, od, oh, ow);
        });
    } else {
        parallel_nd(MB, OC, OD, OH, OW,
            [&](int mb, int oc, int od, int oh, int ow) {
            data_t *d = is_3d
                ? &dst[dst_d.off(mb, oc, od, oh, ow)]
                : &dst[dst_d.off(mb, oc, oh, ow)];
            d[0] = 0;
            ker_avg(d, mb, oc, od, oh, ow);
        });
    }
}

template <data_type_t data_type>
void ref_pooling_bwd_t<data_type>::execute_backward() const {
    using namespace alg_kind;

    auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
    auto ws = pd()->desc()->alg_kind != alg_kind::pooling_max ? nullptr
        : reinterpret_cast<const unsigned char *>(this->input_memory(1));
    auto diff_src = reinterpret_cast<data_t *>(this->memory(0));

    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
    const memory_desc_wrapper ws_d(pd()->workspace_pd());
    const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());

    const int ID = pd()->ID();
    const int IH = pd()->IH();
    const int IW = pd()->IW();
    const int KD = pd()->KD();
    const int KH = pd()->KH();
    const int KW = pd()->KW();
    const int SD = pd()->KSD();
    const int SH = pd()->KSH();
    const int SW = pd()->KSW();
    const int padF = pd()->padFront();
    const int padT = pd()->padT();
    const int padL = pd()->padL();
    const int MB = pd()->MB();
    const int OC = pd()->C();
    const int OD = pd()->OD();
    const int OH = pd()->OH();
    const int OW = pd()->OW();

    const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;

    auto alg = pd()->desc()->alg_kind;

    auto apply_offset = [=](int index, int offset) {
        return (index > offset) ? index - offset : 0;
    };

    auto ker_zero = [=](int _mb, int _oc) {
        for (int id = 0; id < ID; ++id) {
            for (int ih = 0; ih < IH; ++ih) {
                for (int iw = 0; iw < IW; ++iw) {
                    const auto offset = is_3d
                        ? diff_src_d.off(_mb, _oc, id, ih, iw)
                        : diff_src_d.off(_mb, _oc, ih, iw);
                    diff_src[offset] =
                        data_type_t(0);
                }
            }
        }
    };

    auto ker_max = [=](const data_t *d, int mb, int oc, int od, int oh,
            int ow) {
        const size_t ws_off = is_3d
            ? ws_d.off(mb, oc, od, oh, ow)
            : ws_d.off(mb, oc, oh, ow);
        const int index = ws_d.data_type() == data_type::u8
            ? (int)ws[ws_off] : ((int *)ws)[ws_off];

        const int invalid_index_value = ws_d.data_type() == data_type::u8
            ? numeric_limits<typename prec_traits<data_type::u8>::type>::max()
            : -1;
        if (index == invalid_index_value)
           return; // corner case: pool window is outside of real input domain
                   // for this point, do nothing

        const int kw = index % KW;
        const int kh = is_3d
            ? (index / KW) % KH
            : index / KW;
        const int kd = (index / KW) / KH;
        const int id = od * SD - padF + kd;
        const int ih = oh * SH - padT + kh;
        const int iw = ow * SW - padL + kw;

        // If padding area could fit the kernel,
        // then input displacement would be out of bounds.
        // No need to back propagate there as padding is
        // virtual in pooling_max case.
        if (id < 0 || id >= ID)
            return;
        if (ih < 0 || ih >= IH)
            return;
        if (iw < 0 || iw >= IW)
            return;

        const auto offset = is_3d
            ? diff_src_d.off(mb, oc, id, ih, iw)
            : diff_src_d.off(mb, oc, ih, iw);

        diff_src[offset] += d[0];
    };

    auto ker_avg = [=](const data_t *d, int mb, int oc, int od, int oh,
            int ow) {
        auto id_start = apply_offset(od*SD, padF);
        auto ih_start = apply_offset(oh*SH, padT);
        auto iw_start = apply_offset(ow*SW, padL);
        auto id_end = nstl::min(od*SD - padF + KD, ID);
        auto ih_end = nstl::min(oh*SH - padT + KH, IH);
        auto iw_end = nstl::min(ow*SW - padL + KW, IW);

        auto num_summands = (alg == pooling_avg_include_padding)
            ? KW * KH * KD
            : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start);
        assert(num_summands > 0);

        for (int id = id_start; id < id_end; ++id)
        for (int ih = ih_start; ih < ih_end; ++ih)
        for (int iw = iw_start; iw < iw_end; ++iw) {
            const auto offset = is_3d
                ? diff_src_d.off(mb, oc, id, ih, iw)
                : diff_src_d.off(mb, oc, ih, iw);
            diff_src[offset] += d[0] / num_summands;
        }
    };

    if (pd()->desc()->alg_kind == alg_kind::pooling_max) {
        parallel_nd(MB, OC, [&](int mb, int oc) {
            ker_zero(mb, oc);
            for (int od = 0; od < OD; ++od) {
                for (int oh = 0; oh < OH; ++oh) {
                    for (int ow = 0; ow < OW; ++ow) {
                        const data_t *d = is_3d
                            ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)]
                            : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)];
                        ker_max(d, mb, oc, od, oh, ow);
                    }
                }
            }
        });
    } else {
        parallel_nd(MB, OC, [&](int mb, int oc) {
            ker_zero(mb, oc);
            for (int od = 0; od < OD; ++od) {
                for (int oh = 0; oh < OH; ++oh) {
                    for (int ow = 0; ow < OW; ++ow) {
                        const data_t *d = is_3d
                            ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)]
                            : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)];
                        ker_avg(d, mb, oc, od, oh, ow);
                    }
                }
            }
        });
    }
}

template <>
void ref_pooling_bwd_t<data_type::bf16>::execute_backward() const {
    using namespace alg_kind;

    auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
    auto ws = pd()->desc()->alg_kind != alg_kind::pooling_max ? nullptr
        : reinterpret_cast<const unsigned char *>(this->input_memory(1));
    auto diff_src = reinterpret_cast<data_t *>(this->memory(0));

    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
    const memory_desc_wrapper ws_d(pd()->workspace_pd());
    const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());

    const int ID = pd()->ID();
    const int IH = pd()->IH();
    const int IW = pd()->IW();
    const int KD = pd()->KD();
    const int KH = pd()->KH();
    const int KW = pd()->KW();
    const int SD = pd()->KSD();
    const int SH = pd()->KSH();
    const int SW = pd()->KSW();
    const int padF = pd()->padFront();
    const int padT = pd()->padT();
    const int padL = pd()->padL();
    const int MB = pd()->MB();
    const int OC = pd()->C();
    const int OD = pd()->OD();
    const int OH = pd()->OH();
    const int OW = pd()->OW();

    const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;

    auto alg = pd()->desc()->alg_kind;

    auto apply_offset = [=](int index, int offset) {
        return (index > offset) ? index - offset : 0;
    };

    auto ker_zero = [=](int _mb, int _oc) {
        for (int id = 0; id < ID; ++id) {
            for (int ih = 0; ih < IH; ++ih) {
                for (int iw = 0; iw < IW; ++iw) {
                    const auto offset = is_3d ?
                        diff_src_d.off(_mb, _oc, id, ih, iw)
                        : diff_src_d.off(_mb, _oc, ih, iw);
                    diff_src[offset] =
                        data_type_t(0);
                }
            }
        }
    };

    auto ker_max = [=](const mkldnn_bfloat16_t *d, int mb, int oc, int od, int oh,
            int ow) {
        const size_t ws_off = is_3d
            ? ws_d.off(mb, oc, od, oh, ow)
            : ws_d.off(mb, oc, oh, ow);
        const int index = ws_d.data_type() == data_type::u8
            ? (int)ws[ws_off] : ((int *)ws)[ws_off];

        const int invalid_index_value = ws_d.data_type() == data_type::u8
            ? numeric_limits<typename prec_traits<data_type::u8>::type>::max()
            : -1;
        if (index == invalid_index_value)
           return; // corner case: pool window is outside of real input domain
                   // for this point, do nothing

        const int kw = index % KW;
        const int kh = (index / KW) % KH;
        const int kd = (index / KW) / KH;
        const int id = od * SD - padF + kd;
        const int ih = oh * SH - padT + kh;
        const int iw = ow * SW - padL + kw;

        // If padding area could fit the kernel,
        // then input displacement would be out of bounds.
        // No need to back propagate there as padding is
        // virtual in pooling_max case.
        if (id < 0 || id >= ID)
            return;
        if (ih < 0 || ih >= IH)
            return;
        if (iw < 0 || iw >= IW)
            return;

        const auto offset = is_3d
            ? diff_src_d.off(mb, oc, id, ih, iw)
            : diff_src_d.off(mb, oc, ih, iw);

        float ds = cvt_bfloat16_to_float(diff_src[offset]);
        ds += cvt_bfloat16_to_float(d[0]); 
        diff_src[offset] = cvt_float_to_bfloat16(ds);
    };

    auto ker_avg = [=](const mkldnn_bfloat16_t *d, int mb, int oc, int od, int oh,
            int ow) {
        auto id_start = apply_offset(od*SD, padF);
        auto ih_start = apply_offset(oh*SH, padT);
        auto iw_start = apply_offset(ow*SW, padL);
        auto id_end = nstl::min(od*SD - padF + KD, ID);
        auto ih_end = nstl::min(oh*SH - padT + KH, IH);
        auto iw_end = nstl::min(ow*SW - padL + KW, IW);

        auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH*KD
            : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start);
        assert(num_summands > 0);

        for (int id = id_start; id < id_end; ++id)
        for (int ih = ih_start; ih < ih_end; ++ih)
        for (int iw = iw_start; iw < iw_end; ++iw) {
            const auto offset = is_3d
                ? diff_src_d.off(mb, oc, id, ih, iw)
                : diff_src_d.off(mb, oc, ih, iw);
            float ds = cvt_bfloat16_to_float(diff_src[offset]);
            ds += cvt_bfloat16_to_float(d[0]) / num_summands; 
            diff_src[offset] = cvt_float_to_bfloat16(ds);
        }
    };

    if (pd()->desc()->alg_kind == alg_kind::pooling_max) {
        parallel_nd(MB, OC, [&](int mb, int oc) {
            ker_zero(mb, oc);
            for (int od = 0; od < OD; ++od) {
                for (int oh = 0; oh < OH; ++oh) {
                    for (int ow = 0; ow < OW; ++ow) {
                        const data_t *d = is_3d
                            ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)]
                            : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)];
                        ker_max(d, mb, oc, od, oh, ow);
                    }
                }
            }
        });
    } else {
        parallel_nd(MB, OC, [&](int mb, int oc) {
            ker_zero(mb, oc);
            for (int od = 0; od < OD; ++od) {
                for (int oh = 0; oh < OH; ++oh) {
                    for (int ow = 0; ow < OW; ++ow) {
                        const data_t *d = is_3d
                            ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)]
                            : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)];
                        ker_avg(d, mb, oc, od, oh, ow);
                    }
                }
            }
        });
    }
}

template struct ref_pooling_fwd_t<data_type::f32>;
template struct ref_pooling_fwd_t<data_type::s32>;
template struct ref_pooling_fwd_t<data_type::bf16, data_type::f32>;
template struct ref_pooling_fwd_t<data_type::s16, data_type::s32>;
template struct ref_pooling_fwd_t<data_type::s8, data_type::s32>;
template struct ref_pooling_fwd_t<data_type::u8, data_type::s32>;

template struct ref_pooling_bwd_t<data_type::f32>;
template struct ref_pooling_bwd_t<data_type::s32>;
template struct ref_pooling_bwd_t<data_type::bf16>;
template struct ref_pooling_bwd_t<data_type::s16>;

}
}
}

// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s