gpc_train.cpp 2.49 KB
Newer Older
1 2 3
#include "opencv2/optflow.hpp"
#include <iostream>

4 5 6 7 8 9 10 11 12 13 14 15
/* This tool trains the forest for the Global Patch Collider and stores output to the "forest.yml.gz".
 */

using namespace cv;

const String keys = "{help h ?       |             | print this message}"
                    "{max-tree-depth |             | Maximum tree depth to stop partitioning}"
                    "{min-samples    |             | Minimum number of samples in the node to stop partitioning}"
                    "{descriptor-type|0            | Descriptor type. Set to 0 for quality, 1 for speed.}"
                    "{print-progress |             | Set to 0 to enable quiet mode, set to 1 to print progress}"
                    "{f forest       |forest.yml.gz| Path where to store resulting forest. It is recommended to use .yml.gz extension.}";

16 17
const int nTrees = 5;

18 19
static void fillInputImagesFromCommandLine( std::vector< String > &img1, std::vector< String > &img2, std::vector< String > &gt, int argc,
                                            const char **argv )
20
{
21
  for ( int i = 1, j = 0; i < argc; ++i )
22
  {
23 24 25 26 27 28 29 30 31
    if ( argv[i][0] == '-' )
      continue;
    if ( j % 3 == 0 )
      img1.push_back( argv[i] );
    if ( j % 3 == 1 )
      img2.push_back( argv[i] );
    if ( j % 3 == 2 )
      gt.push_back( argv[i] );
    ++j;
32
  }
33
}
34

35 36 37 38 39 40 41
int main( int argc, const char **argv )
{
  CommandLineParser parser( argc, argv, keys );
  parser.about( "Global Patch Collider training tool" );

  std::vector< String > img1, img2, gt;
  optflow::GPCTrainingParams params;
42

43 44 45 46 47 48 49 50 51 52 53 54
  if ( parser.has( "max-tree-depth" ) )
    params.maxTreeDepth = parser.get< unsigned >( "max-tree-depth" );
  if ( parser.has( "min-samples" ) )
    params.minNumberOfSamples = parser.get< unsigned >( "min-samples" );
  if ( parser.has( "descriptor-type" ) )
    params.descriptorType = parser.get< int >( "descriptor-type" );
  if ( parser.has( "print-progress" ) )
    params.printProgress = parser.get< unsigned >( "print-progress" ) != 0;

  fillInputImagesFromCommandLine( img1, img2, gt, argc, argv );

  if ( parser.has( "help" ) || img1.size() != img2.size() || img1.size() != gt.size() || img1.size() == 0 )
55
  {
56 57 58
    std::cerr << "\nUsage: " << argv[0] << " [params] ImageFrom1 ImageTo1 GroundTruth1 ... ImageFromN ImageToN GroundTruthN\n" << std::endl;
    parser.printMessage();
    return 1;
59 60
  }

61 62 63
  Ptr< optflow::GPCForest< nTrees > > forest = optflow::GPCForest< nTrees >::create();
  forest->train( img1, img2, gt, params );
  forest->save( parser.get< String >( "forest" ) );
64 65 66

  return 0;
}