• Vladislav Samsonov's avatar
    [GSoC] Implementation of the Global Patch Collider and demo for PCAFlow (#752) · ac62d70f
    Vladislav Samsonov authored
    * Minor fixes
    
    * Start adding correspondence finding
    
    * Added finding of correspondences using GPC
    
    * New evaluation tool for GPC
    
    * Changed default parameters
    
    * Display ground truth in the evaluation tool
    
    * Added training tool for MPI Sintel dataset
    
    * Added the training tool for Middlebury dataset
    
    * Added some OpenCL optimization
    
    * Added explanatory notes
    
    * Minor improvements: time measurements + little ocl optimization
    
    * Added demos
    
    * Fixed warnings
    
    * Make parameter struct assignable
    
    * Fix warning
    
    * Proper command line argument usage
    
    * Prettified training tool, added parameters
    
    * Fixed VS warning
    
    * Fixed VS warning
    
    * Using of compressed forest.yml.gz files by default to save space
    
    * Added OpenCL flag to the evaluation tool
    
    * Updated documentation
    
    * Major speed and memory improvements:
    1) Added new (optional) type of patch descriptors which are much faster. Retraining with option --descriptor-type=1 is required.
    2) Got rid of hash table for descriptors, less memory usage.
    
    * Fixed various floating point errors related to precision.
    SIMD for dot product, forest traversing is a little bit faster now.
    
    * Tolerant floating point comparison
    
    * Triplets
    
    * Added comment
    
    * Choosing negative sample among nearest neighbors
    
    * Fix warning
    
    * Usage of parallel_for_() in critical places. Performance improvments.
    
    * Simulated annealing heuristic
    
    * Moved OpenCL kernel to separate file
    
    * Moved implementation to source file
    
    * Added basic accuracy tests for GPC and PCAFlow
    
    * Fixing warnings
    
    * Test accuracy constraints were too strict
    
    * Test accuracy constraints were too strict
    
    * Make tests more lightweight
    ac62d70f
gpc_train_sintel.py 1.99 KB
import argparse
import glob
import os
import subprocess

FRAME_DIST = 2

assert (FRAME_DIST >= 1)


def execute(cmd):
    popen = subprocess.Popen(cmd,
                             stdout=subprocess.PIPE,
                             stderr=subprocess.PIPE)
    for stdout_line in iter(popen.stdout.readline, ''):
        print(stdout_line.rstrip())
    for stderr_line in iter(popen.stderr.readline, ''):
        print(stderr_line.rstrip())
    popen.stdout.close()
    popen.stderr.close()
    return_code = popen.wait()
    if return_code != 0:
        raise subprocess.CalledProcessError(return_code, cmd)


def main():
    parser = argparse.ArgumentParser(
        description='Train Global Patch Collider using MPI Sintel dataset')
    parser.add_argument(
        '--bin_path',
        help='Path to the training executable (example_optflow_gpc_train)',
        required=True)
    parser.add_argument('--dataset_path',
                        help='Path to the directory with frames',
                        required=True)
    parser.add_argument('--gt_path',
                        help='Path to the directory with ground truth flow',
                        required=True)
    parser.add_argument('--descriptor_type',
                        help='Descriptor type',
                        type=int,
                        default=0)
    args = parser.parse_args()
    seq = glob.glob(os.path.join(args.dataset_path, '*'))
    seq.sort()
    input_files = []
    for s in seq:
        seq_name = os.path.basename(s)
        frames = glob.glob(os.path.join(s, 'frame*.png'))
        frames.sort()
        for i in range(0, len(frames) - 1, FRAME_DIST):
            gt_flow = os.path.join(args.gt_path, seq_name,
                                   os.path.basename(frames[i])[0:-4] + '.flo')
            assert (os.path.isfile(gt_flow))
            input_files += [frames[i], frames[i + 1], gt_flow]
    execute([args.bin_path, '--descriptor-type=%d' % args.descriptor_type] + input_files)


if __name__ == '__main__':
    main()