Commit a98d6b62 authored by Alexander Mordvintsev's avatar Alexander Mordvintsev

exposed parallelized SVM prediction to python (predict_all)

parent e4d9d529
......@@ -488,7 +488,7 @@ public:
bool balanced=false );
virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
virtual float predict( const CvMat* samples, CvMat* results ) const;
virtual float predict( const CvMat* samples, CV_OUT CvMat* results ) const;
#ifndef SWIG
CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
......@@ -510,6 +510,7 @@ public:
CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
bool balanced=false);
CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
CV_WRAP_AS(predict_all) virtual void predict( cv::InputArray samples, cv::OutputArray results ) const;
#endif
CV_WRAP virtual int get_support_vector_count() const;
......
......@@ -2124,6 +2124,12 @@ float CvSVM::predict(const CvMat* samples, CV_OUT CvMat* results) const
return result;
}
void CvSVM::predict( cv::InputArray _samples, cv::OutputArray _results ) const
{
_results.create(_samples.size().height, 1, CV_32F);
CvMat samples = _samples.getMat(), results = _results.getMat();
predict(&samples, &results);
}
CvSVM::CvSVM( const Mat& _train_data, const Mat& _responses,
const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params )
......
......@@ -88,7 +88,7 @@ class SVM(LetterStatModel):
self.model.train(samples, responses, params = params)
def predict(self, samples):
return np.float32( [self.model.predict(s) for s in samples] )
return self.model.predict_all(samples).ravel()
class MLP(LetterStatModel):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment