/******************************************************************************* * Copyright 2018 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ #include <assert.h> #include "c_types_map.hpp" #include "memory_desc_wrapper.hpp" #include "mkldnn_debug.h" #include "nstl.hpp" #include "type_helpers.hpp" #include "utils.hpp" #include "cpu_primitive.hpp" #include "cpu_reorder_pd.hpp" #include "jit_uni_reorder.hpp" using namespace mkldnn::impl::types; using namespace mkldnn::impl::status; namespace mkldnn { namespace impl { namespace cpu { namespace tr { /** ad-hoc structure to describe blocked memory layout */ struct layout_desc_t { data_type_t dt; int ndims; dims_t id; dims_t dims; strides_t strides; }; status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, layout_desc_t &ld) { using namespace mkldnn::impl::memory_format; using namespace mkldnn::impl::data_type; auto md = memory_desc_wrapper(md_); auto bd = md.blocking_desc(); ld.ndims = 0; ld.dt = md.data_type(); auto P = [&ld](int id, int dim, ptrdiff_t stride) { assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0])); ld.id[ld.ndims] = id; ld.dims[ld.ndims] = dim; ld.strides[ld.ndims] = stride; ++ld.ndims; }; /* special cases */ switch (md.format()) { case memory_format::undef: case memory_format::any: case hwio_s8s8: case hwigo_s8s8: case gOIhw4o4i_s8s8: case gOIhw2i8o4i_s8s8: case gOIw4i16o4i_s8s8: case OIw4i16o4i_s8s8: case gOIhw4i16o4i_s8s8: case OIhw4i16o4i_s8s8: case Goihw16g_s8s8: case Goiw16g_s8s8: case wino_fmt: return invalid_arguments; case OIw4i16o4i: case OIhw4i16o4i: P(0, bd.padding_dims[0] / 16, bd.strides[0][0]); P(0, 16, 4); P(1, bd.padding_dims[1] / 16, bd.strides[0][1]); P(1, 4, 16*4); P(1, 4, 1); P(2, bd.padding_dims[2], bd.strides[0][2]); if (md.format() == OIhw4i16o4i) P(3, bd.padding_dims[3], bd.strides[0][3]); return success; case OIw8i16o2i: case OIhw8i16o2i: case IOhw8i16o2i: case OIdhw8i16o2i: P(0, bd.padding_dims[0] / 16, bd.strides[0][0]); P(0, 16, 2); P(1, bd.padding_dims[1] / 16, bd.strides[0][1]); P(1, 8, 16*2); P(1, 2, 1); P(2, bd.padding_dims[2], bd.strides[0][2]); if (utils::one_of(md.format(), OIhw8i16o2i, IOhw8i16o2i) || md.format() == OIdhw8i16o2i) P(3, bd.padding_dims[3], bd.strides[0][3]); if (md.format() == OIdhw8i16o2i) P(4, bd.padding_dims[4], bd.strides[0][4]); return success; case OIw8o16i2o: case IOw8o16i2o: case OIhw8o16i2o: case IOhw8o16i2o: case OIdhw8o16i2o: case IOdhw8o16i2o: P(0, bd.padding_dims[0] / 16, bd.strides[0][0]); P(0, 8, 16*2); P(0, 2, 1); P(1, bd.padding_dims[1] / 16, bd.strides[0][1]); P(1, 16, 2); P(2, bd.padding_dims[2], bd.strides[0][2]); if (utils::one_of(md.format(), OIhw8o16i2o, IOhw8o16i2o, OIdhw8o16i2o, IOdhw8o16i2o)) P(3, bd.padding_dims[3], bd.strides[0][3]); if (utils::one_of(md.format(), OIdhw8o16i2o, IOdhw8o16i2o)) P(4, bd.padding_dims[4], bd.strides[0][4]); return success; case gOIhw2i8o4i: P(0, bd.padding_dims[0], bd.strides[0][0]); P(1, bd.padding_dims[1] / 8, bd.strides[0][1]); P(1, 8, 4); P(2, bd.padding_dims[2] / 8, bd.strides[0][2]); P(2, 2, 8*4); P(2, 4, 1); P(3, bd.padding_dims[3], bd.strides[0][3]); P(4, bd.padding_dims[4], bd.strides[0][4]); return success; case gOIw4i16o4i: case gOIhw4i16o4i: P(0, bd.padding_dims[0], bd.strides[0][0]); P(1, bd.padding_dims[1] / 16, bd.strides[0][1]); P(1, 16, 4); P(2, bd.padding_dims[2] / 16, bd.strides[0][2]); P(2, 4, 16*4); P(2, 4, 1); P(3, bd.padding_dims[3], bd.strides[0][3]); if (md.format() == gOIhw4i16o4i) P(4, bd.padding_dims[4], bd.strides[0][4]); return success; case gOIw8i16o2i: case gOIhw8i16o2i: case gIOhw8i16o2i: case gOIdhw8i16o2i: P(0, bd.padding_dims[0], bd.strides[0][0]); P(1, bd.padding_dims[1] / 16, bd.strides[0][1]); P(1, 16, 2); P(2, bd.padding_dims[2] / 16, bd.strides[0][2]); P(2, 8, 16*2); P(2, 2, 1); P(3, bd.padding_dims[3], bd.strides[0][3]); if (utils::one_of(md.format(), gOIhw8i16o2i, gIOhw8i16o2i) || md.format() == gOIdhw8i16o2i) P(4, bd.padding_dims[4], bd.strides[0][4]); if (md.format() == gOIdhw8i16o2i) P(5, bd.padding_dims[5], bd.strides[0][5]); return success; case gOIw8o16i2o: case gIOw8o16i2o: case gOIhw8o16i2o: case gIOhw8o16i2o: case gOIdhw8o16i2o: case gIOdhw8o16i2o: P(0, bd.padding_dims[0], bd.strides[0][0]); P(1, bd.padding_dims[1] / 16, bd.strides[0][1]); P(1, 8, 16*2); P(1, 2, 1); P(2, bd.padding_dims[2] / 16, bd.strides[0][2]); P(2, 16, 2); P(3, bd.padding_dims[3], bd.strides[0][3]); if (utils::one_of(md.format(), gOIhw8o16i2o, gIOhw8o16i2o, gOIdhw8o16i2o, gIOdhw8o16i2o)) P(4, bd.padding_dims[4], bd.strides[0][4]); if (utils::one_of(md.format(), gOIdhw8o16i2o, gIOdhw8o16i2o)) P(5, bd.padding_dims[5], bd.strides[0][5]); return success; default: break; } /* regular blocked format */ for (int d = 0; d < md.ndims(); ++d) { P(d, bd.padding_dims[d] / bd.block_dims[d], bd.strides[0][d]); if (bd.block_dims[d] != 1) P(d, bd.block_dims[d], bd.strides[1][d]); } return success; } status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, const primitive_attr_t *attr) { auto im_d = memory_desc_wrapper(imd); auto om_d = memory_desc_wrapper(omd); bool ok = true && im_d.is_blocking_desc() && om_d.is_blocking_desc() && !im_d.has_zero_dim() && !om_d.has_zero_dim(); if (!ok) return unimplemented; /* padding_dim consistency check */ for (int d = 0; d < im_d.ndims(); ++d) { const auto pdim = im_d.blocking_desc().padding_dims[d]; bool ok = true && pdim == om_d.blocking_desc().padding_dims[d] && pdim % im_d.blocking_desc().block_dims[d] == 0 && pdim % om_d.blocking_desc().block_dims[d] == 0; if (!ok) return unimplemented; } layout_desc_t ild, old; status_t status = cvt_mem_desc_to_layout_desc(imd, ild); if (status != success) return status; status = cvt_mem_desc_to_layout_desc(omd, old); if (status != success) return status; p.itype = ild.dt; p.otype = old.dt; p.scale_type = attr->output_scales_.has_default_values() ? scale_type_t::NONE : (attr->output_scales_.mask_ == 0 ? scale_type_t::COMMON : scale_type_t::MANY); ptrdiff_t ss[max_ndims] = {0}; if (p.scale_type == scale_type_t::MANY) { ptrdiff_t last_ss = 1; for (int d = old.ndims - 1; d >=0; --d) { assert((d == 0 || old.id[d - 1] <= old.id[d]) && "logical dimensions should be in ascending order"); if (attr->output_scales_.mask_ & (1 << old.id[d])) { ss[d] = last_ss; last_ss *= old.dims[d]; } } } int ndims = 0; int i_pos = 0; /* state for input -- current dimension */ int o_pos = 0; /* state for output -- current dimension */ while (i_pos < ild.ndims && o_pos < old.ndims) { assert(ild.id[i_pos] == old.id[o_pos]); if (ild.id[i_pos] != old.id[o_pos]) return runtime_error; assert(ndims < max_ndims); if (ndims == max_ndims) return runtime_error; if (ild.dims[i_pos] == old.dims[o_pos]) { p.nodes[ndims].n = ild.dims[i_pos]; p.nodes[ndims].is = ild.strides[i_pos]; p.nodes[ndims].os = old.strides[o_pos]; p.nodes[ndims].ss = ss[o_pos]; ++ndims; ++i_pos; ++o_pos; } else if (ild.dims[i_pos] < old.dims[o_pos]) { assert(old.dims[o_pos] % ild.dims[i_pos] == 0); int factor = old.dims[o_pos] / ild.dims[i_pos]; p.nodes[ndims].n = ild.dims[i_pos]; p.nodes[ndims].is = ild.strides[i_pos]; p.nodes[ndims].os = old.strides[o_pos] * factor; p.nodes[ndims].ss = ss[o_pos] * factor; ++ndims; ++i_pos; old.dims[o_pos] = factor; } else if (ild.dims[i_pos] > old.dims[o_pos]) { assert(ild.dims[i_pos] % old.dims[o_pos] == 0); int factor = ild.dims[i_pos] / old.dims[o_pos]; p.nodes[ndims].n = old.dims[o_pos]; p.nodes[ndims].is = ild.strides[i_pos] * factor; p.nodes[ndims].os = old.strides[o_pos]; p.nodes[ndims].ss = ss[o_pos]; ++ndims; ++o_pos; ild.dims[i_pos] = factor; } } p.ndims = ndims; dims_t zero_pos = {0}; p.ioff = memory_desc_wrapper(imd).off_v(zero_pos); p.ooff = memory_desc_wrapper(omd).off_v(zero_pos); const int sum_idx = attr->post_ops_.find(primitive_kind::sum); p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale; return success; } void prb_normalize(prb_t &p) { for (int d = 0; d < p.ndims; ++d) { int min_pos = d; for (int j = d + 1; j < p.ndims; ++j) { bool new_min = false || p.nodes[j].os < p.nodes[min_pos].os || (true && p.nodes[j].os == p.nodes[min_pos].os && p.nodes[j].n < p.nodes[min_pos].n); if (new_min) min_pos = j; } if (min_pos != d) nstl::swap(p.nodes[d], p.nodes[min_pos]); } } void prb_simplify(prb_t &p) { #if defined(__GNUC__) && __GNUC__ >= 4 /* GCC produces bogus array subscript is above array bounds warning for * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Warray-bounds" #endif for (int d = 0; d < p.ndims - 1; ++d) { auto &this_node = p.nodes[d + 0]; auto &next_node = p.nodes[d + 1]; const bool fold = false || next_node.n == (size_t)1 // trivial case, just drop next node || (true // or real folding if possible && next_node.is == (ptrdiff_t)this_node.n * this_node.is && next_node.os == (ptrdiff_t)this_node.n * this_node.os && next_node.ss == (ptrdiff_t)this_node.n * this_node.ss); if (fold) { this_node.n *= next_node.n; for (int j = d + 2; j < p.ndims; ++j) p.nodes[j - 1] = p.nodes[j]; --p.ndims; --d; // make another try } } #if defined(__GNUC__) && __GNUC__ >= 4 #pragma GCC diagnostic pop #endif } void prb_node_split(prb_t &p, int dim, size_t n1) { assert(dim < p.ndims); assert(p.ndims < max_ndims); assert(p.nodes[dim].n % n1 == 0); p.ndims += 1; for (int d = p.ndims; d > dim + 1; --d) p.nodes[d] = p.nodes[d - 1]; p.nodes[dim + 1].n = p.nodes[dim].n / n1; p.nodes[dim + 1].is = p.nodes[dim].is * n1; p.nodes[dim + 1].os = p.nodes[dim].os * n1; p.nodes[dim + 1].ss = p.nodes[dim].ss * n1; p.nodes[dim].n = n1; } void prb_node_swap(prb_t &p, int d0, int d1) { assert(d0 < p.ndims); assert(d1 < p.ndims); assert(p.ndims < max_ndims); if (d0 == d1) return; nstl::swap(p.nodes[d0], p.nodes[d1]); } void prb_node_move(prb_t &p, int d0, int d1) { assert(d0 < p.ndims); assert(d1 < p.ndims); assert(p.ndims < max_ndims); if (d0 == d1) return; node_t node = p.nodes[d0]; if (d0 < d1) for (int d = d0; d < d1; ++d) p.nodes[d] = p.nodes[d + 1]; else for (int d = d0; d > d1; --d) p.nodes[d] = p.nodes[d - 1]; p.nodes[d1] = node; } void prb_dump(const prb_t &p) { printf("@@@ type:%s:%s ndims:%d ", mkldnn_dt2str(p.itype), mkldnn_dt2str(p.otype), p.ndims); for (int d = 0; d < p.ndims; ++d) printf("[%zu:%td:%td:%td]", p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, p.nodes[d].ss); printf(" off:%zu:%zu\n", p.ioff, p.ooff); } } } } }