Commit a65b5df5 authored by Alexander Alekhin's avatar Alexander Alekhin

Merge pull request #10416 from fenrus75:avx512

parents 2370c8a0 898ca382
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SSE / SSE2 (always available on 64-bit CPUs) # SSE / SSE2 (always available on 64-bit CPUs)
# SSE3 / SSSE3 # SSE3 / SSSE3
# SSE4_1 / SSE4_2 / POPCNT # SSE4_1 / SSE4_2 / POPCNT
# AVX / AVX2 / AVX512 # AVX / AVX2 / AVX_512F
# FMA3 # FMA3
# CPU_{opt}_SUPPORTED=ON/OFF - compiler support (possibly with additional flag) # CPU_{opt}_SUPPORTED=ON/OFF - compiler support (possibly with additional flag)
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
# #
# CPU_DISPATCH_FLAGS_${opt} - flags for source files compiled separately (<name>.avx2.cpp) # CPU_DISPATCH_FLAGS_${opt} - flags for source files compiled separately (<name>.avx2.cpp)
set(CPU_ALL_OPTIMIZATIONS "SSE;SSE2;SSE3;SSSE3;SSE4_1;SSE4_2;POPCNT;AVX;FP16;AVX2;FMA3") # without AVX512 set(CPU_ALL_OPTIMIZATIONS "SSE;SSE2;SSE3;SSSE3;SSE4_1;SSE4_2;POPCNT;AVX;FP16;AVX2;FMA3;AVX_512F")
list(APPEND CPU_ALL_OPTIMIZATIONS NEON VFPV3 FP16) list(APPEND CPU_ALL_OPTIMIZATIONS NEON VFPV3 FP16)
list(APPEND CPU_ALL_OPTIMIZATIONS VSX) list(APPEND CPU_ALL_OPTIMIZATIONS VSX)
list(REMOVE_DUPLICATES CPU_ALL_OPTIMIZATIONS) list(REMOVE_DUPLICATES CPU_ALL_OPTIMIZATIONS)
...@@ -145,7 +145,7 @@ elseif(" ${CMAKE_CXX_FLAGS} " MATCHES " -march=native | -xHost | /QxHost ") ...@@ -145,7 +145,7 @@ elseif(" ${CMAKE_CXX_FLAGS} " MATCHES " -march=native | -xHost | /QxHost ")
endif() endif()
if(X86 OR X86_64) if(X86 OR X86_64)
ocv_update(CPU_KNOWN_OPTIMIZATIONS "SSE;SSE2;SSE3;SSSE3;SSE4_1;POPCNT;SSE4_2;FP16;FMA3;AVX;AVX2") # without AVX512 ocv_update(CPU_KNOWN_OPTIMIZATIONS "SSE;SSE2;SSE3;SSSE3;SSE4_1;POPCNT;SSE4_2;FP16;FMA3;AVX;AVX2;AVX_512F")
ocv_update(CPU_SSE_TEST_FILE "${OpenCV_SOURCE_DIR}/cmake/checks/cpu_sse.cpp") ocv_update(CPU_SSE_TEST_FILE "${OpenCV_SOURCE_DIR}/cmake/checks/cpu_sse.cpp")
ocv_update(CPU_SSE2_TEST_FILE "${OpenCV_SOURCE_DIR}/cmake/checks/cpu_sse2.cpp") ocv_update(CPU_SSE2_TEST_FILE "${OpenCV_SOURCE_DIR}/cmake/checks/cpu_sse2.cpp")
...@@ -157,11 +157,11 @@ if(X86 OR X86_64) ...@@ -157,11 +157,11 @@ if(X86 OR X86_64)
ocv_update(CPU_AVX_TEST_FILE "${OpenCV_SOURCE_DIR}/cmake/checks/cpu_avx.cpp") ocv_update(CPU_AVX_TEST_FILE "${OpenCV_SOURCE_DIR}/cmake/checks/cpu_avx.cpp")
ocv_update(CPU_AVX2_TEST_FILE "${OpenCV_SOURCE_DIR}/cmake/checks/cpu_avx2.cpp") ocv_update(CPU_AVX2_TEST_FILE "${OpenCV_SOURCE_DIR}/cmake/checks/cpu_avx2.cpp")
ocv_update(CPU_FP16_TEST_FILE "${OpenCV_SOURCE_DIR}/cmake/checks/cpu_fp16.cpp") ocv_update(CPU_FP16_TEST_FILE "${OpenCV_SOURCE_DIR}/cmake/checks/cpu_fp16.cpp")
ocv_update(CPU_AVX512_TEST_FILE "${OpenCV_SOURCE_DIR}/cmake/checks/cpu_avx512.cpp") ocv_update(CPU_AVX_512F_TEST_FILE "${OpenCV_SOURCE_DIR}/cmake/checks/cpu_avx512.cpp")
if(NOT OPENCV_CPU_OPT_IMPLIES_IGNORE) if(NOT OPENCV_CPU_OPT_IMPLIES_IGNORE)
ocv_update(CPU_AVX512_IMPLIES "AVX2") ocv_update(CPU_AVX_512F_IMPLIES "AVX2")
ocv_update(CPU_AVX512_FORCE "") # Don't force other optimizations ocv_update(CPU_AVX_512F_FORCE "") # Don't force other optimizations
ocv_update(CPU_AVX2_IMPLIES "AVX;FMA3;FP16") ocv_update(CPU_AVX2_IMPLIES "AVX;FMA3;FP16")
ocv_update(CPU_FMA3_IMPLIES "AVX2") ocv_update(CPU_FMA3_IMPLIES "AVX2")
ocv_update(CPU_FMA3_FORCE "") # Don't force other optimizations ocv_update(CPU_FMA3_FORCE "") # Don't force other optimizations
...@@ -205,7 +205,7 @@ if(X86 OR X86_64) ...@@ -205,7 +205,7 @@ if(X86 OR X86_64)
if(NOT X86_64) # x64 compiler doesn't support /arch:sse if(NOT X86_64) # x64 compiler doesn't support /arch:sse
ocv_intel_compiler_optimization_option(SSE "-msse" "/arch:SSE") ocv_intel_compiler_optimization_option(SSE "-msse" "/arch:SSE")
endif() endif()
#ocv_intel_compiler_optimization_option(AVX512 "-march=core-avx512") ocv_intel_compiler_optimization_option(AVX_512F "-march=common-avx512" "/arch:COMMON-AVX512")
elseif(CMAKE_COMPILER_IS_GNUCXX) elseif(CMAKE_COMPILER_IS_GNUCXX)
ocv_update(CPU_AVX2_FLAGS_ON "-mavx2") ocv_update(CPU_AVX2_FLAGS_ON "-mavx2")
ocv_update(CPU_FP16_FLAGS_ON "-mf16c") ocv_update(CPU_FP16_FLAGS_ON "-mf16c")
...@@ -219,7 +219,8 @@ if(X86 OR X86_64) ...@@ -219,7 +219,8 @@ if(X86 OR X86_64)
ocv_update(CPU_SSE2_FLAGS_ON "-msse2") ocv_update(CPU_SSE2_FLAGS_ON "-msse2")
ocv_update(CPU_SSE_FLAGS_ON "-msse") ocv_update(CPU_SSE_FLAGS_ON "-msse")
if(NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS "5.0") if(NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS "5.0")
ocv_update(CPU_AVX512_FLAGS_ON "-mavx512f -mavx512pf -mavx512er -mavx512cd -mavx512vl -mavx512bw -mavx512dq -mavx512ifma -mavx512vbmi") # -mavx512f -mavx512pf -mavx512er -mavx512cd -mavx512vl -mavx512bw -mavx512dq -mavx512ifma -mavx512vbmi
ocv_update(CPU_AVX_512F_FLAGS_ON "-mavx512f")
endif() endif()
elseif(MSVC) elseif(MSVC)
ocv_update(CPU_AVX2_FLAGS_ON "/arch:AVX2") ocv_update(CPU_AVX2_FLAGS_ON "/arch:AVX2")
......
...@@ -82,6 +82,10 @@ ...@@ -82,6 +82,10 @@
# include <immintrin.h> # include <immintrin.h>
# define CV_AVX2 1 # define CV_AVX2 1
#endif #endif
#ifdef CV_CPU_COMPILE_AVX_512F
# include <immintrin.h>
# define CV_AVX_512F 1
#endif
#ifdef CV_CPU_COMPILE_FMA3 #ifdef CV_CPU_COMPILE_FMA3
# define CV_FMA3 1 # define CV_FMA3 1
#endif #endif
......
...@@ -165,6 +165,21 @@ ...@@ -165,6 +165,21 @@
#endif #endif
#define __CV_CPU_DISPATCH_CHAIN_FMA3(fn, args, mode, ...) CV_CPU_CALL_FMA3(fn, args); __CV_EXPAND(__CV_CPU_DISPATCH_CHAIN_ ## mode(fn, args, __VA_ARGS__)) #define __CV_CPU_DISPATCH_CHAIN_FMA3(fn, args, mode, ...) CV_CPU_CALL_FMA3(fn, args); __CV_EXPAND(__CV_CPU_DISPATCH_CHAIN_ ## mode(fn, args, __VA_ARGS__))
#if !defined CV_DISABLE_OPTIMIZATION && defined CV_ENABLE_INTRINSICS && defined CV_CPU_COMPILE_AVX_512F
# define CV_TRY_AVX_512F 1
# define CV_CPU_HAS_SUPPORT_AVX_512F 1
# define CV_CPU_CALL_AVX_512F(fn, args) return (opt_AVX_512F::fn args)
#elif !defined CV_DISABLE_OPTIMIZATION && defined CV_ENABLE_INTRINSICS && defined CV_CPU_DISPATCH_COMPILE_AVX_512F
# define CV_TRY_AVX_512F 1
# define CV_CPU_HAS_SUPPORT_AVX_512F (cv::checkHardwareSupport(CV_CPU_AVX_512F))
# define CV_CPU_CALL_AVX_512F(fn, args) if (CV_CPU_HAS_SUPPORT_AVX_512F) return (opt_AVX_512F::fn args)
#else
# define CV_TRY_AVX_512F 0
# define CV_CPU_HAS_SUPPORT_AVX_512F 0
# define CV_CPU_CALL_AVX_512F(fn, args)
#endif
#define __CV_CPU_DISPATCH_CHAIN_AVX_512F(fn, args, mode, ...) CV_CPU_CALL_AVX_512F(fn, args); __CV_EXPAND(__CV_CPU_DISPATCH_CHAIN_ ## mode(fn, args, __VA_ARGS__))
#if !defined CV_DISABLE_OPTIMIZATION && defined CV_ENABLE_INTRINSICS && defined CV_CPU_COMPILE_NEON #if !defined CV_DISABLE_OPTIMIZATION && defined CV_ENABLE_INTRINSICS && defined CV_CPU_COMPILE_NEON
# define CV_TRY_NEON 1 # define CV_TRY_NEON 1
# define CV_CPU_HAS_SUPPORT_NEON 1 # define CV_CPU_HAS_SUPPORT_NEON 1
......
...@@ -13,7 +13,7 @@ endif() ...@@ -13,7 +13,7 @@ endif()
set(the_description "Deep neural network module. It allows to load models from different frameworks and to make forward pass") set(the_description "Deep neural network module. It allows to load models from different frameworks and to make forward pass")
ocv_add_dispatched_file("layers/layers_common" AVX AVX2) ocv_add_dispatched_file("layers/layers_common" AVX AVX2 AVX_512F)
ocv_add_module(dnn opencv_core opencv_imgproc WRAP python matlab java js) ocv_add_module(dnn opencv_core opencv_imgproc WRAP python matlab java js)
ocv_warnings_disable(CMAKE_CXX_FLAGS -Wno-shadow -Wno-parentheses -Wmaybe-uninitialized -Wsign-promo ocv_warnings_disable(CMAKE_CXX_FLAGS -Wno-shadow -Wno-parentheses -Wmaybe-uninitialized -Wsign-promo
......
...@@ -345,10 +345,11 @@ public: ...@@ -345,10 +345,11 @@ public:
bool is1x1_; bool is1x1_;
bool useAVX; bool useAVX;
bool useAVX2; bool useAVX2;
bool useAVX512;
ParallelConv() ParallelConv()
: input_(0), weights_(0), output_(0), ngroups_(0), nstripes_(0), : input_(0), weights_(0), output_(0), ngroups_(0), nstripes_(0),
biasvec_(0), reluslope_(0), activ_(0), is1x1_(false), useAVX(false), useAVX2(false) biasvec_(0), reluslope_(0), activ_(0), is1x1_(false), useAVX(false), useAVX2(false), useAVX512(false)
{} {}
static void run( const Mat& input, Mat& output, const Mat& weights, static void run( const Mat& input, Mat& output, const Mat& weights,
...@@ -383,6 +384,7 @@ public: ...@@ -383,6 +384,7 @@ public:
p.is1x1_ = kernel == Size(0,0) && pad == Size(0, 0); p.is1x1_ = kernel == Size(0,0) && pad == Size(0, 0);
p.useAVX = checkHardwareSupport(CPU_AVX); p.useAVX = checkHardwareSupport(CPU_AVX);
p.useAVX2 = checkHardwareSupport(CPU_AVX2); p.useAVX2 = checkHardwareSupport(CPU_AVX2);
p.useAVX512 = CV_CPU_HAS_SUPPORT_AVX_512F;
int ncn = std::min(inpCn, (int)BLK_SIZE_CN); int ncn = std::min(inpCn, (int)BLK_SIZE_CN);
p.ofstab_.resize(kernel.width*kernel.height*ncn); p.ofstab_.resize(kernel.width*kernel.height*ncn);
...@@ -562,6 +564,13 @@ public: ...@@ -562,6 +564,13 @@ public:
// now compute dot product of the weights // now compute dot product of the weights
// and im2row-transformed part of the tensor // and im2row-transformed part of the tensor
int bsz = ofs1 - ofs0; int bsz = ofs1 - ofs0;
#if CV_TRY_AVX_512F
/* AVX512 convolution requires an alignment of 16, and ROI is only there for larger vector sizes */
if(useAVX512)
opt_AVX_512F::fastConv(wptr, wstep, biasptr, rowbuf0, data_out0 + ofs0,
outShape, bsz, vsz, vsz_a, relu, cn0 == 0);
else
#endif
#if CV_TRY_AVX2 #if CV_TRY_AVX2
if(useAVX2) if(useAVX2)
opt_AVX2::fastConv(wptr, wstep, biasptr, rowbuf0, data_out0 + ofs0, opt_AVX2::fastConv(wptr, wstep, biasptr, rowbuf0, data_out0 + ofs0,
...@@ -1093,6 +1102,7 @@ public: ...@@ -1093,6 +1102,7 @@ public:
nstripes_ = nstripes; nstripes_ = nstripes;
useAVX = checkHardwareSupport(CPU_AVX); useAVX = checkHardwareSupport(CPU_AVX);
useAVX2 = checkHardwareSupport(CPU_AVX2); useAVX2 = checkHardwareSupport(CPU_AVX2);
useAVX512 = CV_CPU_HAS_SUPPORT_AVX_512F;
} }
void operator()(const Range& range_) const void operator()(const Range& range_) const
...@@ -1110,6 +1120,11 @@ public: ...@@ -1110,6 +1120,11 @@ public:
size_t bstep = b_->step1(); size_t bstep = b_->step1();
size_t cstep = c_->step1(); size_t cstep = c_->step1();
#if CV_TRY_AVX_512F
if( useAVX512 )
opt_AVX_512F::fastGEMM( aptr, astep, bptr, bstep, cptr, cstep, mmax, kmax, nmax );
else
#endif
#if CV_TRY_AVX2 #if CV_TRY_AVX2
if( useAVX2 ) if( useAVX2 )
opt_AVX2::fastGEMM( aptr, astep, bptr, bstep, cptr, cstep, mmax, kmax, nmax ); opt_AVX2::fastGEMM( aptr, astep, bptr, bstep, cptr, cstep, mmax, kmax, nmax );
...@@ -1214,6 +1229,7 @@ public: ...@@ -1214,6 +1229,7 @@ public:
int nstripes_; int nstripes_;
bool useAVX; bool useAVX;
bool useAVX2; bool useAVX2;
bool useAVX512;
}; };
class Col2ImInvoker : public cv::ParallelLoopBody class Col2ImInvoker : public cv::ParallelLoopBody
......
...@@ -139,7 +139,7 @@ public: ...@@ -139,7 +139,7 @@ public:
class FullyConnected : public ParallelLoopBody class FullyConnected : public ParallelLoopBody
{ {
public: public:
FullyConnected() : srcMat(0), weights(0), biasMat(0), activ(0), dstMat(0), nstripes(0), useAVX(false), useAVX2(false) {} FullyConnected() : srcMat(0), weights(0), biasMat(0), activ(0), dstMat(0), nstripes(0), useAVX(false), useAVX2(false), useAVX512(false) {}
static void run(const Mat& srcMat, const Mat& weights, const Mat& biasMat, static void run(const Mat& srcMat, const Mat& weights, const Mat& biasMat,
Mat& dstMat, const ActivationLayer* activ, int nstripes) Mat& dstMat, const ActivationLayer* activ, int nstripes)
...@@ -161,6 +161,7 @@ public: ...@@ -161,6 +161,7 @@ public:
p.activ = activ; p.activ = activ;
p.useAVX = checkHardwareSupport(CPU_AVX); p.useAVX = checkHardwareSupport(CPU_AVX);
p.useAVX2 = checkHardwareSupport(CPU_AVX2); p.useAVX2 = checkHardwareSupport(CPU_AVX2);
p.useAVX512 = CV_CPU_HAS_SUPPORT_AVX_512F;
parallel_for_(Range(0, nstripes), p, nstripes); parallel_for_(Range(0, nstripes), p, nstripes);
} }
...@@ -195,6 +196,11 @@ public: ...@@ -195,6 +196,11 @@ public:
memcpy(sptr, sptr_, vecsize*sizeof(sptr[0])); memcpy(sptr, sptr_, vecsize*sizeof(sptr[0]));
#if CV_TRY_AVX_512F
if( useAVX512 )
opt_AVX_512F::fastGEMM1T( sptr, wptr, wstep, biasptr, dptr, nw, vecsize);
else
#endif
#if CV_TRY_AVX2 #if CV_TRY_AVX2
if( useAVX2 ) if( useAVX2 )
opt_AVX2::fastGEMM1T( sptr, wptr, wstep, biasptr, dptr, nw, vecsize); opt_AVX2::fastGEMM1T( sptr, wptr, wstep, biasptr, dptr, nw, vecsize);
...@@ -255,6 +261,7 @@ public: ...@@ -255,6 +261,7 @@ public:
int nstripes; int nstripes;
bool useAVX; bool useAVX;
bool useAVX2; bool useAVX2;
bool useAVX512;
}; };
#ifdef HAVE_OPENCL #ifdef HAVE_OPENCL
......
...@@ -72,7 +72,7 @@ void fastConv( const float* weights, size_t wstep, const float* bias, ...@@ -72,7 +72,7 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
int outCn = outShape[1]; int outCn = outShape[1];
size_t outPlaneSize = outShape[2]*outShape[3]; size_t outPlaneSize = outShape[2]*outShape[3];
float r0 = 1.f, r1 = 1.f, r2 = 1.f; float r0 = 1.f, r1 = 1.f, r2 = 1.f;
__m256 vr0 = _mm256_set1_ps(1.f), vr1 = vr0, vr2 = vr0, z = _mm256_setzero_ps(); __m128 vr0 = _mm_set1_ps(1.f), vr1 = vr0, vr2 = vr0, z = _mm_setzero_ps();
// now compute dot product of the weights // now compute dot product of the weights
// and im2row-transformed part of the tensor // and im2row-transformed part of the tensor
...@@ -104,9 +104,9 @@ void fastConv( const float* weights, size_t wstep, const float* bias, ...@@ -104,9 +104,9 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
r0 = relu[i]; r0 = relu[i];
r1 = relu[i+1]; r1 = relu[i+1];
r2 = relu[i+2]; r2 = relu[i+2];
vr0 = _mm256_set1_ps(r0); vr0 = _mm_set1_ps(r0);
vr1 = _mm256_set1_ps(r1); vr1 = _mm_set1_ps(r1);
vr2 = _mm256_set1_ps(r2); vr2 = _mm_set1_ps(r2);
} }
int j = 0; int j = 0;
...@@ -156,38 +156,38 @@ void fastConv( const float* weights, size_t wstep, const float* bias, ...@@ -156,38 +156,38 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
t1 = _mm256_add_ps(t1, _mm256_permute2f128_ps(t1, t1, 1)); t1 = _mm256_add_ps(t1, _mm256_permute2f128_ps(t1, t1, 1));
t2 = _mm256_add_ps(t2, _mm256_permute2f128_ps(t2, t2, 1)); t2 = _mm256_add_ps(t2, _mm256_permute2f128_ps(t2, t2, 1));
__m256 s0, s1, s2; __m128 s0, s1, s2;
if( initOutput ) if( initOutput )
{ {
s0 = _mm256_set1_ps(bias0); s0 = _mm_set1_ps(bias0);
s1 = _mm256_set1_ps(bias1); s1 = _mm_set1_ps(bias1);
s2 = _mm256_set1_ps(bias2); s2 = _mm_set1_ps(bias2);
} }
else else
{ {
s0 = _mm256_castps128_ps256(_mm_loadu_ps(outptr0 + j)); s0 = _mm_loadu_ps(outptr0 + j);
s1 = _mm256_castps128_ps256(_mm_loadu_ps(outptr1 + j)); s1 = _mm_loadu_ps(outptr1 + j);
s2 = _mm256_castps128_ps256(_mm_loadu_ps(outptr2 + j)); s2 = _mm_loadu_ps(outptr2 + j);
} }
s0 = _mm256_add_ps(s0, t0); s0 = _mm_add_ps(s0, _mm256_castps256_ps128(t0));
s1 = _mm256_add_ps(s1, t1); s1 = _mm_add_ps(s1, _mm256_castps256_ps128(t1));
s2 = _mm256_add_ps(s2, t2); s2 = _mm_add_ps(s2, _mm256_castps256_ps128(t2));
if( relu ) if( relu )
{ {
__m256 m0 = _mm256_cmp_ps(s0, z, _CMP_GT_OS); __m128 m0 = _mm_cmp_ps(s0, z, _CMP_GT_OS);
__m256 m1 = _mm256_cmp_ps(s1, z, _CMP_GT_OS); __m128 m1 = _mm_cmp_ps(s1, z, _CMP_GT_OS);
__m256 m2 = _mm256_cmp_ps(s2, z, _CMP_GT_OS); __m128 m2 = _mm_cmp_ps(s2, z, _CMP_GT_OS);
s0 = _mm256_xor_ps(s0, _mm256_andnot_ps(m0, _mm256_xor_ps(_mm256_mul_ps(s0, vr0), s0))); s0 = _mm_xor_ps(s0, _mm_andnot_ps(m0, _mm_xor_ps(_mm_mul_ps(s0, vr0), s0)));
s1 = _mm256_xor_ps(s1, _mm256_andnot_ps(m1, _mm256_xor_ps(_mm256_mul_ps(s1, vr1), s1))); s1 = _mm_xor_ps(s1, _mm_andnot_ps(m1, _mm_xor_ps(_mm_mul_ps(s1, vr1), s1)));
s2 = _mm256_xor_ps(s2, _mm256_andnot_ps(m2, _mm256_xor_ps(_mm256_mul_ps(s2, vr2), s2))); s2 = _mm_xor_ps(s2, _mm_andnot_ps(m2, _mm_xor_ps(_mm_mul_ps(s2, vr2), s2)));
} }
_mm_storeu_ps(outptr0 + j, _mm256_castps256_ps128(s0)); _mm_storeu_ps(outptr0 + j, s0);
_mm_storeu_ps(outptr1 + j, _mm256_castps256_ps128(s1)); _mm_storeu_ps(outptr1 + j, s1);
_mm_storeu_ps(outptr2 + j, _mm256_castps256_ps128(s2)); _mm_storeu_ps(outptr2 + j, s2);
} }
for( ; j < blockSize; j++ ) for( ; j < blockSize; j++ )
...@@ -294,11 +294,63 @@ void fastGEMM1T( const float* vec, const float* weights, ...@@ -294,11 +294,63 @@ void fastGEMM1T( const float* vec, const float* weights,
_mm256_zeroupper(); _mm256_zeroupper();
} }
void fastGEMM( const float* aptr, size_t astep, const float* bptr, void fastGEMM( const float* aptr, size_t astep, const float* bptr,
size_t bstep, float* cptr, size_t cstep, size_t bstep, float* cptr, size_t cstep,
int ma, int na, int nb ) int ma, int na, int nb )
{ {
int n = 0; int n = 0;
#if CV_AVX_512F
for( ; n <= nb - 32; n += 32 )
{
for( int m = 0; m < ma; m += 4 )
{
const float* aptr0 = aptr + astep*m;
const float* aptr1 = aptr + astep*std::min(m+1, ma-1);
const float* aptr2 = aptr + astep*std::min(m+2, ma-1);
const float* aptr3 = aptr + astep*std::min(m+3, ma-1);
float* cptr0 = cptr + cstep*m;
float* cptr1 = cptr + cstep*std::min(m+1, ma-1);
float* cptr2 = cptr + cstep*std::min(m+2, ma-1);
float* cptr3 = cptr + cstep*std::min(m+3, ma-1);
__m512 d00 = _mm512_setzero_ps(), d01 = _mm512_setzero_ps();
__m512 d10 = _mm512_setzero_ps(), d11 = _mm512_setzero_ps();
__m512 d20 = _mm512_setzero_ps(), d21 = _mm512_setzero_ps();
__m512 d30 = _mm512_setzero_ps(), d31 = _mm512_setzero_ps();
for( int k = 0; k < na; k++ )
{
__m512 a0 = _mm512_set1_ps(aptr0[k]);
__m512 a1 = _mm512_set1_ps(aptr1[k]);
__m512 a2 = _mm512_set1_ps(aptr2[k]);
__m512 a3 = _mm512_set1_ps(aptr3[k]);
__m512 b0 = _mm512_loadu_ps(bptr + k*bstep + n);
__m512 b1 = _mm512_loadu_ps(bptr + k*bstep + n + 16);
d00 = _mm512_fmadd_ps(a0, b0, d00);
d01 = _mm512_fmadd_ps(a0, b1, d01);
d10 = _mm512_fmadd_ps(a1, b0, d10);
d11 = _mm512_fmadd_ps(a1, b1, d11);
d20 = _mm512_fmadd_ps(a2, b0, d20);
d21 = _mm512_fmadd_ps(a2, b1, d21);
d30 = _mm512_fmadd_ps(a3, b0, d30);
d31 = _mm512_fmadd_ps(a3, b1, d31);
}
_mm512_storeu_ps(cptr0 + n, d00);
_mm512_storeu_ps(cptr0 + n + 16, d01);
_mm512_storeu_ps(cptr1 + n, d10);
_mm512_storeu_ps(cptr1 + n + 16, d11);
_mm512_storeu_ps(cptr2 + n, d20);
_mm512_storeu_ps(cptr2 + n + 16, d21);
_mm512_storeu_ps(cptr3 + n, d30);
_mm512_storeu_ps(cptr3 + n + 16, d31);
}
}
#endif
for( ; n <= nb - 16; n += 16 ) for( ; n <= nb - 16; n += 16 )
{ {
for( int m = 0; m < ma; m += 4 ) for( int m = 0; m < ma; m += 4 )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment