Commit a0af8707 authored by Alexander Mordvintsev's avatar Alexander Mordvintsev

added CV_OUT to CvANN_MLP::predict

python cv2 MLP sample done
parent 622bd422
...@@ -1926,7 +1926,7 @@ public: ...@@ -1926,7 +1926,7 @@ public:
CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(), CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
int flags=0 ); int flags=0 );
CV_WRAP virtual float predict( const cv::Mat& inputs, cv::Mat& outputs ) const; CV_WRAP virtual float predict( const cv::Mat& inputs, CV_OUT cv::Mat& outputs ) const;
#endif #endif
CV_WRAP virtual void clear(); CV_WRAP virtual void clear();
......
...@@ -90,6 +90,7 @@ class SVM(LetterStatModel): ...@@ -90,6 +90,7 @@ class SVM(LetterStatModel):
def predict(self, samples): def predict(self, samples):
return np.float32( [self.model.predict(s) for s in samples] ) return np.float32( [self.model.predict(s) for s in samples] )
class MLP(LetterStatModel): class MLP(LetterStatModel):
def __init__(self): def __init__(self):
self.model = cv2.ANN_MLP() self.model = cv2.ANN_MLP()
...@@ -109,10 +110,8 @@ class MLP(LetterStatModel): ...@@ -109,10 +110,8 @@ class MLP(LetterStatModel):
self.model.train(samples, np.float32(new_responses), None, params = params) self.model.train(samples, np.float32(new_responses), None, params = params)
def predict(self, samples): def predict(self, samples):
pass ret, resp = self.model.predict(samples)
#return np.float32( [self.model.predict(s) for s in samples] ) return resp.argmax(-1)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment