Unverified Commit 7ea5029a authored by Vadim Pisarevsky's avatar Vadim Pisarevsky Committed by GitHub

Grabcut with frozen models (#11339)

* model is not learned when grabcut is called with GC_EVAL

* fixed test, was writing to wrong file.

* modified patch by Iwan Paolucci; added GC_EVAL_FREEZE_MODEL in addition to GC_EVAL (which semantics is retained)
parent 64a6b121
...@@ -386,7 +386,9 @@ enum GrabCutModes { ...@@ -386,7 +386,9 @@ enum GrabCutModes {
automatically initialized with GC_BGD .*/ automatically initialized with GC_BGD .*/
GC_INIT_WITH_MASK = 1, GC_INIT_WITH_MASK = 1,
/** The value means that the algorithm should just resume. */ /** The value means that the algorithm should just resume. */
GC_EVAL = 2 GC_EVAL = 2,
/** The value means that the algorithm should just run the grabCut algorithm (a single iteration) with the fixed model */
GC_EVAL_FREEZE_MODEL = 3
}; };
//! distanceTransform algorithm flags //! distanceTransform algorithm flags
......
...@@ -557,7 +557,10 @@ void cv::grabCut( InputArray _img, InputOutputArray _mask, Rect rect, ...@@ -557,7 +557,10 @@ void cv::grabCut( InputArray _img, InputOutputArray _mask, Rect rect,
if( iterCount <= 0) if( iterCount <= 0)
return; return;
if( mode == GC_EVAL ) if( mode == GC_EVAL_FREEZE_MODEL )
iterCount = 1;
if( mode == GC_EVAL || mode == GC_EVAL_FREEZE_MODEL )
checkMask( img, mask ); checkMask( img, mask );
const double gamma = 50; const double gamma = 50;
...@@ -571,7 +574,8 @@ void cv::grabCut( InputArray _img, InputOutputArray _mask, Rect rect, ...@@ -571,7 +574,8 @@ void cv::grabCut( InputArray _img, InputOutputArray _mask, Rect rect,
{ {
GCGraph<double> graph; GCGraph<double> graph;
assignGMMsComponents( img, mask, bgdGMM, fgdGMM, compIdxs ); assignGMMsComponents( img, mask, bgdGMM, fgdGMM, compIdxs );
learnGMMs( img, mask, compIdxs, bgdGMM, fgdGMM ); if( mode != GC_EVAL_FREEZE_MODEL )
learnGMMs( img, mask, compIdxs, bgdGMM, fgdGMM );
constructGCGraph(img, mask, bgdGMM, fgdGMM, lambda, leftW, upleftW, upW, uprightW, graph ); constructGCGraph(img, mask, bgdGMM, fgdGMM, lambda, leftW, upleftW, upW, uprightW, graph );
estimateSegmentation( graph, mask ); estimateSegmentation( graph, mask );
} }
......
...@@ -92,7 +92,9 @@ void CV_GrabcutTest::run( int /* start_from */) ...@@ -92,7 +92,9 @@ void CV_GrabcutTest::run( int /* start_from */)
mask = Scalar(0); mask = Scalar(0);
Mat bgdModel, fgdModel; Mat bgdModel, fgdModel;
grabCut( img, mask, rect, bgdModel, fgdModel, 0, GC_INIT_WITH_RECT ); grabCut( img, mask, rect, bgdModel, fgdModel, 0, GC_INIT_WITH_RECT );
grabCut( img, mask, rect, bgdModel, fgdModel, 2, GC_EVAL ); bgdModel.copyTo(exp_bgdModel);
fgdModel.copyTo(exp_fgdModel);
grabCut( img, mask, rect, bgdModel, fgdModel, 2, GC_EVAL_FREEZE_MODEL );
// Multiply images by 255 for more visuality of test data. // Multiply images by 255 for more visuality of test data.
if( mask_prob.empty() ) if( mask_prob.empty() )
...@@ -105,12 +107,20 @@ void CV_GrabcutTest::run( int /* start_from */) ...@@ -105,12 +107,20 @@ void CV_GrabcutTest::run( int /* start_from */)
exp_mask1 = (mask & 1) * 255; exp_mask1 = (mask & 1) * 255;
imwrite(string(ts->get_data_path()) + "grabcut/exp_mask1.png", exp_mask1); imwrite(string(ts->get_data_path()) + "grabcut/exp_mask1.png", exp_mask1);
} }
if (!verify((mask & 1) * 255, exp_mask1)) if (!verify((mask & 1) * 255, exp_mask1))
{ {
ts->set_failed_test_info(cvtest::TS::FAIL_MISMATCH); ts->set_failed_test_info(cvtest::TS::FAIL_MISMATCH);
return; return;
} }
// The model should not be changed after calling with GC_EVAL_FREEZE_MODEL
double sumBgdModel = cv::sum(cv::abs(bgdModel) - cv::abs(exp_bgdModel))[0];
double sumFgdModel = cv::sum(cv::abs(fgdModel) - cv::abs(exp_fgdModel))[0];
if (sumBgdModel >= 0.1 || sumFgdModel >= 0.1)
{
ts->printf(cvtest::TS::LOG, "sumBgdModel = %f, sumFgdModel = %f\n", sumBgdModel, sumFgdModel);
ts->set_failed_test_info(cvtest::TS::FAIL_MISMATCH);
return;
}
mask = mask_prob; mask = mask_prob;
bgdModel.release(); bgdModel.release();
...@@ -124,7 +134,6 @@ void CV_GrabcutTest::run( int /* start_from */) ...@@ -124,7 +134,6 @@ void CV_GrabcutTest::run( int /* start_from */)
exp_mask2 = (mask & 1) * 255; exp_mask2 = (mask & 1) * 255;
imwrite(string(ts->get_data_path()) + "grabcut/exp_mask2.png", exp_mask2); imwrite(string(ts->get_data_path()) + "grabcut/exp_mask2.png", exp_mask2);
} }
if (!verify((mask & 1) * 255, exp_mask2)) if (!verify((mask & 1) * 255, exp_mask2))
{ {
ts->set_failed_test_info(cvtest::TS::FAIL_MISMATCH); ts->set_failed_test_info(cvtest::TS::FAIL_MISMATCH);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment