svmsgd.cpp 16.2 KB
Newer Older
joao.faro's avatar
joao.faro committed
1 2 3 4 5 6 7 8 9 10 11 12 13
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                           License Agreement
//                For Open Source Computer Vision Library
//
// Copyright (C) 2000, Intel Corporation, all rights reserved.
14
// Copyright (C) 2016, Itseez Inc, all rights reserved.
joao.faro's avatar
joao.faro committed
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of the copyright holders may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"
44
#include "limits"
45 46 47 48 49

#include <iostream>

using std::cout;
using std::endl;
joao.faro's avatar
joao.faro committed
50 51 52 53 54

/****************************************************************************************\
*                        Stochastic Gradient Descent SVM Classifier                      *
\****************************************************************************************/

55 56 57 58
namespace cv
{
namespace ml
{
joao.faro's avatar
joao.faro committed
59

60 61
class SVMSGDImpl : public SVMSGD
{
joao.faro's avatar
joao.faro committed
62

63 64
public:
    SVMSGDImpl();
joao.faro's avatar
joao.faro committed
65

66
    virtual ~SVMSGDImpl() {}
joao.faro's avatar
joao.faro committed
67

68
    virtual bool train(const Ptr<TrainData>& data, int);
joao.faro's avatar
joao.faro committed
69

70
    virtual float predict( InputArray samples, OutputArray results=noArray(), int flags = 0 ) const;
joao.faro's avatar
joao.faro committed
71

72
    virtual bool isClassifier() const;
joao.faro's avatar
joao.faro committed
73

74
    virtual bool isTrained() const;
joao.faro's avatar
joao.faro committed
75

76
    virtual void clear();
joao.faro's avatar
joao.faro committed
77

78
    virtual void write(FileStorage &fs) const;
joao.faro's avatar
joao.faro committed
79

80
    virtual void read(const FileNode &fn);
joao.faro's avatar
joao.faro committed
81

82
    virtual Mat getWeights(){ return weights_; }
joao.faro's avatar
joao.faro committed
83

84
    virtual float getShift(){ return shift_; }
joao.faro's avatar
joao.faro committed
85

86
    virtual int getVarCount() const { return weights_.cols; }
joao.faro's avatar
joao.faro committed
87

88
    virtual String getDefaultName() const {return "opencv_ml_svmsgd";}
joao.faro's avatar
joao.faro committed
89

90
    virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN);
joao.faro's avatar
joao.faro committed
91

92 93
    CV_IMPL_PROPERTY(int, SvmsgdType, params.svmsgdType)
    CV_IMPL_PROPERTY(int, MarginType, params.marginType)
94 95 96
    CV_IMPL_PROPERTY(float, MarginRegularization, params.marginRegularization)
    CV_IMPL_PROPERTY(float, InitialStepSize, params.initialStepSize)
    CV_IMPL_PROPERTY(float, StepDecreasingPower, params.stepDecreasingPower)
97
    CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
joao.faro's avatar
joao.faro committed
98

99
private:
100
    void updateWeights(InputArray sample, bool positive, float stepSize, Mat &weights);
101

102
    void writeParams( FileStorage &fs ) const;
103

104
    void readParams( const FileNode &fn );
105

106
    static inline bool isPositive(float val) { return val > 0; }
107

108
    static void normalizeSamples(Mat &matrix, Mat &average, float &multiplier);
109 110 111

    float calcShift(InputArray _samples, InputArray _responses) const;

112
    static void makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier);
113

114 115 116 117 118 119 120
    // Vector with SVM weights
    Mat weights_;
    float shift_;

    // Parameters for learning
    struct SVMSGDParams
    {
121 122 123
        float marginRegularization;
        float initialStepSize;
        float stepDecreasingPower;
124
        TermCriteria termCrit;
125 126
        int svmsgdType;
        int marginType;
127 128 129 130 131 132
    };

    SVMSGDParams params;
};

Ptr<SVMSGD> SVMSGD::create()
133
{
134
    return makePtr<SVMSGDImpl>();
joao.faro's avatar
joao.faro committed
135 136
}

137 138 139 140 141 142
Ptr<SVMSGD> SVMSGD::load(const String& filepath, const String& nodeName)
{
    return Algorithm::load<SVMSGD>(filepath, nodeName);
}


