Commit d636e112 authored by Alexander Mordvintsev's avatar Alexander Mordvintsev

removed ANN digits recognition

added deskew for SVN and KNearest recognition sample
parent f2e78eed
'''
Neural network digit recognition sample.
SVN and KNearest digit recognition.
Sample loads a dataset of handwritten digits from 'digits.png'.
Then it trains a SVN and KNearest classifiers on it and evaluates
their accuracy. Moment-based image deskew is used to improve
the recognition accuracy.
Usage:
digits.py
Sample loads a dataset of handwritten digits from 'digits.png'.
Then it trains a neural network classifier on it and evaluates
its classification accuracy.
'''
import numpy as np
import cv2
from common import mosaic
def unroll_responses(responses, class_n):
'''[1, 0, 2, ...] -> [[0, 1, 0], [1, 0, 0], [0, 0, 1], ...]'''
sample_n = len(responses)
new_responses = np.zeros((sample_n, class_n), np.float32)
new_responses[np.arange(sample_n), responses] = 1
return new_responses
from multiprocessing.pool import ThreadPool
from common import clock, mosaic
SZ = 20 # size of each digit is SZ x SZ
CLASS_N = 10
digits_img = cv2.imread('digits.png', 0)
# prepare dataset
h, w = digits_img.shape
digits = [np.hsplit(row, w/SZ) for row in np.vsplit(digits_img, h/SZ)]
digits = np.float32(digits).reshape(-1, SZ*SZ)
N = len(digits)
labels = np.repeat(np.arange(CLASS_N), N/CLASS_N)
# split it onto train and test subsets
shuffle = np.random.permutation(N)
train_n = int(0.9*N)
digits_train, digits_test = np.split(digits[shuffle], [train_n])
labels_train, labels_test = np.split(labels[shuffle], [train_n])
# train model
model = cv2.ANN_MLP()
layer_sizes = np.int32([SZ*SZ, 25, CLASS_N])
model.create(layer_sizes)
params = dict( term_crit = (cv2.TERM_CRITERIA_COUNT, 100, 0.01),
train_method = cv2.ANN_MLP_TRAIN_PARAMS_BACKPROP,
bp_dw_scale = 0.001,
bp_moment_scale = 0.0 )
print 'training...'
labels_train_unrolled = unroll_responses(labels_train, CLASS_N)
model.train(digits_train, labels_train_unrolled, None, params=params)
model.save('dig_nn.dat')
model.load('dig_nn.dat')
def evaluate(model, samples, labels):
'''Evaluates classifier preformance on a given labeled samples set.'''
ret, resp = model.predict(samples)
resp = resp.argmax(-1)
error_mask = (resp == labels)
accuracy = error_mask.mean()
return accuracy, error_mask
# evaluate model
train_accuracy, _ = evaluate(model, digits_train, labels_train)
print 'train accuracy: ', train_accuracy
test_accuracy, test_error_mask = evaluate(model, digits_test, labels_test)
print 'test accuracy: ', test_accuracy
# visualize test results
vis = []
for img, flag in zip(digits_test, test_error_mask):
img = np.uint8(img).reshape(SZ, SZ)
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if not flag:
img[...,:2] = 0
vis.append(img)
vis = mosaic(25, vis)
cv2.imshow('test', vis)
cv2.waitKey()
def load_digits(fn):
print 'loading "%s" ...' % fn
digits_img = cv2.imread(fn, 0)
h, w = digits_img.shape
digits = [np.hsplit(row, w/SZ) for row in np.vsplit(digits_img, h/SZ)]
digits = np.array(digits).reshape(-1, SZ, SZ)
labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N)
return digits, labels
def deskew(img):
m = cv2.moments(img)
if abs(m['mu02']) < 1e-2:
return img.copy()
skew = m['mu11']/m['mu02']
M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]])
img = cv2.warpAffine(img, M, (SZ, SZ), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)
return img
class StatModel(object):
def load(self, fn):
self.model.load(fn)
def save(self, fn):
self.model.save(fn)
class KNearest(StatModel):
def __init__(self, k = 3):
self.k = k
self.model = cv2.KNearest()
def train(self, samples, responses):
self.model = cv2.KNearest()
self.model.train(samples, responses)
def predict(self, samples):
retval, results, neigh_resp, dists = self.model.find_nearest(samples, self.k)
return results.ravel()
class SVM(StatModel):
def __init__(self, C = 1, gamma = 0.5):
self.params = dict( kernel_type = cv2.SVM_RBF,
svm_type = cv2.SVM_C_SVC,
C = C,
gamma = gamma )
self.model = cv2.SVM()
def train(self, samples, responses):
self.model = cv2.SVM()
self.model.train(samples, responses, params = self.params)
def predict(self, samples):
return self.model.predict_all(samples).ravel()
def evaluate_model(model, digits, samples, labels):
resp = model.predict(samples)
err = (labels != resp).mean()
print 'error: %.2f %%' % (err*100)
confusion = np.zeros((10, 10), np.int32)
for i, j in zip(labels, resp):
confusion[i, j] += 1
print 'confusion matrix:'
print confusion
print
vis = []
for img, flag in zip(digits, resp == labels):
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if not flag:
img[...,:2] = 0
vis.append(img)
return mosaic(25, vis)
if __name__ == '__main__':
print __doc__
digits, labels = load_digits('digits.png')
print 'preprocessing...'
# shuffle digits
rand = np.random.RandomState(12345)
shuffle = rand.permutation(len(digits))
digits, labels = digits[shuffle], labels[shuffle]
digits2 = map(deskew, digits)
samples = np.float32(digits2).reshape(-1, SZ*SZ) / 255.0
train_n = int(0.9*len(samples))
cv2.imshow('test set', mosaic(25, digits[train_n:]))
digits_train, digits_test = np.split(digits2, [train_n])
samples_train, samples_test = np.split(samples, [train_n])
labels_train, labels_test = np.split(labels, [train_n])
print 'training KNearest...'
model = KNearest(k=1)
model.train(samples_train, labels_train)
vis = evaluate_model(model, digits_test, samples_test, labels_test)
cv2.imshow('KNearest test', vis)
print 'training SVM...'
model = SVM(C=4.66, gamma=0.08)
model.train(samples_train, labels_train)
vis = evaluate_model(model, digits_test, samples_test, labels_test)
cv2.imshow('SVM test', vis)
cv2.waitKey(0)
import numpy as np
import cv2
from multiprocessing.pool import ThreadPool
SZ = 20 # size of each digit is SZ x SZ
CLASS_N = 10
def load_base(fn):
print 'loading "%s" ...' % fn
digits_img = cv2.imread(fn, 0)
h, w = digits_img.shape
digits = [np.hsplit(row, w/SZ) for row in np.vsplit(digits_img, h/SZ)]
digits = np.array(digits).reshape(-1, SZ, SZ)
digits = np.float32(digits).reshape(-1, SZ*SZ) / 255.0
labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N)
return digits, labels
def cross_validate(model_class, params, samples, labels, kfold = 4, pool = None):
n = len(samples)
folds = np.array_split(np.arange(n), kfold)
def f(i):
model = model_class(**params)
test_idx = folds[i]
train_idx = list(folds)
train_idx.pop(i)
train_idx = np.hstack(train_idx)
train_samples, train_labels = samples[train_idx], labels[train_idx]
test_samples, test_labels = samples[test_idx], labels[test_idx]
model.train(train_samples, train_labels)
resp = model.predict(test_samples)
score = (resp != test_labels).mean()
print ".",
return score
if pool is None:
scores = map(f, xrange(kfold))
else:
scores = pool.map(f, xrange(kfold))
return np.mean(scores)
class StatModel(object):
def load(self, fn):
self.model.load(fn)
def save(self, fn):
self.model.save(fn)
class KNearest(StatModel):
def __init__(self, k = 3):
self.k = k
@staticmethod
def adjust(samples, labels):
print 'adjusting KNearest ...'
best_err, best_k = np.inf, -1
for k in xrange(1, 11):
err = cross_validate(KNearest, dict(k=k), samples, labels)
if err < best_err:
best_err, best_k = err, k
print 'k = %d, error: %.2f %%' % (k, err*100)
best_params = dict(k=best_k)
print 'best params:', best_params
return best_params
def train(self, samples, responses):
self.model = cv2.KNearest()
self.model.train(samples, responses)
def predict(self, samples):
retval, results, neigh_resp, dists = self.model.find_nearest(samples, self.k)
return results.ravel()
class SVM(StatModel):
def __init__(self, C = 1, gamma = 0.5):
self.params = dict( kernel_type = cv2.SVM_RBF,
svm_type = cv2.SVM_C_SVC,
C = C,
gamma = gamma )
@staticmethod
def adjust(samples, labels):
Cs = np.logspace(0, 5, 10, base=2)
gammas = np.logspace(-7, -2, 10, base=2)
scores = np.zeros((len(Cs), len(gammas)))
scores[:] = np.nan
print 'adjusting SVM (may take a long time) ...'
def f(job):
i, j = job
params = dict(C = Cs[i], gamma=gammas[j])
score = cross_validate(SVM, params, samples, labels)
scores[i, j] = score
nready = np.isfinite(scores).sum()
print '%d / %d (best error: %.2f %%, last: %.2f %%)' % (nready, scores.size, np.nanmin(scores)*100, score*100)
pool = ThreadPool(processes=cv2.getNumberOfCPUs())
pool.map(f, np.ndindex(*scores.shape))
print scores
i, j = np.unravel_index(scores.argmin(), scores.shape)
best_params = dict(C = Cs[i], gamma=gammas[j])
print 'best params:', best_params
print 'best error: %.2f %%' % (scores.min()*100)
return best_params
def train(self, samples, responses):
self.model = cv2.SVM()
self.model.train(samples, responses, params = self.params)
def predict(self, samples):
return self.model.predict_all(samples).ravel()
def main_adjustSVM(samples, labels):
params = SVM.adjust(samples, labels)
print 'training SVM on all samples ...'
model = SVN(**params)
model.train(samples, labels)
print 'saving "digits_svm.dat" ...'
model.save('digits_svm.dat')
def main_adjustKNearest(samples, labels):
params = KNearest.adjust(samples, labels)
def main_showSVM(samples, labels):
from common import mosaic
train_n = int(0.9*len(samples))
digits_train, digits_test = np.split(samples[shuffle], [train_n])
labels_train, labels_test = np.split(labels[shuffle], [train_n])
print 'training SVM ...'
model = SVM(C=2.16, gamma=0.0536)
model.train(digits_train, labels_train)
train_err = (model.predict(digits_train) != labels_train).mean()
resp_test = model.predict(digits_test)
test_err = (resp_test != labels_test).mean()
print 'train errors: %.2f %%' % (train_err*100)
print 'test errors: %.2f %%' % (test_err*100)
# visualize test results
vis = []
for img, flag in zip(digits_test, resp_test == labels_test):
img = np.uint8(img*255).reshape(SZ, SZ)
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if not flag:
img[...,:2] = 0
vis.append(img)
vis = mosaic(25, vis)
cv2.imshow('test', vis)
cv2.waitKey()
if __name__ == '__main__':
samples, labels = load_base('digits.png')
shuffle = np.random.permutation(len(samples))
samples, labels = samples[shuffle], labels[shuffle]
#main_adjustSVM(samples, labels)
#main_adjustKNearest(samples, labels)
main_showSVM(samples, labels)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment