kalman.py 3.23 KB
Newer Older
1
#!/usr/bin/env python
2 3 4 5 6 7 8 9 10 11 12 13
"""
   Tracking of rotating point.
   Rotation speed is constant.
   Both state and measurements vectors are 1D (a point angle),
   Measurement is the real point angle + gaussian noise.
   The real and the estimated points are connected with yellow line segment,
   the real and the measured points are connected with red line segment.
   (if Kalman filter works correctly,
    the yellow segment should be shorter than the red one).
   Pressing any key (except ESC) will reset the tracking with a different speed.
   Pressing ESC will stop the program.
"""
14 15 16 17 18 19 20
# Python 2/3 compatibility
import sys
PY3 = sys.version_info[0] == 3

if PY3:
    long = int

21
import cv2 as cv
22
from math import cos, sin, sqrt
23 24 25 26 27 28
import numpy as np

if __name__ == "__main__":

    img_height = 500
    img_width = 500
29
    kalman = cv.KalmanFilter(2, 1, 0)
30

31
    code = long(-1)
32

33
    cv.namedWindow("Kalman")
34 35 36 37

    while True:
        state = 0.1 * np.random.randn(2, 1)

38 39 40 41 42 43
        kalman.transitionMatrix = np.array([[1., 1.], [0., 1.]])
        kalman.measurementMatrix = 1. * np.ones((1, 2))
        kalman.processNoiseCov = 1e-5 * np.eye(2)
        kalman.measurementNoiseCov = 1e-1 * np.ones((1, 1))
        kalman.errorCovPost = 1. * np.ones((2, 2))
        kalman.statePost = 0.1 * np.random.randn(2, 1)
44 45 46 47

        while True:
            def calc_point(angle):
                return (np.around(img_width/2 + img_width/3*cos(angle), 0).astype(int),
48
                        np.around(img_height/2 - img_width/3*sin(angle), 1).astype(int))
49 50 51 52 53 54 55 56

            state_angle = state[0, 0]
            state_pt = calc_point(state_angle)

            prediction = kalman.predict()
            predict_angle = prediction[0, 0]
            predict_pt = calc_point(predict_angle)

57
            measurement = kalman.measurementNoiseCov * np.random.randn(1, 1)
58 59

            # generate measurement
60
            measurement = np.dot(kalman.measurementMatrix, state) + measurement
61 62 63 64 65 66

            measurement_angle = measurement[0, 0]
            measurement_pt = calc_point(measurement_angle)

            # plot points
            def draw_cross(center, color, d):
67
                cv.line(img,
68
                         (center[0] - d, center[1] - d), (center[0] + d, center[1] + d),
69 70
                         color, 1, cv.LINE_AA, 0)
                cv.line(img,
71
                         (center[0] + d, center[1] - d), (center[0] - d, center[1] + d),
72
                         color, 1, cv.LINE_AA, 0)
73 74 75 76 77 78

            img = np.zeros((img_height, img_width, 3), np.uint8)
            draw_cross(np.int32(state_pt), (255, 255, 255), 3)
            draw_cross(np.int32(measurement_pt), (0, 0, 255), 3)
            draw_cross(np.int32(predict_pt), (0, 255, 0), 3)

79 80
            cv.line(img, state_pt, measurement_pt, (0, 0, 255), 3, cv.LINE_AA, 0)
            cv.line(img, state_pt, predict_pt, (0, 255, 255), 3, cv.LINE_AA, 0)
81 82 83

            kalman.correct(measurement)

84
            process_noise = sqrt(kalman.processNoiseCov[0,0]) * np.random.randn(2, 1)
85
            state = np.dot(kalman.transitionMatrix, state) + process_noise
86

87
            cv.imshow("Kalman", img)
88

89
            code = cv.waitKey(100)
90 91 92
            if code != -1:
                break

93
        if code in [27, ord('q'), ord('Q')]:
94 95
            break

96
    cv.destroyWindow("Kalman")