svmsgd.cpp 16.7 KB
Newer Older
joao.faro's avatar
joao.faro committed
1 2 3 4 5 6 7 8 9 10 11 12 13
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                           License Agreement
//                For Open Source Computer Vision Library
//
// Copyright (C) 2000, Intel Corporation, all rights reserved.
14
// Copyright (C) 2016, Itseez Inc, all rights reserved.
joao.faro's avatar
joao.faro committed
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of the copyright holders may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"
44
#include "limits"
45 46 47 48 49

#include <iostream>

using std::cout;
using std::endl;
joao.faro's avatar
joao.faro committed
50 51 52 53 54

/****************************************************************************************\
*                        Stochastic Gradient Descent SVM Classifier                      *
\****************************************************************************************/

55 56 57 58
namespace cv
{
namespace ml
{
joao.faro's avatar
joao.faro committed
59

60 61
class SVMSGDImpl : public SVMSGD
{
joao.faro's avatar
joao.faro committed
62

63 64
public:
    SVMSGDImpl();
joao.faro's avatar
joao.faro committed
65

66
    virtual ~SVMSGDImpl() {}
joao.faro's avatar
joao.faro committed
67

68
    virtual bool train(const Ptr<TrainData>& data, int);
joao.faro's avatar
joao.faro committed
69

70
    virtual float predict( InputArray samples, OutputArray results=noArray(), int flags = 0 ) const;
joao.faro's avatar
joao.faro committed
71

72
    virtual bool isClassifier() const;
joao.faro's avatar
joao.faro committed
73

74
    virtual bool isTrained() const;
joao.faro's avatar
joao.faro committed
75

76
    virtual void clear();
joao.faro's avatar
joao.faro committed
77

78
    virtual void write(FileStorage &fs) const;
joao.faro's avatar
joao.faro committed
79

80
    virtual void read(const FileNode &fn);
joao.faro's avatar
joao.faro committed
81

82
    virtual Mat getWeights(){ return weights_; }
joao.faro's avatar
joao.faro committed
83

84
    virtual float getShift(){ return shift_; }
joao.faro's avatar
joao.faro committed
85

86
    virtual int getVarCount() const { return weights_.cols; }
joao.faro's avatar
joao.faro committed
87

88
    virtual String getDefaultName() const {return "opencv_ml_svmsgd";}
joao.faro's avatar
joao.faro committed
89

90
    virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN);
joao.faro's avatar
joao.faro committed
91

92 93
    CV_IMPL_PROPERTY(int, SvmsgdType, params.svmsgdType)
    CV_IMPL_PROPERTY(int, MarginType, params.marginType)
94 95 96
    CV_IMPL_PROPERTY(float, MarginRegularization, params.marginRegularization)
    CV_IMPL_PROPERTY(float, InitialStepSize, params.initialStepSize)
    CV_IMPL_PROPERTY(float, StepDecreasingPower, params.stepDecreasingPower)
97
    CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
joao.faro's avatar
joao.faro committed
98

99
private:
100
    void updateWeights(InputArray sample, bool isPositive, float stepSize, Mat &weights);
101

102
    std::pair<bool,bool> areClassesEmpty(Mat responses);
103

104
    void writeParams( FileStorage &fs ) const;
105

106
    void readParams( const FileNode &fn );
107

108
    static inline bool isPositive(float val) { return val > 0; }
109

110
    static void normalizeSamples(Mat &matrix, Mat &average, float &multiplier);
111 112 113

    float calcShift(InputArray _samples, InputArray _responses) const;

114
    static void makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier);
115 116


117 118 119 120 121 122 123 124

    // Vector with SVM weights
    Mat weights_;
    float shift_;

    // Parameters for learning
    struct SVMSGDParams
    {
125 126 127
        float marginRegularization;
        float initialStepSize;
        float stepDecreasingPower;
128
        TermCriteria termCrit;
129 130
        int svmsgdType;
        int marginType;
131 132 133 134 135 136
    };

    SVMSGDParams params;
};

Ptr<SVMSGD> SVMSGD::create()
137
{
138
    return makePtr<SVMSGDImpl>();
joao.faro's avatar
joao.faro committed
139 140
}

141
std::pair<bool,bool> SVMSGDImpl::areClassesEmpty(Mat responses)
142
{
143
    CV_Assert(responses.cols == 1 || responses.rows == 1);
144
    std::pair<bool,bool> emptyInClasses(true, true);
145
    int limitIndex = responses.rows;
146

147
    for(int index = 0; index < limitIndex; index++)
148
    {
149
        if (isPositive(responses.at<float>(index)))
150 151 152
            emptyInClasses.first = false;
        else
            emptyInClasses.second = false;
joao.faro's avatar
joao.faro committed
153

154 155 156
        if (!emptyInClasses.first && ! emptyInClasses.second)
            break;
    }
157

158 159
    return emptyInClasses;
}
160

