digits_adjust.py 4.27 KB
Newer Older
1
#!/usr/bin/env python
2

3 4 5 6 7 8 9
'''
Digit recognition adjustment.
Grid search is used to find the best parameters for SVM and KNearest classifiers.
SVM adjustment follows the guidelines given in
http://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf

Usage:
10
  digits_adjust.py [--model {svm|knearest}]
11 12 13 14 15

  --model {svm|knearest}   - select the classifier (SVM is the default)

'''

16 17 18 19 20 21 22 23
# Python 2/3 compatibility
from __future__ import print_function
import sys
PY3 = sys.version_info[0] == 3

if PY3:
    xrange = range

24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
import numpy as np
import cv2
from multiprocessing.pool import ThreadPool

from digits import *

def cross_validate(model_class, params, samples, labels, kfold = 3, pool = None):
    n = len(samples)
    folds = np.array_split(np.arange(n), kfold)
    def f(i):
        model = model_class(**params)
        test_idx = folds[i]
        train_idx = list(folds)
        train_idx.pop(i)
        train_idx = np.hstack(train_idx)
        train_samples, train_labels = samples[train_idx], labels[train_idx]
        test_samples, test_labels = samples[test_idx], labels[test_idx]
        model.train(train_samples, train_labels)
        resp = model.predict(test_samples)
        score = (resp != test_labels).mean()
44
        print(".", end='')
45 46
        return score
    if pool is None:
47
        scores = list(map(f, xrange(kfold)))
48 49 50 51 52 53
    else:
        scores = pool.map(f, xrange(kfold))
    return np.mean(scores)


class App(object):
54 55
    def __init__(self):
        self._samples, self._labels = self.preprocess()
56 57 58 59 60

    def preprocess(self):
        digits, labels = load_digits(DIGITS_FN)
        shuffle = np.random.permutation(len(digits))
        digits, labels = digits[shuffle], labels[shuffle]
61
        digits2 = list(map(deskew, digits))
62 63 64 65
        samples = preprocess_hog(digits2)
        return samples, labels

    def get_dataset(self):
66
        return self._samples, self._labels
67 68

    def run_jobs(self, f, jobs):
69 70
        pool = ThreadPool(processes=cv2.getNumberOfCPUs())
        ires = pool.imap_unordered(f, jobs)
71 72 73 74 75 76 77 78
        return ires

    def adjust_SVM(self):
        Cs = np.logspace(0, 10, 15, base=2)
        gammas = np.logspace(-7, 4, 15, base=2)
        scores = np.zeros((len(Cs), len(gammas)))
        scores[:] = np.nan

79
        print('adjusting SVM (may take a long time) ...')
80 81 82 83 84 85 86 87 88 89
        def f(job):
            i, j = job
            samples, labels = self.get_dataset()
            params = dict(C = Cs[i], gamma=gammas[j])
            score = cross_validate(SVM, params, samples, labels)
            return i, j, score

        ires = self.run_jobs(f, np.ndindex(*scores.shape))
        for count, (i, j, score) in enumerate(ires):
            scores[i, j] = score
90 91 92
            print('%d / %d (best error: %.2f %%, last: %.2f %%)' %
                  (count+1, scores.size, np.nanmin(scores)*100, score*100))
        print(scores)
93

94
        print('writing score table to "svm_scores.npz"')
95 96 97 98
        np.savez('svm_scores.npz', scores=scores, Cs=Cs, gammas=gammas)

        i, j = np.unravel_index(scores.argmin(), scores.shape)
        best_params = dict(C = Cs[i], gamma=gammas[j])
99 100
        print('best params:', best_params)
        print('best error: %.2f %%' % (scores.min()*100))
101 102 103
        return best_params

    def adjust_KNearest(self):
104
        print('adjusting KNearest ...')
105 106 107 108 109 110 111 112
        def f(k):
            samples, labels = self.get_dataset()
            err = cross_validate(KNearest, dict(k=k), samples, labels)
            return k, err
        best_err, best_k = np.inf, -1
        for k, err in self.run_jobs(f, xrange(1, 9)):
            if err < best_err:
                best_err, best_k = err, k
113
            print('k = %d, error: %.2f %%' % (k, err*100))
114
        best_params = dict(k=best_k)
115
        print('best params:', best_params, 'err: %.2f' % (best_err*100))
116 117 118 119 120 121 122
        return best_params


if __name__ == '__main__':
    import getopt
    import sys

123
    print(__doc__)
124

125
    args, _ = getopt.getopt(sys.argv[1:], '', ['model='])
126 127 128 129
    args = dict(args)
    args.setdefault('--model', 'svm')
    args.setdefault('--env', '')
    if args['--model'] not in ['svm', 'knearest']:
130
        print('unknown model "%s"' % args['--model'])
131 132 133
        sys.exit(1)

    t = clock()
134
    app = App()
135 136 137 138
    if args['--model'] == 'knearest':
        app.adjust_KNearest()
    else:
        app.adjust_SVM()
139
    print('work time: %f s' % (clock() - t))