Random decision forest-Examples of cvrtrees for opencv

Source: Internet
Author: User

This article introduces the use of random forest random trees in the MLL of opencv Machine Learning Library.

References:

1. breiman,
LEO (2001). "random forests ".Machine
Learning
 

2. Random Forests website

If you are not familiar with MLL, refer to this article: opencv Machine Learning Library MLL

Opencv machine learning algorithms are relatively simple: train --> predict

class CV_EXPORTS_W CvRTrees : public CvStatModel{public:    CV_WRAP CvRTrees();    virtual ~CvRTrees();    virtual bool train( const CvMat* trainData, int tflag,                        const CvMat* responses, const CvMat* varIdx=0,                        const CvMat* sampleIdx=0, const CvMat* varType=0,                        const CvMat* missingDataMask=0,                        CvRTParams params=CvRTParams() );    virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );    virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;    virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;    CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,                       const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),                       const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),                       const cv::Mat& missingDataMask=cv::Mat(),                       CvRTParams params=CvRTParams() );    CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;    CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;    CV_WRAP virtual cv::Mat getVarImportance();    CV_WRAP virtual void clear();    virtual const CvMat* get_var_importance();    virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,        const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;    virtual float calc_error( CvMLData* data, int type , std::vector<float>* resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}    virtual float get_train_error();    virtual void read( CvFileStorage* fs, CvFileNode* node );    virtual void write( CvFileStorage* fs, const char* name ) const;    CvMat* get_active_var_mask();    CvRNG* get_rng();    int get_tree_count() const;    CvForestTree* get_tree(int i) const;protected:    virtual std::string getName() const;    virtual bool grow_forest( const CvTermCriteria term_crit );    // array of the trees of the forest    CvForestTree** trees;    CvDTreeTrainData* data;    int ntrees;    int nclasses;    double oob_error;    CvMat* var_importance;    int nsamples;    cv::RNG* rng;    CvMat* active_var_mask;};

Use the cvrtrees class to classify handwritten data

// Example: Random forest (tree) Learning // Author: Toby breckon, toby.breckon@cranfield.ac.uk // copyright (c) 2011 School of Engineering, Cranfield University // license: lgpl-http://www.gnu.org/licenses/lgpl.html#include <cv. h> // opencv General include file # include <ml. h> // opencv machine learning include file # include <stdio. h> using namespace CV; // opencv API is in the C ++ "CV" namespace /* **************************************** *************************************/// Global definitions (for speed and usage of use) // handwritten number recognition # define number_of_training_samples 3823 # define attributes_per_sample 64 # define number_of_testing_samples 1797 # define number_of_classes 10 // n. b. classes are integer Handwritten digits in range 0-9 /***************************** ************************************** * ********* // Loads the sample database from file (which is a CSV text file) int read_data_from_csv (const char * filename, mat data, mat classes, int n_samples) {float TMP; // if we can't read the input file then return 0 file * f = fopen (filename, "R"); If (! F) {printf ("error: cannot read file % s \ n", filename); Return 0; // all not OK} // for each sample in the file for (INT line = 0; line <n_samples; line ++) {// For each attribute on the line in the file for (INT attribute = 0; Attribute <(attributes_per_sample + 1); Attribute ++) {If (attribute <64) {// first 64 elements (0-63) in each line are the attributes fscanf (F, "% F,", & TMP); data. at <flo At> (line, attribute) = TMP; // printf ("% F,", Data. at <float> (line, attribute);} else if (attribute = 64) {// attribute 65 is the class label {0... 9} fscanf (F, "% F,", & TMP); classes. at <float> (line, 0) = TMP; // printf ("% F \ n", classes. at <float> (line, 0) ;}} fclose (f); return 1; // All OK }/********************************** **************************************** * ***/INT main (INT argc, char ** Argv) {for (INT I = 0; I <argc; I ++) STD: cout <argv [I] <STD: Endl; // lets just check the version firstprintf ("opencv version % s (% d. % d. % d) \ n ", cv_version, cv_major_version, cv_minor_version, cv_subminor_version); // defines the training data and label matrix mat training_data = MAT (partition, attributes_per_sample, cv_32fc1 ); mat training_classifications = MAT (number_of_training_samples, 1, cv_32fc1); // define test data Matrix and label mat testing_data = MAT (number_of_testing_samples, attributes_per_sample, cv_32fc1); MAT testing_classifications = MAT (number_of_testing_samples, 1, cv_32fc1 ); // define all the attributes as numerical // alternatives are cv_var_categorical or cv_var_ordered (= optional) // that can be assigned on a per attribute basis mat var_type = MAT (attributes_per_sample + 1, 1, cv_8u); var_type.s Etto (scalar (cv_var_numerical); // all inputs are numerical // This is a classification problem (I. e. predict a discrete number of class // outputs) So reset the last (+ 1) Output var_type element to cv_var_categorical var_type.at <uchar> (attributes_per_sample, 0) = cv_var_categorical; double result; // value returned from a prediction // load the training dataset and test dataset if (read_data_from_csv (argv [1], training_data, Training_classifications, number_of_training_samples) & read_data_from_csv (argv [2], testing_data, testing_classifications, number_of_testing_samples )) {/********************************* Step 1: define the parameter *******************************/float priors [] = {1, 1, 1, 1, 1, 1, 1}; // weights of each classification for Classes cvrtparams Params = cvrtparams (25, // max depth 5, // min sample count 0, // regression accuracy: N/A here false, // compute surrogate split, no missing data 15, // Max number of categories (Use sub-optimal algorithm for larger numbers) priors, // the array of priors false, // calculate variable importance 4, // number of variables randomly selected at node and used to find the best split (s ). 100, // Max number of trees in the forest 0.01f, // Forrest accuracy cv_term Crit_iter | cv_termcrit_eps // termination cirteria ); /*************************** Step 2: Train random demo-forest (RDF) classifier *******************/printf ("\ nusing training database: % s \ n ", argv [1]); cvrtrees * rtree = new cvrtrees; rtree-> train (training_data, cv_row_sample, training_classifications, MAT (), MAT (), var_type, MAT (), params); // perform classifier testing and report results mat test_sam Ple; int correct_class = 0; int wrong_class = 0; int false_positives [number_of_classes] = {0, 0, 0, 0, 0, 0, 0, 0}; printf ("\ nusing testing database: % s \ n ", argv [2]); For (INT tsample = 0; tsample <number_of_testing_samples; tsample ++) {// extract a row from the testing matrix test_sample = testing_data.row (tsample ); /********************************* Step 3: prediction *************************************** * *****/Result = rtree-> predict (test_sample, MAT (); printf ("Testing Sample % I-> class result (digit % d) \ n ", tsample, (INT) result); // If the prediction and the (true) testing classification are the same // (N. b. opencv uses a floating point demo-tree implementation !) If (FABS (result-testing_classifications.at <float> (tsample, 0)> = flt_epsilon) {// if they differ more than floating point error => wrong class wrong_class ++; false_positives [(INT) result] ++;} else {// otherwise correct correct_class ++ ;}} printf ("\ nresults on the testing database: % s \ n "" \ tcorrect classification: % d (% G %) \ n "" \ twrong classifications: % d (% G %) \ n ", argv [2], correct_class, (double) correct_class * 100/number_of_testing_samples, wrong_class, (double) wrong_class * 100/second); For (INT I = 0; I <number_of_classes; I ++) {printf ("\ tclass (digit % d) False postives % d (% G %) \ n", I, false_positives [I], (double) false_positives [I] * 100/number_of_testing_samples);}/All matrix memory free by Destructors/All OK: Main returns 0 return 0;} // not OK: main returns-1 return-1 ;} /*************************************** ***************************************/

========================================================== ============================================

Handwritten data:

Set the dataset train test:

Accuracy on the test Dataset:

Contact Us

The content source of this page is from Internet, which doesn't represent Alibaba Cloud's opinion; products and services mentioned on that page don't have any relationship with Alibaba Cloud. If the content of the page makes you feel confusing, please write us an email, we will handle the problem within 5 days after receiving your email.

If you find any instances of plagiarism from the community, please send an email to: info-contact@alibabacloud.com and provide relevant evidence. A staff member will contact you within 5 working days.

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.