gemm.cpp 6.87 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                           License Agreement
//                For Open Source Computer Vision Library
//
// Copyright (C) 2010-2012, Multicoreware, Inc., all rights reserved.
// Copyright (C) 2010-2012, Advanced Micro Devices, Inc., all rights reserved.
// Third party copyrights are property of their respective owners.
//
// @Authors
//    Peng Xiao, pengxiao@multicorewareinc.com
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
Andrey Pavlenko's avatar
Andrey Pavlenko committed
28
//     and/or other materials provided with the distribution.
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
//
//   * The name of the copyright holders may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors as is and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"

48 49 50 51 52 53 54 55 56
namespace cv { namespace ocl {

// used for clAmdBlas library to avoid redundant setup/teardown
void clBlasSetup();
void clBlasTeardown();

}} /* namespace cv { namespace ocl */


57 58 59
#if !defined HAVE_CLAMDBLAS
void cv::ocl::gemm(const oclMat&, const oclMat&, double,
                   const oclMat&, double, oclMat&, int)
Niko's avatar
Niko committed
60
{
61
    CV_Error(CV_OpenCLNoAMDBlasFft, "OpenCL BLAS is not implemented");
Niko's avatar
Niko committed
62
}
63 64 65

void cv::ocl::clBlasSetup()
{
66
    CV_Error(CV_OpenCLNoAMDBlasFft, "OpenCL BLAS is not implemented");
67 68 69 70
}

void cv::ocl::clBlasTeardown()
{
71
    //intentionally do nothing
72 73
}

74
#else
75
#include "opencv2/ocl/cl_runtime/clamdblas_runtime.hpp"
76 77
using namespace cv;

78 79 80 81 82 83
static bool clBlasInitialized = false;

void cv::ocl::clBlasSetup()
{
    if(!clBlasInitialized)
    {
84
        AutoLock lock(getInitializationMutex());
peng xiao's avatar
peng xiao committed
85 86 87 88 89
        if(!clBlasInitialized)
        {
            openCLSafeCall(clAmdBlasSetup());
            clBlasInitialized = true;
        }
90 91 92 93 94
    }
}

void cv::ocl::clBlasTeardown()
{
95
    AutoLock lock(getInitializationMutex());
96 97 98 99 100 101 102
    if(clBlasInitialized)
    {
        clAmdBlasTeardown();
        clBlasInitialized = false;
    }
}

103 104 105 106
void cv::ocl::gemm(const oclMat &src1, const oclMat &src2, double alpha,
                   const oclMat &src3, double beta, oclMat &dst, int flags)
{
    CV_Assert(src1.cols == src2.rows &&
107
              (src3.empty() || (src1.rows == src3.rows && src2.cols == src3.cols)));
108 109 110 111 112 113 114 115 116 117
    CV_Assert(!(cv::GEMM_3_T & flags)); // cv::GEMM_3_T is not supported
    if(!src3.empty())
    {
        src3.copyTo(dst);
    }
    else
    {
        dst.create(src1.rows, src2.cols, src1.type());
        dst.setTo(Scalar::all(0));
    }
118 119

    clBlasSetup();
120

121 122 123
    const clAmdBlasTranspose transA = (cv::GEMM_1_T & flags) ? clAmdBlasTrans : clAmdBlasNoTrans;
    const clAmdBlasTranspose transB = (cv::GEMM_2_T & flags) ? clAmdBlasTrans : clAmdBlasNoTrans;
    const clAmdBlasOrder     order  = clAmdBlasRowMajor;
124

125 126 127 128 129 130 131 132 133
    const int M = src1.rows;
    const int N = src2.cols;
    const int K = src1.cols;
    int lda     = src1.step;
    int ldb     = src2.step;
    int ldc     = dst.step;
    int offa    = src1.offset;
    int offb    = src2.offset;
    int offc    = dst.offset;
134

135
    cl_command_queue clq = *(cl_command_queue*)src1.clCxt->getOpenCLCommandQueuePtr();
136 137 138 139 140 141 142 143 144
    switch(src1.type())
    {
    case CV_32FC1:
        lda  /= sizeof(float);
        ldb  /= sizeof(float);
        ldc  /= sizeof(float);
        offa /= sizeof(float);
        offb /= sizeof(float);
        offc /= sizeof(float);
145

146 147 148 149
        openCLSafeCall
        (
            clAmdBlasSgemmEx(order, transA, transB, M, N, K,
                             alpha, (const cl_mem)src1.data, offa, lda, (const cl_mem)src2.data, offb, ldb,
150
                             beta, (cl_mem)dst.data, offc, ldc, 1, &clq, 0, NULL, NULL)
151 152 153 154 155 156 157 158 159 160 161 162 163
        );
        break;
    case CV_64FC1:
        lda  /= sizeof(double);
        ldb  /= sizeof(double);
        ldc  /= sizeof(double);
        offa /= sizeof(double);
        offb /= sizeof(double);
        offc /= sizeof(double);
        openCLSafeCall
        (
            clAmdBlasDgemmEx(order, transA, transB, M, N, K,
                             alpha, (const cl_mem)src1.data, offa, lda, (const cl_mem)src2.data, offb, ldb,
164
                             beta, (cl_mem)dst.data, offc, ldc, 1, &clq, 0, NULL, NULL)
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
        );
        break;
    case CV_32FC2:
    {
        lda  /= sizeof(std::complex<float>);
        ldb  /= sizeof(std::complex<float>);
        ldc  /= sizeof(std::complex<float>);
        offa /= sizeof(std::complex<float>);
        offb /= sizeof(std::complex<float>);
        offc /= sizeof(std::complex<float>);
        cl_float2 alpha_2 = {{alpha, 0}};
        cl_float2 beta_2  = {{beta, 0}};
        openCLSafeCall
        (
            clAmdBlasCgemmEx(order, transA, transB, M, N, K,
                             alpha_2, (const cl_mem)src1.data, offa, lda, (const cl_mem)src2.data, offb, ldb,
181
                             beta_2, (cl_mem)dst.data, offc, ldc, 1, &clq, 0, NULL, NULL)
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
        );
    }
    break;
    case CV_64FC2:
    {
        lda  /= sizeof(std::complex<double>);
        ldb  /= sizeof(std::complex<double>);
        ldc  /= sizeof(std::complex<double>);
        offa /= sizeof(std::complex<double>);
        offb /= sizeof(std::complex<double>);
        offc /= sizeof(std::complex<double>);
        cl_double2 alpha_2 = {{alpha, 0}};
        cl_double2 beta_2  = {{beta, 0}};
        openCLSafeCall
        (
            clAmdBlasZgemmEx(order, transA, transB, M, N, K,
                             alpha_2, (const cl_mem)src1.data, offa, lda, (const cl_mem)src2.data, offb, ldb,
199
                             beta_2, (cl_mem)dst.data, offc, ldc, 1, &clq, 0, NULL, NULL)
200 201 202 203 204
        );
    }
    break;
    }
}
205
#endif