1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#include "op_blas.hpp"
#ifdef HAVE_LAPACK
#include "opencv_lapack.h"
#endif
#include <iostream>
namespace cv
{
namespace dnn
{
void gemm(InputArray A, InputArray B, double alpha, InputOutputArray C, double beta, int flags)
{
if (C.isMat())
gemmCPU(A.getMat(), B.getMat(), alpha, C.getMatRef(), beta, flags);
else
{
cv::gemm(A, B, alpha, (beta == 0) ? noArray() : C, beta, C, flags);
}
}
inline void SwapRowCols(const Mat &A, int &rows, int &cols, bool isTrans)
{
CV_DbgAssert(A.dims == 2);
rows = (isTrans) ? A.cols : A.rows;
cols = (isTrans) ? A.rows : A.cols;
}
class GEMMInvoker : public ParallelLoopBody
{
public:
GEMMInvoker(const Mat* _a, const Mat* _b, double _alpha, Mat* _c, double _beta)
{
a = _a;
b = _b;
c = _c;
alpha = _alpha;
beta = _beta;
}
void operator()(const Range& range) const
{
int mmax = a->rows;
int nmax = range.end - range.start;
int kmax = a->cols;
int m, n, k;
AutoBuffer<float> buf(nmax);
float* ptr = buf;
if( mmax %2 != 0 )
memset(ptr, 0, nmax*sizeof(ptr[0]));
for( m = 0; m < mmax; m += 2 )
{
float* dst0 = c->ptr<float>(m) + range.start;
float* dst1 = m+1 < mmax ? c->ptr<float>(m+1) + range.start : ptr;
const float* aptr0 = a->ptr<float>(m);
const float* aptr1 = m+1 < mmax ? a->ptr<float>(m+1) : aptr0;
if( beta != 1 )
{
if( beta == 0 )
for( n = 0; n < nmax; n++ )
{
dst0[n] = 0.f;
dst1[n] = 0.f;
}
else
for( n = 0; n < nmax; n++ )
{
dst0[n] *= (float)beta;
dst1[n] *= (float)beta;
}
}
for( k = 0; k < kmax; k++ )
{
float alpha0 = (float)(alpha*aptr0[k]);
float alpha1 = (float)(alpha*aptr1[k]);
const float* bptr = b->ptr<float>(k) + range.start;
for( n = 0; n < nmax; n++ )
{
float d0 = dst0[n] + alpha0*bptr[n];
float d1 = dst1[n] + alpha1*bptr[n];
dst0[n] = d0;
dst1[n] = d1;
}
}
}
}
const Mat *a, *b;
Mat* c;
double alpha, beta;
};
void gemmCPU(const Mat &A, const Mat &B, double alpha, Mat &C, double beta, int flags /*= 0*/)
{
#ifdef HAVE_LAPACK
bool transA = static_cast<bool>(flags & GEMM_1_T);
bool transB = static_cast<bool>(flags & GEMM_2_T);
bool transC = static_cast<bool>(flags & GEMM_3_T);
int Arows, Acols, Brows, Bcols, Crows, Ccols;
SwapRowCols(A, Arows, Acols, transA);
SwapRowCols(B, Brows, Bcols, transB);
SwapRowCols(C, Crows, Ccols, transC);
CV_Assert(!(flags & GEMM_3_T));
CV_Assert(Acols == Brows && Arows == Crows && Bcols == Ccols);
CV_Assert(A.isContinuous() && B.isContinuous() && C.isContinuous());
CV_Assert(A.type() == B.type() && B.type() == C.type());
CV_Assert(A.data != C.data && B.data != C.data);
if (C.type() == CV_32F)
{
cblas_sgemm(CblasRowMajor, transA ? CblasTrans : CblasNoTrans, transB ? CblasTrans : CblasNoTrans,
Arows, Bcols, Acols,
(float)alpha, A.ptr<float>(), A.cols,
B.ptr<float>(), B.cols,
(float)beta, C.ptr<float>(), C.cols);
}
else if (C.type() == CV_64F)
{
//TODO: Should be tested
cblas_dgemm(CblasRowMajor, transA ? CblasTrans : CblasNoTrans, transB ? CblasTrans : CblasNoTrans,
Arows, Bcols, Acols,
alpha, A.ptr<double>(), A.cols,
B.ptr<double>(), B.cols,
beta, C.ptr<double>(), C.cols);
}
else
{
CV_Error(Error::BadDepth, "Only floating point types are supported");
}
#else
if( C.type() == CV_32F && flags == 0 )
{
GEMMInvoker invoker(&A, &B, alpha, &C, beta);
double granularity = 10000000./((double)A.rows*A.cols);
parallel_for_(Range(0, B.cols), invoker, granularity);
}
else
cv::gemm(A, B, alpha, C, beta, C, flags);
#endif
}
int getBlasThreads()
{
#ifdef OPENBLAS_VERSION
return openblas_get_num_threads();
#else
return 1;
#endif
}
void setBlasThreads(int numThreads)
{
#ifdef OPENBLAS_VERSION
openblas_set_num_threads(numThreads);
goto_set_num_threads(numThreads);
#else
(void)numThreads; //suppress compilers' warning
#endif
}
}
}