143
void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier)
144 145 146
{
    int featuresCount = samples.cols;
    int samplesCount = samples.rows;
147

148
    average = Mat(1, featuresCount, samples.type());
149
    CV_Assert(average.type() ==  CV_32FC1);
150
    for (int featureIndex = 0; featureIndex < featuresCount; featureIndex++)
151
    {
152
        average.at<float>(featureIndex) = static_cast<float>(mean(samples.col(featureIndex))[0]);
153
    }
154 155

    for (int sampleIndex = 0; sampleIndex < samplesCount; sampleIndex++)
156
    {
157
        samples.row(sampleIndex) -= average;
158 159
    }

160
    double normValue = norm(samples);
joao.faro's avatar
joao.faro committed
161

berak's avatar
berak committed
162
    multiplier = static_cast<float>(sqrt(static_cast<double>(samples.total())) / normValue);
163 164

    samples *= multiplier;
165
}
166

167
void SVMSGDImpl::makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier)
168
{
169 170
    Mat normalizedTrainSamples = trainSamples.clone();
    int samplesCount = normalizedTrainSamples.rows;
171

172
    normalizeSamples(normalizedTrainSamples, average, multiplier);
173

174
    Mat onesCol = Mat::ones(samplesCount, 1, CV_32F);
175
    cv::hconcat(normalizedTrainSamples, onesCol, extendedTrainSamples);
joao.faro's avatar
joao.faro committed
176 177
}

178
void SVMSGDImpl::updateWeights(InputArray _sample, bool positive, float stepSize, Mat& weights)
179
{
180
    Mat sample = _sample.getMat();
181

182
    int response = positive ? 1 : -1; // ensure that trainResponses are -1 or 1
joao.faro's avatar
joao.faro committed
183

184 185 186
    if ( sample.dot(weights) * response > 1)
    {
        // Not a support vector, only apply weight decay
187
        weights *= (1.f - stepSize * params.marginRegularization);
188 189 190 191
    }
    else
    {
        // It's a support vector, add it to the weights
192
        weights -= (stepSize * params.marginRegularization) * weights - (stepSize * response) * sample;
joao.faro's avatar
joao.faro committed
193
    }
194
}
joao.faro's avatar
joao.faro committed
195

196 197
float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
{
198
    float margin[2] = { std::numeric_limits<float>::max(), std::numeric_limits<float>::max() };
joao.faro's avatar
joao.faro committed
199

200 201
    Mat trainSamples = _samples.getMat();
    int trainSamplesCount = trainSamples.rows;
joao.faro's avatar
joao.faro committed
202

203 204
    Mat trainResponses = _responses.getMat();

205
    CV_Assert(trainResponses.type() ==  CV_32FC1);
206 207 208
    for (int samplesIndex = 0; samplesIndex < trainSamplesCount; samplesIndex++)
    {
        Mat currentSample = trainSamples.row(samplesIndex);
Marina Noskova's avatar
Marina Noskova committed
209
        float dotProduct = static_cast<float>(currentSample.dot(weights_));
joao.faro's avatar
joao.faro committed
210

211 212 213 214
        bool positive = isPositive(trainResponses.at<float>(samplesIndex));
        int index = positive ? 0 : 1;
        float signToMul = positive ? 1.f : -1.f;
        float curMargin = dotProduct * signToMul;
215

216
        if (curMargin < margin[index])
217
        {
218
            margin[index] = curMargin;
219
        }
joao.faro's avatar
joao.faro committed
220 221
    }

222
    return -(margin[0] - margin[1]) / 2.f;
joao.faro's avatar
joao.faro committed
223 224
}