161
void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier)
162 163 164
{
    int featuresCount = samples.cols;
    int samplesCount = samples.rows;
165

166 167
    average = Mat(1, featuresCount, samples.type());
    for (int featureIndex = 0; featureIndex < featuresCount; featureIndex++)
168
    {
169
        average.at<float>(featureIndex) = static_cast<float>(mean(samples.col(featureIndex))[0]);
170
    }
171 172

    for (int sampleIndex = 0; sampleIndex < samplesCount; sampleIndex++)
173
    {
174
        samples.row(sampleIndex) -= average;
175 176
    }

177
    double normValue = norm(samples);
joao.faro's avatar
joao.faro committed
178

Marina Noskova's avatar
Marina Noskova committed
179
    multiplier = static_cast<float>(sqrt(samples.total()) / normValue);
180 181

    samples *= multiplier;
182
}
183

184
void SVMSGDImpl::makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier)
185
{
186 187
    Mat normalizedTrainSamples = trainSamples.clone();
    int samplesCount = normalizedTrainSamples.rows;
188

189
    normalizeSamples(normalizedTrainSamples, average, multiplier);
190

191
    Mat onesCol = Mat::ones(samplesCount, 1, CV_32F);
192
    cv::hconcat(normalizedTrainSamples, onesCol, extendedTrainSamples);
joao.faro's avatar
joao.faro committed
193 194
}

195
void SVMSGDImpl::updateWeights(InputArray _sample, bool firstClass, float stepSize, Mat& weights)
196
{
197
    Mat sample = _sample.getMat();
198

199
    int response = firstClass ? 1 : -1; // ensure that trainResponses are -1 or 1
joao.faro's avatar
joao.faro committed
200

201 202 203
    if ( sample.dot(weights) * response > 1)
    {
        // Not a support vector, only apply weight decay
204
        weights *= (1.f - stepSize * params.marginRegularization);
205 206 207 208
    }
    else
    {
        // It's a support vector, add it to the weights
209
        weights -= (stepSize * params.marginRegularization) * weights - (stepSize * response) * sample;
joao.faro's avatar
joao.faro committed
210
    }
211
}
joao.faro's avatar
joao.faro committed
212

213 214
float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
{
215
    float margin[2] = { std::numeric_limits<float>::max(), std::numeric_limits<float>::max() };
joao.faro's avatar
joao.faro committed
216

217 218
    Mat trainSamples = _samples.getMat();
    int trainSamplesCount = trainSamples.rows;
joao.faro's avatar
joao.faro committed
219

220 221 222 223 224
    Mat trainResponses = _responses.getMat();

    for (int samplesIndex = 0; samplesIndex < trainSamplesCount; samplesIndex++)
    {
        Mat currentSample = trainSamples.row(samplesIndex);
Marina Noskova's avatar
Marina Noskova committed
225
        float dotProduct = static_cast<float>(currentSample.dot(weights_));
joao.faro's avatar
joao.faro committed
226

227 228 229 230
        bool positive = isPositive(trainResponses.at<float>(samplesIndex));
        int index = positive ? 0 : 1;
        float signToMul = positive ? 1.f : -1.f;
        float curMargin = dotProduct * signToMul;
231

232
        if (curMargin < margin[index])
233
        {
234
            margin[index] = curMargin;
235
        }
joao.faro's avatar
joao.faro committed
236 237
    }

238
    return -(margin[0] - margin[1]) / 2.f;
joao.faro's avatar
joao.faro committed
239 240
}

