gemm.cpp 5.97 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                           License Agreement
//                For Open Source Computer Vision Library
//
// Copyright (C) 2010-2012, Multicoreware, Inc., all rights reserved.
// Copyright (C) 2010-2012, Advanced Micro Devices, Inc., all rights reserved.
// Third party copyrights are property of their respective owners.
//
// @Authors
//    Peng Xiao, pengxiao@multicorewareinc.com
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other oclMaterials provided with the distribution.
//
//   * The name of the copyright holders may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors as is and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include <iomanip>
#include "precomp.hpp"

49 50 51
#if !defined HAVE_CLAMDBLAS
void cv::ocl::gemm(const oclMat&, const oclMat&, double,
                   const oclMat&, double, oclMat&, int)
Niko's avatar
Niko committed
52
{
53
    CV_Error(Error::StsNotImplemented, "OpenCL BLAS is not implemented");
Niko's avatar
Niko committed
54
}
55
#else
56
#include "clAmdBlas.h"
57 58
using namespace cv;

59 60 61 62
void cv::ocl::gemm(const oclMat &src1, const oclMat &src2, double alpha,
                   const oclMat &src3, double beta, oclMat &dst, int flags)
{
    CV_Assert(src1.cols == src2.rows &&
63
              (src3.empty() || (src1.rows == src3.rows && src2.cols == src3.cols)));
64 65 66 67 68 69 70 71 72 73 74
    CV_Assert(!(cv::GEMM_3_T & flags)); // cv::GEMM_3_T is not supported
    if(!src3.empty())
    {
        src3.copyTo(dst);
    }
    else
    {
        dst.create(src1.rows, src2.cols, src1.type());
        dst.setTo(Scalar::all(0));
    }
    openCLSafeCall( clAmdBlasSetup() );
75

76 77 78
    const clAmdBlasTranspose transA = (cv::GEMM_1_T & flags) ? clAmdBlasTrans : clAmdBlasNoTrans;
    const clAmdBlasTranspose transB = (cv::GEMM_2_T & flags) ? clAmdBlasTrans : clAmdBlasNoTrans;
    const clAmdBlasOrder     order  = clAmdBlasRowMajor;
79

80 81 82 83 84 85 86 87 88
    const int M = src1.rows;
    const int N = src2.cols;
    const int K = src1.cols;
    int lda     = src1.step;
    int ldb     = src2.step;
    int ldc     = dst.step;
    int offa    = src1.offset;
    int offb    = src2.offset;
    int offc    = dst.offset;
89

90
    cl_command_queue clq = (cl_command_queue)src1.clCxt->oclCommandQueue();
91 92 93 94 95 96 97 98 99
    switch(src1.type())
    {
    case CV_32FC1:
        lda  /= sizeof(float);
        ldb  /= sizeof(float);
        ldc  /= sizeof(float);
        offa /= sizeof(float);
        offb /= sizeof(float);
        offc /= sizeof(float);
100

101 102 103 104
        openCLSafeCall
        (
            clAmdBlasSgemmEx(order, transA, transB, M, N, K,
                             alpha, (const cl_mem)src1.data, offa, lda, (const cl_mem)src2.data, offb, ldb,
105
                             beta, (cl_mem)dst.data, offc, ldc, 1, &clq, 0, NULL, NULL)
106 107 108 109 110 111 112 113 114 115 116 117 118
        );
        break;
    case CV_64FC1:
        lda  /= sizeof(double);
        ldb  /= sizeof(double);
        ldc  /= sizeof(double);
        offa /= sizeof(double);
        offb /= sizeof(double);
        offc /= sizeof(double);
        openCLSafeCall
        (
            clAmdBlasDgemmEx(order, transA, transB, M, N, K,
                             alpha, (const cl_mem)src1.data, offa, lda, (const cl_mem)src2.data, offb, ldb,
119
                             beta, (cl_mem)dst.data, offc, ldc, 1, &clq, 0, NULL, NULL)
120 121 122 123
        );
        break;
    case CV_32FC2:
    {
124 125 126 127 128 129
        lda  /= (2*sizeof(float));
        ldb  /= (2*sizeof(float));
        ldc  /= (2*sizeof(float));
        offa /= (2*sizeof(float));
        offb /= (2*sizeof(float));
        offc /= (2*sizeof(float));
130 131 132 133 134 135
        cl_float2 alpha_2 = {{alpha, 0}};
        cl_float2 beta_2  = {{beta, 0}};
        openCLSafeCall
        (
            clAmdBlasCgemmEx(order, transA, transB, M, N, K,
                             alpha_2, (const cl_mem)src1.data, offa, lda, (const cl_mem)src2.data, offb, ldb,
136
                             beta_2, (cl_mem)dst.data, offc, ldc, 1, &clq, 0, NULL, NULL)
137 138 139 140 141
        );
    }
    break;
    case CV_64FC2:
    {
142 143 144 145 146 147
        lda  /= (2*sizeof(double));
        ldb  /= (2*sizeof(double));
        ldc  /= (2*sizeof(double));
        offa /= (2*sizeof(double));
        offb /= (2*sizeof(double));
        offc /= (2*sizeof(double));
148 149 150 151 152 153
        cl_double2 alpha_2 = {{alpha, 0}};
        cl_double2 beta_2  = {{beta, 0}};
        openCLSafeCall
        (
            clAmdBlasZgemmEx(order, transA, transB, M, N, K,
                             alpha_2, (const cl_mem)src1.data, offa, lda, (const cl_mem)src2.data, offb, ldb,
154
                             beta_2, (cl_mem)dst.data, offc, ldc, 1, &clq, 0, NULL, NULL)
155 156 157 158 159 160
        );
    }
    break;
    }
    clAmdBlasTeardown();
}
161
#endif