225 226 227 228 229 230 231 232 233 234
bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
{
    clear();
    CV_Assert( isClassifier() );   //toDo: consider

    Mat trainSamples = data->getTrainSamples();

    int featureCount = trainSamples.cols;
    Mat trainResponses = data->getTrainResponses();        // (trainSamplesCount x 1) matrix

235
    CV_Assert(trainResponses.rows == trainSamples.rows);
236

237
    if (trainResponses.empty())
238 239 240
    {
        return false;
    }
241 242 243 244 245

    int positiveCount = countNonZero(trainResponses >= 0);
    int negativeCount = countNonZero(trainResponses < 0);

    if ( positiveCount <= 0 || negativeCount <= 0 )
246 247
    {
        weights_ = Mat::zeros(1, featureCount, CV_32F);
248
        shift_ = (positiveCount > 0) ? 1.f : -1.f;
249
        return true;
250
    }
251 252

    Mat extendedTrainSamples;
253 254 255
    Mat average;
    float multiplier = 0;
    makeExtendedTrainSamples(trainSamples, extendedTrainSamples, average, multiplier);
256 257 258 259

    int extendedTrainSamplesCount = extendedTrainSamples.rows;
    int extendedFeatureCount = extendedTrainSamples.cols;

260 261 262
    Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
    Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
    Mat averageExtendedWeights;
263 264 265 266 267 268 269
    if (params.svmsgdType == ASGD)
    {
        averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
    }

    RNG rng(0);

270
    CV_Assert (params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS);
271 272 273 274
    int maxCount = (params.termCrit.type & TermCriteria::COUNT) ? params.termCrit.maxCount : INT_MAX;
    double epsilon = (params.termCrit.type & TermCriteria::EPS) ? params.termCrit.epsilon : 0;

    double err = DBL_MAX;
275
    CV_Assert (trainResponses.type() == CV_32FC1);
276 277 278 279 280 281 282
    // Stochastic gradient descent SVM
    for (int iter = 0; (iter < maxCount) && (err > epsilon); iter++)
    {
        int randomNumber = rng.uniform(0, extendedTrainSamplesCount);             //generate sample number

        Mat currentSample = extendedTrainSamples.row(randomNumber);

283
        float stepSize = params.initialStepSize * std::pow((1 + params.marginRegularization * params.initialStepSize * (float)iter), (-params.stepDecreasingPower));    //update stepSize
284

285
        updateWeights( currentSample, isPositive(trainResponses.at<float>(randomNumber)), stepSize, extendedWeights );
286 287 288 289 290 291 292 293 294 295

        //average weights (only for ASGD model)
        if (params.svmsgdType == ASGD)
        {
            averageExtendedWeights = ((float)iter/ (1 + (float)iter)) * averageExtendedWeights  + extendedWeights / (1 + (float) iter);
            err = norm(averageExtendedWeights - previousWeights);
            averageExtendedWeights.copyTo(previousWeights);
        }
        else
        {
296 297
            err = norm(extendedWeights - previousWeights);
            extendedWeights.copyTo(previousWeights);
298 299 300 301 302 303 304 305 306 307
        }
    }

    if (params.svmsgdType == ASGD)
    {
        extendedWeights = averageExtendedWeights;
    }

    Rect roi(0, 0, featureCount, 1);
    weights_ = extendedWeights(roi);
308
    weights_ *= multiplier;
309

310
    CV_Assert((params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN) && (extendedWeights.type() ==  CV_32FC1));
311

312 313
    if (params.marginType == SOFT_MARGIN)
    {
Marina Noskova's avatar
Marina Noskova committed
314
        shift_ = extendedWeights.at<float>(featureCount) - static_cast<float>(weights_.dot(average));
315 316 317 318 319
    }
    else
    {
        shift_ = calcShift(trainSamples, trainResponses);
    }
320 321 322 323

    return true;
}

324 325 326 327 328 329
float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const
{
    float result = 0;
    cv::Mat samples = _samples.getMat();
    int nSamples = samples.rows;
    cv::Mat results;
joao.faro's avatar
joao.faro committed
330

331
    CV_Assert( samples.cols == weights_.cols && samples.type() == CV_32FC1);
332 333 334 335 336 337 338 339 340

    if( _results.needed() )
    {
        _results.create( nSamples, 1, samples.type() );
        results = _results.getMat();
    }
    else
    {
        CV_Assert( nSamples == 1 );
341
        results = Mat(1, 1, CV_32FC1, &result);
joao.faro's avatar
joao.faro committed
342
    }
343 344 345

    for (int sampleIndex = 0; sampleIndex < nSamples; sampleIndex++)
    {
346
        Mat currentSample = samples.row(sampleIndex);
Marina Noskova's avatar
Marina Noskova committed
347 348
        float criterion = static_cast<float>(currentSample.dot(weights_)) + shift_;
        results.at<float>(sampleIndex) = (criterion >= 0) ? 1.f : -1.f;
349 350 351
    }

    return result;
joao.faro's avatar
joao.faro committed
352 353
}

354
bool SVMSGDImpl::isClassifier() const
355
{
356
    return (params.svmsgdType == SGD || params.svmsgdType == ASGD)
357 358
            &&
            (params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN)
359
            &&
360
            (params.marginRegularization > 0) && (params.initialStepSize > 0) && (params.stepDecreasingPower >= 0);
361 362 363 364 365 366 367 368 369 370 371 372
}

bool SVMSGDImpl::isTrained() const
{
    return !weights_.empty();
}

void SVMSGDImpl::write(FileStorage& fs) const
{
    if( !isTrained() )
        CV_Error( CV_StsParseError, "SVMSGD model data is invalid, it hasn't been trained" );

373
    writeFormat(fs);
374 375 376
    writeParams( fs );

    fs << "weights" << weights_;
377
    fs << "shift" << shift_;
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
}