241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
{
    clear();
    CV_Assert( isClassifier() );   //toDo: consider

    Mat trainSamples = data->getTrainSamples();

    int featureCount = trainSamples.cols;
    Mat trainResponses = data->getTrainResponses();        // (trainSamplesCount x 1) matrix

    std::pair<bool,bool> areEmpty = areClassesEmpty(trainResponses);

    if ( areEmpty.first && areEmpty.second )
    {
        return false;
    }
    if ( areEmpty.first || areEmpty.second )
    {
        weights_ = Mat::zeros(1, featureCount, CV_32F);
Marina Noskova's avatar
Marina Noskova committed
260
        shift_ = areEmpty.first ? -1.f : 1.f;
261
        return true;
262
    }
263 264

    Mat extendedTrainSamples;
265 266 267
    Mat average;
    float multiplier = 0;
    makeExtendedTrainSamples(trainSamples, extendedTrainSamples, average, multiplier);
268 269 270 271

    int extendedTrainSamplesCount = extendedTrainSamples.rows;
    int extendedFeatureCount = extendedTrainSamples.cols;

272 273 274
    Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
    Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
    Mat averageExtendedWeights;
275 276 277 278 279 280 281
    if (params.svmsgdType == ASGD)
    {
        averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
    }

    RNG rng(0);

282
    CV_Assert (params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS);
283 284 285 286 287 288 289 290 291 292 293
    int maxCount = (params.termCrit.type & TermCriteria::COUNT) ? params.termCrit.maxCount : INT_MAX;
    double epsilon = (params.termCrit.type & TermCriteria::EPS) ? params.termCrit.epsilon : 0;

    double err = DBL_MAX;
    // Stochastic gradient descent SVM
    for (int iter = 0; (iter < maxCount) && (err > epsilon); iter++)
    {
        int randomNumber = rng.uniform(0, extendedTrainSamplesCount);             //generate sample number

        Mat currentSample = extendedTrainSamples.row(randomNumber);

294
        float stepSize = params.initialStepSize * std::pow((1 + params.marginRegularization * params.initialStepSize * (float)iter), (-params.stepDecreasingPower));    //update stepSize
295

296
        updateWeights( currentSample, isPositive(trainResponses.at<float>(randomNumber)), stepSize, extendedWeights );
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318

        //average weights (only for ASGD model)
        if (params.svmsgdType == ASGD)
        {
            averageExtendedWeights = ((float)iter/ (1 + (float)iter)) * averageExtendedWeights  + extendedWeights / (1 + (float) iter);
            err = norm(averageExtendedWeights - previousWeights);
            averageExtendedWeights.copyTo(previousWeights);
        }
        else
        {
             err = norm(extendedWeights - previousWeights);
             extendedWeights.copyTo(previousWeights);
        }
    }

    if (params.svmsgdType == ASGD)
    {
        extendedWeights = averageExtendedWeights;
    }

    Rect roi(0, 0, featureCount, 1);
    weights_ = extendedWeights(roi);
319
    weights_ *= multiplier;
320

321
    CV_Assert(params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN);
322

323 324
    if (params.marginType == SOFT_MARGIN)
    {
Marina Noskova's avatar
Marina Noskova committed
325
        shift_ = extendedWeights.at<float>(featureCount) - static_cast<float>(weights_.dot(average));
326 327 328 329 330
    }
    else
    {
        shift_ = calcShift(trainSamples, trainResponses);
    }
331 332 333 334 335

    return true;
}


336 337 338 339 340 341
float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const
{
    float result = 0;
    cv::Mat samples = _samples.getMat();
    int nSamples = samples.rows;
    cv::Mat results;
joao.faro's avatar
joao.faro committed
342

343 344 345 346 347 348 349 350 351 352 353
    CV_Assert( samples.cols == weights_.cols && samples.type() == CV_32F );

    if( _results.needed() )
    {
        _results.create( nSamples, 1, samples.type() );
        results = _results.getMat();
    }
    else
    {
        CV_Assert( nSamples == 1 );
        results = Mat(1, 1, CV_32F, &result);
joao.faro's avatar
joao.faro committed
354
    }
355 356 357

    for (int sampleIndex = 0; sampleIndex < nSamples; sampleIndex++)
    {
358
        Mat currentSample = samples.row(sampleIndex);
Marina Noskova's avatar
Marina Noskova committed
359 360
        float criterion = static_cast<float>(currentSample.dot(weights_)) + shift_;
        results.at<float>(sampleIndex) = (criterion >= 0) ? 1.f : -1.f;
361 362 363
    }

    return result;
joao.faro's avatar
joao.faro committed
364 365
}

366
bool SVMSGDImpl::isClassifier() const
367
{
368
    return (params.svmsgdType == SGD || params.svmsgdType == ASGD)
369 370
            &&
            (params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN)
371
            &&
372
            (params.marginRegularization > 0) && (params.initialStepSize > 0) && (params.stepDecreasingPower >= 0);
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
}

bool SVMSGDImpl::isTrained() const
{
    return !weights_.empty();
}

void SVMSGDImpl::write(FileStorage& fs) const
{
    if( !isTrained() )
        CV_Error( CV_StsParseError, "SVMSGD model data is invalid, it hasn't been trained" );

    writeParams( fs );

    fs << "weights" << weights_;
388
    fs << "shift" << shift_;
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
}

