Commit a718f2e6 authored by berak's avatar berak

ml/python: fix digits samples(3.4)

parent c2096771
...@@ -70,13 +70,8 @@ def deskew(img): ...@@ -70,13 +70,8 @@ def deskew(img):
img = cv.warpAffine(img, M, (SZ, SZ), flags=cv.WARP_INVERSE_MAP | cv.INTER_LINEAR) img = cv.warpAffine(img, M, (SZ, SZ), flags=cv.WARP_INVERSE_MAP | cv.INTER_LINEAR)
return img return img
class StatModel(object):
def load(self, fn):
self.model.load(fn) # Known bug: https://github.com/opencv/opencv/issues/4969
def save(self, fn):
self.model.save(fn)
class KNearest(StatModel): class KNearest(object):
def __init__(self, k = 3): def __init__(self, k = 3):
self.k = k self.k = k
self.model = cv.ml.KNearest_create() self.model = cv.ml.KNearest_create()
...@@ -88,7 +83,13 @@ class KNearest(StatModel): ...@@ -88,7 +83,13 @@ class KNearest(StatModel):
_retval, results, _neigh_resp, _dists = self.model.findNearest(samples, self.k) _retval, results, _neigh_resp, _dists = self.model.findNearest(samples, self.k)
return results.ravel() return results.ravel()
class SVM(StatModel): def load(self, fn):
self.model = cv.ml.KNearest_load(fn)
def save(self, fn):
self.model.save(fn)
class SVM(object):
def __init__(self, C = 1, gamma = 0.5): def __init__(self, C = 1, gamma = 0.5):
self.model = cv.ml.SVM_create() self.model = cv.ml.SVM_create()
self.model.setGamma(gamma) self.model.setGamma(gamma)
...@@ -102,6 +103,11 @@ class SVM(StatModel): ...@@ -102,6 +103,11 @@ class SVM(StatModel):
def predict(self, samples): def predict(self, samples):
return self.model.predict(samples)[1].ravel() return self.model.predict(samples)[1].ravel()
def load(self, fn):
self.model = cv.ml.SVM_load(fn)
def save(self, fn):
self.model.save(fn)
def evaluate_model(model, digits, samples, labels): def evaluate_model(model, digits, samples, labels):
resp = model.predict(samples) resp = model.predict(samples)
......
#!/usr/bin/env python #!/usr/bin/env python
'''
Digit recognition from video.
Run digits.py before, to train and save the SVM.
Usage:
digits_video.py [{camera_id|video_file}]
'''
# Python 2/3 compatibility # Python 2/3 compatibility
from __future__ import print_function from __future__ import print_function
...@@ -28,11 +36,7 @@ def main(): ...@@ -28,11 +36,7 @@ def main():
print('"%s" not found, run digits.py first' % classifier_fn) print('"%s" not found, run digits.py first' % classifier_fn)
return return
if True: model = cv.ml.SVM_load(classifier_fn)
model = cv.ml.SVM_load(classifier_fn)
else:
model = cv.ml.SVM_create()
model.load_(classifier_fn) #Known bug: https://github.com/opencv/opencv/issues/4969
while True: while True:
_ret, frame = cap.read() _ret, frame = cap.read()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment