mosse.py 6.24 KB
Newer Older
1
#!/usr/bin/env python
2

3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
'''
MOSSE tracking sample

This sample implements correlation-based tracking approach, described in [1].

Usage:
  mosse.py [--pause] [<video source>]

  --pause  -  Start with playback paused at the first video frame.
              Useful for tracking target selection.

  Draw rectangles around objects with a mouse to track them.

Keys:
  SPACE    - pause video
  c        - clear targets

[1] David S. Bolme et al. "Visual Object Tracking using Adaptive Correlation Filters"
21
    http://www.cs.colostate.edu/~draper/papers/bolme_cvpr10.pdf
22 23
'''

24 25 26 27 28 29 30 31
# Python 2/3 compatibility
from __future__ import print_function
import sys
PY3 = sys.version_info[0] == 3

if PY3:
    xrange = range

32
import numpy as np
33
import cv2 as cv
34 35 36 37 38 39 40 41 42 43 44 45 46
from common import draw_str, RectSelector
import video

def rnd_warp(a):
    h, w = a.shape[:2]
    T = np.zeros((2, 3))
    coef = 0.2
    ang = (np.random.rand()-0.5)*coef
    c, s = np.cos(ang), np.sin(ang)
    T[:2, :2] = [[c,-s], [s, c]]
    T[:2, :2] += (np.random.rand(2, 2) - 0.5)*coef
    c = (w/2, h/2)
    T[:,2] = c - np.dot(T[:2, :2], c)
47
    return cv.warpAffine(a, T, (w, h), borderMode = cv.BORDER_REFLECT)
48 49 50 51 52 53 54 55 56 57 58 59 60

def divSpec(A, B):
    Ar, Ai = A[...,0], A[...,1]
    Br, Bi = B[...,0], B[...,1]
    C = (Ar+1j*Ai)/(Br+1j*Bi)
    C = np.dstack([np.real(C), np.imag(C)]).copy()
    return C

eps = 1e-5

class MOSSE:
    def __init__(self, frame, rect):
        x1, y1, x2, y2 = rect
61
        w, h = map(cv.getOptimalDFTSize, [x2-x1, y2-y1])
62 63 64
        x1, y1 = (x1+x2-w)//2, (y1+y2-h)//2
        self.pos = x, y = x1+0.5*(w-1), y1+0.5*(h-1)
        self.size = w, h
65
        img = cv.getRectSubPix(frame, (w, h), (x, y))
66

67
        self.win = cv.createHanningWindow((w, h), cv.CV_32F)
