// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
* This is the implementation of the paper
*    GMS: Grid-based Motion Statistics for Fast, Ultra-robust Feature Correspondence.
*    JiaWang Bian, Wen-Yan Lin, Yasuyuki Matsushita, Sai-Kit Yeung, Tan Dat Nguyen, Ming-Ming Cheng
*    IEEE CVPR, 2017
*    ProjectPage: http://jwbian.net/gms

#include "precomp.hpp"
#include <algorithm>

using namespace std;

namespace cv
namespace xfeatures2d
// 8 possible rotation and each one is 3 X 3
const int mRotationPatterns[8][9] = {

// 5 level scales
const double mScaleRatios[5] = { 1.0, 1.0 / 2, 1.0 / std::sqrt(2.0), std::sqrt(2.0), 2.0 };

class GMSMatcher
    // OpenCV Keypoints & Correspond Image Size & Nearest Neighbor Matches
    GMSMatcher(const vector<KeyPoint>& vkp1, const Size& size1, const vector<KeyPoint>& vkp2, const Size& size2,
               const vector<DMatch>& vDMatches, const double thresholdFactor) : mThresholdFactor(thresholdFactor)
        // Input initialize
        normalizePoints(vkp1, size1, mvP1);
        normalizePoints(vkp2, size2, mvP2);
        mNumberMatches = vDMatches.size();
        convertMatches(vDMatches, mvMatches);

        // Grid initialize
        mGridSizeLeft = Size(20, 20);
        mGridNumberLeft = mGridSizeLeft.width * mGridSizeLeft.height;

        // Initialize the neighbor of left grid
        mGridNeighborLeft = Mat::zeros(mGridNumberLeft, 9, CV_32SC1);
        initalizeNeighbors(mGridNeighborLeft, mGridSizeLeft);

    ~GMSMatcher() {}

    // Get Inlier Mask
    // Return number of inliers
    int getInlierMask(vector<bool> &vbInliers, const bool withRotation = false, const bool withScale = false);

    // Normalized Points
    vector<Point2f> mvP1, mvP2;

    // Matches
    vector<pair<int, int> > mvMatches;

    // Number of Matches
    size_t mNumberMatches;

    // Grid Size
    Size mGridSizeLeft, mGridSizeRight;
    int mGridNumberLeft;
    int mGridNumberRight;

    // x      : left grid idx
    // y      : right grid idx
    // value  : how many matches from idx_left to idx_right
    Mat mMotionStatistics;

    vector<int> mNumberPointsInPerCellLeft;

    // Inldex  : grid_idx_left
    // Value   : grid_idx_right
    vector<int> mCellPairs;

    // Every Matches has a cell-pair
    // first  : grid_idx_left
    // second : grid_idx_right
    vector<pair<int, int> > mvMatchPairs;

    // Inlier Mask for output
    vector<bool> mvbInlierMask;

    Mat mGridNeighborLeft;
    Mat mGridNeighborRight;

    double mThresholdFactor;

    // Assign Matches to Cell Pairs
    void assignMatchPairs(const int GridType);

    void convertMatches(const vector<DMatch> &vDMatches, vector<pair<int, int> > &vMatches);

    int getGridIndexLeft(const Point2f &pt, const int type);

    int getGridIndexRight(const Point2f &pt);

    vector<int> getNB9(const int idx, const Size& GridSize);

    void initalizeNeighbors(Mat &neighbor, const Size& GridSize);

    void normalizePoints(const vector<KeyPoint> &kp, const Size &size, vector<Point2f> &npts);

    // Run
    int run(const int rotationType);

    void setScale(const int scale);

    // Verify Cell Pairs
    void verifyCellPairs(const int rotationType);

void GMSMatcher::assignMatchPairs(const int gridType)
    for (size_t i = 0; i < mNumberMatches; i++)
        Point2f &lp = mvP1[mvMatches[i].first];
        Point2f &rp = mvP2[mvMatches[i].second];

        int lgidx = mvMatchPairs[i].first = getGridIndexLeft(lp, gridType);
        int rgidx = -1;

        if (gridType == 1)
            rgidx = mvMatchPairs[i].second = getGridIndexRight(rp);
            rgidx = mvMatchPairs[i].second;

        if (lgidx < 0 || rgidx < 0) continue;

        mMotionStatistics.at<int>(lgidx, rgidx)++;

// Convert OpenCV DMatch to Match (pair<int, int>)
void GMSMatcher::convertMatches(const vector<DMatch> &vDMatches, vector<pair<int, int> > &vMatches)
    for (size_t i = 0; i < mNumberMatches; i++)
        vMatches[i] = pair<int, int>(vDMatches[i].queryIdx, vDMatches[i].trainIdx);

int GMSMatcher::getGridIndexLeft(const Point2f &pt, const int type)
    int x = 0, y = 0;

    if (type == 1) {
        x = cvFloor(pt.x * mGridSizeLeft.width);
        y = cvFloor(pt.y * mGridSizeLeft.height);

    if (type == 2) {
        x = cvFloor(pt.x * mGridSizeLeft.width + 0.5);
        y = cvFloor(pt.y * mGridSizeLeft.height);

    if (type == 3) {
        x = cvFloor(pt.x * mGridSizeLeft.width);
        y = cvFloor(pt.y * mGridSizeLeft.height + 0.5);

    if (type == 4) {
        x = cvFloor(pt.x * mGridSizeLeft.width + 0.5);
        y = cvFloor(pt.y * mGridSizeLeft.height + 0.5);

    if (x >= mGridSizeLeft.width || y >= mGridSizeLeft.height)
        return -1;

    return x + y * mGridSizeLeft.width;

int GMSMatcher::getGridIndexRight(const Point2f &pt)
    int x = cvFloor(pt.x * mGridSizeRight.width);
    int y = cvFloor(pt.y * mGridSizeRight.height);

    return x + y * mGridSizeRight.width;

int GMSMatcher::getInlierMask(vector<bool> &vbInliers, const bool withRotation, const bool withScale)
    int max_inlier = 0;

    if (!withScale && !withRotation)
        max_inlier = run(1);
        vbInliers = mvbInlierMask;
        return max_inlier;

    if (withRotation && withScale)
        for (int scale = 0; scale < 5; scale++)
            for (int rotationType = 1; rotationType <= 8; rotationType++)
                int num_inlier = run(rotationType);

                if (num_inlier > max_inlier)
                    vbInliers = mvbInlierMask;
                    max_inlier = num_inlier;
        return max_inlier;

    if (withRotation && !withScale)
        for (int rotationType = 1; rotationType <= 8; rotationType++)
            int num_inlier = run(rotationType);

            if (num_inlier > max_inlier)
                vbInliers = mvbInlierMask;
                max_inlier = num_inlier;
        return max_inlier;

    if (!withRotation && withScale)
        for (int scale = 0; scale < 5; scale++)
            int num_inlier = run(1);

            if (num_inlier > max_inlier)
                vbInliers = mvbInlierMask;
                max_inlier = num_inlier;

        return max_inlier;

    return max_inlier;

// Get Neighbor 9
vector<int> GMSMatcher::getNB9(const int idx, const Size& gridSize)
    vector<int> NB9(9, -1);

    int idx_x = idx % gridSize.width;
    int idx_y = idx / gridSize.width;

    for (int yi = -1; yi <= 1; yi++)
        for (int xi = -1; xi <= 1; xi++)
            int idx_xx = idx_x + xi;
            int idx_yy = idx_y + yi;

            if (idx_xx < 0 || idx_xx >= gridSize.width || idx_yy < 0 || idx_yy >= gridSize.height)

            NB9[xi + 4 + yi * 3] = idx_xx + idx_yy * gridSize.width;
    return NB9;

void GMSMatcher::initalizeNeighbors(Mat &neighbor, const Size& gridSize)
    for (int i = 0; i < neighbor.rows; i++)
        vector<int> NB9 = getNB9(i, gridSize);
        int *data = neighbor.ptr<int>(i);
        memcpy(data, &NB9[0], sizeof(int) * 9);

// Normalize Key Points to Range(0 - 1)
void GMSMatcher::normalizePoints(const vector<KeyPoint> &kp, const Size &size, vector<Point2f> &npts)
    const size_t numP = kp.size();
    const int width   = size.width;
    const int height  = size.height;

    for (size_t i = 0; i < numP; i++)
        npts[i].x = kp[i].pt.x / width;
        npts[i].y = kp[i].pt.y / height;

int GMSMatcher::run(const int rotationType)
    mvbInlierMask.assign(mNumberMatches, false);

    // Initialize Motion Statisctics
    mMotionStatistics = Mat::zeros(mGridNumberLeft, mGridNumberRight, CV_32SC1);
    mvMatchPairs.assign(mNumberMatches, pair<int, int>(0, 0));

    for (int gridType = 1; gridType <= 4; gridType++)
        // initialize
        mCellPairs.assign(mGridNumberLeft, -1);
        mNumberPointsInPerCellLeft.assign(mGridNumberLeft, 0);


        // Mark inliers
        for (size_t i = 0; i < mNumberMatches; i++)
            if (mCellPairs[mvMatchPairs[i].first] == mvMatchPairs[i].second)
                mvbInlierMask[i] = true;

    return (int) count(mvbInlierMask.begin(), mvbInlierMask.end(), true); //number of inliers

void GMSMatcher::setScale(const int scale)
    // Set Scale
    mGridSizeRight.width = cvRound(mGridSizeLeft.width  * mScaleRatios[scale]);
    mGridSizeRight.height = cvRound(mGridSizeLeft.height * mScaleRatios[scale]);
    mGridNumberRight = mGridSizeRight.width * mGridSizeRight.height;

    // Initialize the neighbor of right grid
    mGridNeighborRight = Mat::zeros(mGridNumberRight, 9, CV_32SC1);
    initalizeNeighbors(mGridNeighborRight, mGridSizeRight);

void GMSMatcher::verifyCellPairs(const int rotationType)
    const int *CurrentRP = mRotationPatterns[rotationType - 1];

    for (int i = 0; i < mGridNumberLeft; i++)
        if (sum(mMotionStatistics.row(i))[0] == 0)
            mCellPairs[i] = -1;

        int max_number = 0;
        for (int j = 0; j < mGridNumberRight; j++)
            int *value = mMotionStatistics.ptr<int>(i);
            if (value[j] > max_number)
                mCellPairs[i] = j;
                max_number = value[j];

        int idx_grid_rt = mCellPairs[i];

        const int *NB9_lt = mGridNeighborLeft.ptr<int>(i);
        const int *NB9_rt = mGridNeighborRight.ptr<int>(idx_grid_rt);

        int score = 0;
        double thresh = 0;
        int numpair = 0;

        for (size_t j = 0; j < 9; j++)
            int ll = NB9_lt[j];
            int rr = NB9_rt[CurrentRP[j] - 1];
            if (ll == -1 || rr == -1)

            score += mMotionStatistics.at<int>(ll, rr);
            thresh += mNumberPointsInPerCellLeft[ll];

        thresh = mThresholdFactor * std::sqrt(thresh / numpair);

        if (score < thresh)
            mCellPairs[i] = -2;

void matchGMS( const Size& size1, const Size& size2, const vector<KeyPoint>& keypoints1, const vector<KeyPoint>& keypoints2,
               const vector<DMatch>& matches1to2, vector<DMatch>& matchesGMS, const bool withRotation, const bool withScale,
               const double thresholdFactor )
    GMSMatcher gms(keypoints1, size1, keypoints2, size2, matches1to2, thresholdFactor);
    vector<bool> inlierMask;
    gms.getInlierMask(inlierMask, withRotation, withScale);

    for (size_t i = 0; i < inlierMask.size(); i++) {
        if (inlierMask[i])

} //namespace xfeatures2d
} //namespace cv