void SVMSGDImpl::writeParams( FileStorage& fs ) const
{
    String SvmsgdTypeStr;

    switch (params.svmsgdType)
    {
    case SGD:
        SvmsgdTypeStr = "SGD";
        break;
    case ASGD:
        SvmsgdTypeStr = "ASGD";
        break;
    default:
404
        SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
405 406 407 408
    }

    fs << "svmsgdType" << SvmsgdTypeStr;

409 410 411 412 413 414 415 416 417 418 419
    String marginTypeStr;

    switch (params.marginType)
    {
    case SOFT_MARGIN:
        marginTypeStr = "SOFT_MARGIN";
        break;
    case HARD_MARGIN:
        marginTypeStr = "HARD_MARGIN";
        break;
    default:
420
        marginTypeStr = format("Unknown_%d", params.marginType);
421 422 423 424
    }

    fs << "marginType" << marginTypeStr;

425 426 427
    fs << "marginRegularization" << params.marginRegularization;
    fs << "initialStepSize" << params.initialStepSize;
    fs << "stepDecreasingPower" << params.stepDecreasingPower;
428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443

    fs << "term_criteria" << "{:";
    if( params.termCrit.type & TermCriteria::EPS )
        fs << "epsilon" << params.termCrit.epsilon;
    if( params.termCrit.type & TermCriteria::COUNT )
        fs << "iterations" << params.termCrit.maxCount;
    fs << "}";
}

void SVMSGDImpl::read(const FileNode& fn)
{
    clear();

    readParams(fn);

    fn["weights"] >> weights_;
444
    fn["shift"] >> shift_;
445 446 447 448 449
}

void SVMSGDImpl::readParams( const FileNode& fn )
{
    String svmsgdTypeStr = (String)fn["svmsgdType"];
450
    int svmsgdType =
451
            svmsgdTypeStr == "SGD" ? SGD :
452
                                     svmsgdTypeStr == "ASGD" ? ASGD : -1;
453

454
    if( svmsgdType < 0 )
455 456 457 458
        CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );

    params.svmsgdType = svmsgdType;

459
    String marginTypeStr = (String)fn["marginType"];
460
    int marginType =
461
            marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
462
                                     marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1;
463

464
    if( marginType < 0 )
465 466 467 468
        CV_Error( CV_StsParseError, "Missing or invalid margin type" );

    params.marginType = marginType;

469 470
    CV_Assert ( fn["marginRegularization"].isReal() );
    params.marginRegularization = (float)fn["marginRegularization"];
471

472 473
    CV_Assert ( fn["initialStepSize"].isReal() );
    params.initialStepSize = (float)fn["initialStepSize"];
474

475 476
    CV_Assert ( fn["stepDecreasingPower"].isReal() );
    params.stepDecreasingPower = (float)fn["stepDecreasingPower"];
477 478 479 480 481 482 483 484

    FileNode tcnode = fn["term_criteria"];
    if( !tcnode.empty() )
    {
        params.termCrit.epsilon = (double)tcnode["epsilon"];
        params.termCrit.maxCount = (int)tcnode["iterations"];
        params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
                (params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
joao.faro's avatar
joao.faro committed
485
    }
486
    else
487
        params.termCrit = TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 100000, FLT_EPSILON );
488

joao.faro's avatar
joao.faro committed
489 490
}

491 492 493
void SVMSGDImpl::clear()
{
    weights_.release();
494
    shift_ = 0;
joao.faro's avatar
joao.faro committed
495
}
496 497 498 499 500 501


SVMSGDImpl::SVMSGDImpl()
{
    clear();

502 503
    params.svmsgdType = -1;
    params.marginType = -1;
504 505

    // Parameters for learning
506 507 508
    params.marginRegularization = 0;                              // regularization
    params.initialStepSize = 0;                        // learning rate (ideally should be large at beginning and decay each iteration)
    params.stepDecreasingPower = 0;
509 510 511 512 513

    TermCriteria _termCrit(TermCriteria::COUNT + TermCriteria::EPS, 0, 0);
    params.termCrit = _termCrit;
}

514
void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
515
{
516
    switch (svmsgdType)
517 518 519
    {
    case SGD:
        params.svmsgdType = SGD;
520
        params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
521
                            (marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
522 523 524
        params.marginRegularization = 0.0001f;
        params.initialStepSize = 0.05f;
        params.stepDecreasingPower = 1.f;
525
        params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
526 527 528 529
        break;

    case ASGD:
        params.svmsgdType = ASGD;
530
        params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
531
                            (marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
532 533 534
        params.marginRegularization = 0.00001f;
        params.initialStepSize = 0.05f;
        params.stepDecreasingPower = 0.75f;
535
        params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
536 537 538 539 540 541 542 543
        break;

    default:
        CV_Error( CV_StsParseError, "SVMSGD model data is invalid" );
    }
}
}   //ml
}   //cv