68 69
        g = np.zeros((h, w), np.float32)
        g[h//2, w//2] = 1
70
        g = cv.GaussianBlur(g, (-1, -1), 2.0)
71
        g /= g.max()
72

73
        self.G = cv.dft(g, flags=cv.DFT_COMPLEX_OUTPUT)
74 75
        self.H1 = np.zeros_like(self.G)
        self.H2 = np.zeros_like(self.G)
tribta's avatar
tribta committed
76
        for _i in xrange(128):
77
            a = self.preprocess(rnd_warp(img))
78 79 80
            A = cv.dft(a, flags=cv.DFT_COMPLEX_OUTPUT)
            self.H1 += cv.mulSpectrums(self.G, A, 0, conjB=True)
            self.H2 += cv.mulSpectrums(     A, A, 0, conjB=True)
81 82 83 84 85
        self.update_kernel()
        self.update(frame)

    def update(self, frame, rate = 0.125):
        (x, y), (w, h) = self.pos, self.size
86
        self.last_img = img = cv.getRectSubPix(frame, (w, h), (x, y))
87 88 89 90 91
        img = self.preprocess(img)
        self.last_resp, (dx, dy), self.psr = self.correlate(img)
        self.good = self.psr > 8.0
        if not self.good:
            return
92

93
        self.pos = x+dx, y+dy
94
        self.last_img = img = cv.getRectSubPix(frame, (w, h), self.pos)
95 96
        img = self.preprocess(img)

97 98 99
        A = cv.dft(img, flags=cv.DFT_COMPLEX_OUTPUT)
        H1 = cv.mulSpectrums(self.G, A, 0, conjB=True)
        H2 = cv.mulSpectrums(     A, A, 0, conjB=True)
100 101 102 103 104 105
        self.H1 = self.H1 * (1.0-rate) + H1 * rate
        self.H2 = self.H2 * (1.0-rate) + H2 * rate
        self.update_kernel()

    @property
    def state_vis(self):
106
        f = cv.idft(self.H, flags=cv.DFT_SCALE | cv.DFT_REAL_OUTPUT )
107 108 109 110 111 112 113 114 115 116 117 118
        h, w = f.shape
        f = np.roll(f, -h//2, 0)
        f = np.roll(f, -w//2, 1)
        kernel = np.uint8( (f-f.min()) / f.ptp()*255 )
        resp = self.last_resp
        resp = np.uint8(np.clip(resp/resp.max(), 0, 1)*255)
        vis = np.hstack([self.last_img, kernel, resp])
        return vis

    def draw_state(self, vis):
        (x, y), (w, h) = self.pos, self.size
        x1, y1, x2, y2 = int(x-0.5*w), int(y-0.5*h), int(x+0.5*w), int(y+0.5*h)
119
        cv.rectangle(vis, (x1, y1), (x2, y2), (0, 0, 255))
120
        if self.good:
121
            cv.circle(vis, (int(x), int(y)), 2, (0, 0, 255), -1)
122
        else:
123 124
            cv.line(vis, (x1, y1), (x2, y2), (0, 0, 255))
            cv.line(vis, (x2, y1), (x1, y2), (0, 0, 255))
125 126 127 128 129 130 131 132
        draw_str(vis, (x1, y2+16), 'PSR: %.2f' % self.psr)

    def preprocess(self, img):
        img = np.log(np.float32(img)+1.0)
        img = (img-img.mean()) / (img.std()+eps)
        return img*self.win

    def correlate(self, img):
133 134
        C = cv.mulSpectrums(cv.dft(img, flags=cv.DFT_COMPLEX_OUTPUT), self.H, 0, conjB=True)
        resp = cv.idft(C, flags=cv.DFT_SCALE | cv.DFT_REAL_OUTPUT)
135
        h, w = resp.shape
136
        _, mval, _, (mx, my) = cv.minMaxLoc(resp)
137
        side_resp = resp.copy()
138
        cv.rectangle(side_resp, (mx-5, my-5), (mx+5, my+5), 0, -1)
139 140 141 142 143 144 145 146 147 148 149 150
        smean, sstd = side_resp.mean(), side_resp.std()
        psr = (mval-smean) / (sstd+eps)
        return resp, (mx-w//2, my-h//2), psr

    def update_kernel(self):
        self.H = divSpec(self.H1, self.H2)
        self.H[...,1] *= -1

class App:
    def __init__(self, video_src, paused = False):
        self.cap = video.create_capture(video_src)
        _, self.frame = self.cap.read()
151
        cv.imshow('frame', self.frame)
152 153 154 155 156
        self.rect_sel = RectSelector('frame', self.onrect)
        self.trackers = []
        self.paused = paused

    def onrect(self, rect):
157
        frame_gray = cv.cvtColor(self.frame, cv.COLOR_BGR2GRAY)
158 159
        tracker = MOSSE(frame_gray, rect)
        self.trackers.append(tracker)
160

161 162 163 164 165 166
    def run(self):
        while True:
            if not self.paused:
                ret, self.frame = self.cap.read()
                if not ret:
                    break
167
                frame_gray = cv.cvtColor(self.frame, cv.COLOR_BGR2GRAY)
168 169
                for tracker in self.trackers:
                    tracker.update(frame_gray)
170

171 172 173 174
            vis = self.frame.copy()
            for tracker in self.trackers:
                tracker.draw_state(vis)
            if len(self.trackers) > 0:
175
                cv.imshow('tracker state', self.trackers[-1].state_vis)
176
            self.rect_sel.draw(vis)
177

178 179
            cv.imshow('frame', vis)
            ch = cv.waitKey(10)
180 181 182 183 184 185 186
            if ch == 27:
                break
            if ch == ord(' '):
                self.paused = not self.paused
            if ch == ord('c'):
                self.trackers = []

187

188
if __name__ == '__main__':
189
    print (__doc__)
190 191 192
    import sys, getopt
    opts, args = getopt.getopt(sys.argv[1:], '', ['pause'])
    opts = dict(opts)
193 194 195 196
    try:
        video_src = args[0]
    except:
        video_src = '0'
197 198

    App(video_src, paused = '--pause' in opts).run()