void SVMSGDImpl::writeParams( FileStorage& fs ) const
{
    String SvmsgdTypeStr;

    switch (params.svmsgdType)
    {
    case SGD:
        SvmsgdTypeStr = "SGD";
        break;
    case ASGD:
        SvmsgdTypeStr = "ASGD";
        break;
    default:
393
        SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
394 395 396 397
    }

    fs << "svmsgdType" << SvmsgdTypeStr;

398 399 400 401 402 403 404 405 406 407 408
    String marginTypeStr;

    switch (params.marginType)
    {
    case SOFT_MARGIN:
        marginTypeStr = "SOFT_MARGIN";
        break;
    case HARD_MARGIN:
        marginTypeStr = "HARD_MARGIN";
        break;
    default:
409
        marginTypeStr = format("Unknown_%d", params.marginType);
410 411 412 413
    }

    fs << "marginType" << marginTypeStr;

414 415 416
    fs << "marginRegularization" << params.marginRegularization;
    fs << "initialStepSize" << params.initialStepSize;
    fs << "stepDecreasingPower" << params.stepDecreasingPower;
417 418 419 420 421 422 423 424 425 426 427

    fs << "term_criteria" << "{:";
    if( params.termCrit.type & TermCriteria::EPS )
        fs << "epsilon" << params.termCrit.epsilon;
    if( params.termCrit.type & TermCriteria::COUNT )
        fs << "iterations" << params.termCrit.maxCount;
    fs << "}";
}
void SVMSGDImpl::readParams( const FileNode& fn )
{
    String svmsgdTypeStr = (String)fn["svmsgdType"];
428
    int svmsgdType =
429
            svmsgdTypeStr == "SGD" ? SGD :
430
                                     svmsgdTypeStr == "ASGD" ? ASGD : -1;
431

432
    if( svmsgdType < 0 )
433 434 435 436
        CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );

    params.svmsgdType = svmsgdType;

437
    String marginTypeStr = (String)fn["marginType"];
438
    int marginType =
439
            marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
440
                                             marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1;
441

442
    if( marginType < 0 )
443 444 445 446
        CV_Error( CV_StsParseError, "Missing or invalid margin type" );

    params.marginType = marginType;

447 448
    CV_Assert ( fn["marginRegularization"].isReal() );
    params.marginRegularization = (float)fn["marginRegularization"];
449

450 451
    CV_Assert ( fn["initialStepSize"].isReal() );
    params.initialStepSize = (float)fn["initialStepSize"];
452

453 454
    CV_Assert ( fn["stepDecreasingPower"].isReal() );
    params.stepDecreasingPower = (float)fn["stepDecreasingPower"];
455 456

    FileNode tcnode = fn["term_criteria"];
457 458 459 460 461 462 463
    CV_Assert(!tcnode.empty());
    params.termCrit.epsilon = (double)tcnode["epsilon"];
    params.termCrit.maxCount = (int)tcnode["iterations"];
    params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
            (params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
    CV_Assert ((params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS));
}
464

465 466 467 468 469 470 471 472
void SVMSGDImpl::read(const FileNode& fn)
{
    clear();

    readParams(fn);

    fn["weights"] >> weights_;
    fn["shift"] >> shift_;
joao.faro's avatar
joao.faro committed
473 474
}

475 476 477
void SVMSGDImpl::clear()
{
    weights_.release();
478
    shift_ = 0;
joao.faro's avatar
joao.faro committed
479
}
480 481 482 483 484


SVMSGDImpl::SVMSGDImpl()
{
    clear();
485
    setOptimalParameters();
486 487
}

488
void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
489
{
490
    switch (svmsgdType)
491 492 493
    {
    case SGD:
        params.svmsgdType = SGD;
494
        params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
495
                                                          (marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
496 497 498
        params.marginRegularization = 0.0001f;
        params.initialStepSize = 0.05f;
        params.stepDecreasingPower = 1.f;
499
        params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
500 501 502 503
        break;

    case ASGD:
        params.svmsgdType = ASGD;
504
        params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
505
                                                          (marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
506 507 508
        params.marginRegularization = 0.00001f;
        params.initialStepSize = 0.05f;
        params.stepDecreasingPower = 0.75f;
509
        params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
510 511 512 513 514 515 516 517
        break;

    default:
        CV_Error( CV_StsParseError, "SVMSGD model data is invalid" );
    }
}
}   //ml
}   //cv