Commit 7ee69740 authored by Alexander Alekhin's avatar Alexander Alekhin

ml(test): test different samples layout of TrainData

parent 828cb428
...@@ -721,5 +721,68 @@ void CV_MLBaseTest::load( const char* filename ) ...@@ -721,5 +721,68 @@ void CV_MLBaseTest::load( const char* filename )
CV_Error( CV_StsNotImplemented, "invalid stat model name"); CV_Error( CV_StsNotImplemented, "invalid stat model name");
} }
TEST(TrainDataGet, layout_ROW_SAMPLE) // Details: #12236
{
cv::Mat test = cv::Mat::ones(150, 30, CV_32FC1) * 2;
test.col(3) += Scalar::all(3);
cv::Mat labels = cv::Mat::ones(150, 3, CV_32SC1) * 5;
labels.col(1) += 1;
cv::Ptr<cv::ml::TrainData> train_data = cv::ml::TrainData::create(test, cv::ml::ROW_SAMPLE, labels);
train_data->setTrainTestSplitRatio(0.9);
Mat tidx = train_data->getTestSampleIdx();
EXPECT_EQ((size_t)15, tidx.total());
Mat tresp = train_data->getTestResponses();
EXPECT_EQ(15, tresp.rows);
EXPECT_EQ(labels.cols, tresp.cols);
EXPECT_EQ(5, tresp.at<int>(0, 0)) << tresp;
EXPECT_EQ(6, tresp.at<int>(0, 1)) << tresp;
EXPECT_EQ(6, tresp.at<int>(14, 1)) << tresp;
EXPECT_EQ(5, tresp.at<int>(14, 2)) << tresp;
Mat tsamples = train_data->getTestSamples();
EXPECT_EQ(15, tsamples.rows);
EXPECT_EQ(test.cols, tsamples.cols);
EXPECT_EQ(2, tsamples.at<float>(0, 0)) << tsamples;
EXPECT_EQ(5, tsamples.at<float>(0, 3)) << tsamples;
EXPECT_EQ(2, tsamples.at<float>(14, test.cols - 1)) << tsamples;
EXPECT_EQ(5, tsamples.at<float>(14, 3)) << tsamples;
}
TEST(TrainDataGet, layout_COL_SAMPLE) // Details: #12236
{
cv::Mat test = cv::Mat::ones(30, 150, CV_32FC1) * 3;
test.row(3) += Scalar::all(3);
cv::Mat labels = cv::Mat::ones(3, 150, CV_32SC1) * 5;
labels.row(1) += 1;
cv::Ptr<cv::ml::TrainData> train_data = cv::ml::TrainData::create(test, cv::ml::COL_SAMPLE, labels);
train_data->setTrainTestSplitRatio(0.9);
Mat tidx = train_data->getTestSampleIdx();
EXPECT_EQ((size_t)15, tidx.total());
Mat tresp = train_data->getTestResponses(); // always row-based, transposed
EXPECT_EQ(15, tresp.rows);
EXPECT_EQ(labels.rows, tresp.cols);
EXPECT_EQ(5, tresp.at<int>(0, 0)) << tresp;
EXPECT_EQ(6, tresp.at<int>(0, 1)) << tresp;
EXPECT_EQ(6, tresp.at<int>(14, 1)) << tresp;
EXPECT_EQ(5, tresp.at<int>(14, 2)) << tresp;
Mat tsamples = train_data->getTestSamples();
EXPECT_EQ(15, tsamples.cols);
EXPECT_EQ(test.rows, tsamples.rows);
EXPECT_EQ(3, tsamples.at<float>(0, 0)) << tsamples;
EXPECT_EQ(6, tsamples.at<float>(3, 0)) << tsamples;
EXPECT_EQ(6, tsamples.at<float>(3, 14)) << tsamples;
EXPECT_EQ(3, tsamples.at<float>(test.rows - 1, 14)) << tsamples;
}
} // namespace } // namespace
/* End of file. */ /* End of file. */
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment