|
degate 0.1.1
|
00001 /* 00002 00003 This code is from http://terpconnect.umd.edu/~xliu10/research/adaboost.html 00004 00005 Todo: Check licence type 00006 00007 */ 00008 00009 #ifndef ADABOOST_HPP 00010 #define ADABOOST_HPP 00011 00012 #include <vector> 00013 #include <string> 00014 #include <math.h> 00015 00016 // that is bad 00017 using namespace std; 00018 00019 template <class T> 00020 class Classifier 00021 { 00022 public: 00023 // This function performs the actual recognition 00024 // MUST be implemented by the weak classifier, usually T is the feature vector 00025 virtual int recognize(T&) = 0; 00026 // MUST be implemented by the weak classifier, simply return the name of the weak classifier itself 00027 // It is recommended to use this function to keep track of the weak classifiers. 00028 // You will find this useful if more than 30 weak classifiers are trained 00029 virtual string get_name() const = 0; 00030 // the ada-boost algorithm that trains the strong classifier from weak classifiers 00031 // data and label defines the training set 00032 // clsfrs is a collection of weak classifiers 00033 // this ada-boost implementation will first run the weak classifiers against all the training samples 00034 // and therefore the acutal trainning will be very fast 00035 static vector<float> adaboost(vector<Classifier<T>*> clsfrs, vector<T*> data, vector<int> label, const int maxround = 80) 00036 { 00037 vector<float> alpha; 00038 vector<float> d; 00039 00040 if (data.size()!=label.size() || clsfrs.size()==0 || label.size()==0) 00041 return alpha; 00042 00043 d.resize(label.size()); 00044 alpha.resize(clsfrs.size()); 00045 00046 00047 for (unsigned int i=0;i<label.size();i++) 00048 d[i]=float(1.0)/float(label.size()); 00049 vector< vector<int> > rec; 00050 rec.resize(clsfrs.size()); 00051 00052 // run the weak classifiers on all the trainning data first 00053 for (unsigned int j=0;j<clsfrs.size();j++) 00054 { 00055 rec[j].resize(label.size()); 00056 for (unsigned int i=0;i<label.size();i++) 00057 rec[j][i]=clsfrs[j]->recognize(*data[i]); 00058 } 00059 00060 //run maxround times of iteration 00061 00062 for (int round=0;round<maxround;round++) 00063 { 00064 float minerr=(float)label.size(); 00065 int best = 0; 00066 for (unsigned int j=0;j<clsfrs.size();j++) 00067 { 00068 float err=0; 00069 for (unsigned int i=0;i<label.size();i++) 00070 { 00071 if (rec[j][i]!=label[i]) 00072 err += d[i]; 00073 } 00074 if (err<minerr) 00075 { 00076 minerr = err; 00077 best = j; 00078 } 00079 } 00080 if (minerr >= 0.5) break; 00081 00082 float a= log((1.0f-minerr)/minerr)/2; 00083 alpha[best]+=a; 00084 vector<float> d1=d; 00085 float z = 0; 00086 for (unsigned int i=0;i<label.size();i++) 00087 { 00088 d1[i]=d[i]*exp(-a*label[i]*rec[best][i]); 00089 z+=d1[i]; 00090 } 00091 for (unsigned int i=0;i<label.size();i++) 00092 { 00093 d[i]=d1[i]/z; 00094 } 00095 } 00096 return alpha; 00097 } 00098 }; 00099 00100 //The linear combination of weak classifiers i.e. the strong classifier 00101 00102 template <class T> 00103 class MultiClassifier :public Classifier<T> 00104 { 00105 private: 00106 vector<float> weights; 00107 vector<Classifier<T>*> clsfrs; 00108 public: 00109 float score; 00110 MultiClassifier(vector<float> w, vector<Classifier<T>*> c) 00111 { 00112 this->weights = w; 00113 this->clsfrs = c; 00114 } 00115 std::string get_name() const { return "MultiClassifier"; } 00116 int recognize(T& obj) 00117 { 00118 float res=0; 00119 for (unsigned int i=0;i<weights.size();i++) 00120 if(weights[i]> 0) res+=weights[i]*clsfrs[i]->recognize(obj); 00121 score=res; 00122 if (res>=0) 00123 return 1; 00124 else 00125 return -1; 00126 } 00127 }; 00128 00129 // the utility function that tests a (strong) classifier over all the test data 00130 00131 template <class T> 00132 void testClassifier(Classifier<T>* cls, vector<T*> data, vector<int> label, float & fpos, float & fneg) 00133 { 00134 int pos = 0, neg = 0; 00135 fpos=fneg=0; 00136 for (int i=0;i<label.size();i++) 00137 { 00138 int rec = cls->recognize(*data[i]); 00139 if (label[i]==1) 00140 { 00141 pos++; 00142 if (rec!=1) 00143 fneg=fneg+1; 00144 } 00145 if (label[i]==-1) 00146 { 00147 neg++; 00148 if (rec!=-1) 00149 fpos=fpos+1; 00150 } 00151 } 00152 fpos = fpos/neg; 00153 fneg = fneg/pos; 00154 } 00155 00156 #endif
1.7.4