Commit 10b60f8d authored by Vadim Pisarevsky's avatar Vadim Pisarevsky

continuing refactoring ml samples; added "max vote" response to ANN_MLP.…

continuing refactoring ml samples; added "max vote" response to ANN_MLP. Probably, should make it in less hacky way
parent 223cdcd0
...@@ -228,9 +228,8 @@ public: ...@@ -228,9 +228,8 @@ public:
int n = inputs.rows, dn0 = n; int n = inputs.rows, dn0 = n;
CV_Assert( (type == CV_32F || type == CV_64F) && inputs.cols == layer_sizes[0] ); CV_Assert( (type == CV_32F || type == CV_64F) && inputs.cols == layer_sizes[0] );
_outputs.create(n, layer_sizes[l_count-1], type); int noutputs = layer_sizes[l_count-1];
Mat outputs;
Mat outputs = _outputs.getMat();
int min_buf_sz = 2*max_lsize; int min_buf_sz = 2*max_lsize;
int buf_sz = n*min_buf_sz; int buf_sz = n*min_buf_sz;
...@@ -242,9 +241,20 @@ public: ...@@ -242,9 +241,20 @@ public:
buf_sz = dn0*min_buf_sz; buf_sz = dn0*min_buf_sz;
} }
cv::AutoBuffer<double> _buf(buf_sz); cv::AutoBuffer<double> _buf(buf_sz+noutputs);
double* buf = _buf; double* buf = _buf;
if( !_outputs.needed() )
{
CV_Assert( n == 1 );
outputs = Mat(n, noutputs, type, buf + buf_sz);
}
else
{
_outputs.create(n, noutputs, type);
outputs = _outputs.getMat();
}
int dn = 0; int dn = 0;
for( int i = 0; i < n; i += dn ) for( int i = 0; i < n; i += dn )
{ {
...@@ -273,6 +283,13 @@ public: ...@@ -273,6 +283,13 @@ public:
scale_output( layer_in, layer_out ); scale_output( layer_in, layer_out );
} }
if( n == 1 )
{
int maxIdx[] = {0, 0};
minMaxIdx(outputs, 0, 0, 0, maxIdx);
return maxIdx[0] + maxIdx[1];
}
return 0.f; return 0.f;
} }
......
This diff is collapsed.
...@@ -229,22 +229,7 @@ static void find_decision_boundary_ANN( const Mat& layer_sizes ) ...@@ -229,22 +229,7 @@ static void find_decision_boundary_ANN( const Mat& layer_sizes )
Ptr<TrainData> tdata = TrainData::create(samples, ROW_SAMPLE, trainClasses); Ptr<TrainData> tdata = TrainData::create(samples, ROW_SAMPLE, trainClasses);
ann->train(tdata); ann->train(tdata);
predict_and_paint(ann, imgDst);
Mat testSample( 1, 2, CV_32FC1 );
Mat outputs;
for( int y = 0; y < img.rows; y += testStep )
{
for( int x = 0; x < img.cols; x += testStep )
{
testSample.at<float>(0) = (float)x;
testSample.at<float>(1) = (float)y;
ann->predict( testSample, outputs );
Point maxLoc;
minMaxLoc( outputs, 0, 0, 0, &maxLoc );
imgDst.at<Vec3b>(y, x) = classColors[maxLoc.x];
}
}
} }
#endif #endif
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment