1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Copyright (C) 2016, Intel, Inc., all rights reserved.
// Third party copyrights are property of their respective owners.
#define LOCAL_SIZE_X 64
#define BLOCK_SIZE_X 3
__kernel void tmm(__global float *A, int m, int n, float alpha, __global float *D)
{
int lidX = get_local_id(0);
uint lsizeX = get_local_size(0);
uint matI = get_group_id(1);
uint matJ = get_group_id(0);
if (matI < matJ)
return;
__local float4 a[LOCAL_SIZE_X], b[LOCAL_SIZE_X];
float4 result;
__local uint cnt;
result = 0;
cnt = 0;
barrier(CLK_LOCAL_MEM_FENCE);
do {
// load block data to SLM.
int global_block_base = (lidX + cnt * lsizeX) * BLOCK_SIZE_X;
float4 pa[BLOCK_SIZE_X], pb[BLOCK_SIZE_X];
#pragma unroll
for(uint j = 0; j < BLOCK_SIZE_X && (cnt * lsizeX + lidX) * BLOCK_SIZE_X < n / 4; j++) {
pa[j] = *(__global float4*)&A[matI * n + (global_block_base + j) * 4];
if (matI != matJ)
pb[j] = *(__global float4*)&A[matJ * n + (global_block_base + j) * 4];
else
pb[j] = pa[j];
}
// zero the data out-of-boundary.
if (global_block_base + BLOCK_SIZE_X - 1 >= n/4) {
#pragma unroll
for(int i = 0; i < BLOCK_SIZE_X; i++) {
if (global_block_base + i >= n/4)
pb[i] = 0;
}
}
pb[0] *= pa[0];
for(int j = 1; j < BLOCK_SIZE_X; j++)
pb[0] = fma(pb[j], pa[j], pb[0]);
b[lidX] = pb[0];
barrier(CLK_LOCAL_MEM_FENCE);
// perform reduce add
for(int offset = LOCAL_SIZE_X / 2; offset > 0; offset >>= 1) {
if (lidX < offset)
b[lidX] += b[(lidX + offset)];
barrier(CLK_LOCAL_MEM_FENCE);
}
if (lidX == 0) {
result += b[0];
cnt++;
}
barrier(CLK_LOCAL_MEM_FENCE);
} while(cnt * BLOCK_SIZE_X * lsizeX < n / 4);
if (lidX == 0) {
float ret = (result.s0 + result.s1 + result.s2 + result.s3) * alpha;
D[matI * m + matJ] = ret;
if (matI != matJ)
D[matJ * m + matI] = ret;
}
}