This section explains how to predict the type of contact lenses a patient needs to wear.
1. Use decision trees to predict the general flow of contact lens types
(1) Collection of data: Provided text files (data from the UCI database)
(2) Preparing Data: Resolving tab-delimited data rows
(3) Analyze data: Quickly check the data to ensure that the data content is parsed correctly, using the Createplot () function to draw the final tree diagram
(4) Training algorithm: Createtree () function
(5) Test algorithm: Write test function Validation decision tree can correctly classify a given data instance
(6) Use algorithm: stores the data structure of the number so that the next time you use it without having to re-construct the tree
trees.py as follows:
#!/usr/bin/python#-*-coding:utf-8-*-from Math Import log# calculates Shannon entropy for a given dataset Def calcshannonent (DataSet): Numentries=len ( DataSet) labelcounts={} for Featvec in dataset:currentlabel=featvec[-1] if CurrentLabel not in labelcounts.ke Ys (): Labelcounts[currentlabel]=0 labelcounts[currentlabel]+=1 shannonent=0.0 for key in labelcounts: Prob=float (Labelcounts[key])/numentries shannonent-=prob*log (prob,2) return shannonent# partition data set by given characteristics def Splitdataset (Dataset,axis,value): retdataset=[] for Featvec in Dataset:if featvec[axis]==value:reducedfeatve C=featvec[:axis] Reducedfeatvec.extend (featvec[axis+1:]) retdataset.append (Reducedfeatvec) return retdataset# choosing the best way to partition a DataSet Def choosebestfeaturetosplit (DataSet): Numfeatures=len (Dataset[0])-1 baseentropy= Calcshannonent (DataSet) #计算整个数据集的原始香农熵 bestinfogain=0.0;bestfeature=-1 for I in Range (numfeatures): #循环遍历数据集中的所有特征 Featlist=[example[i] For example in DataSET] Uniquevals=set (featlist) newentropy=0.0 for value in Uniquevals:subdataset=splitdataset (dataset,i,value ) Prob=len (Subdataset)/float (len (dataSet)) newentropy+=prob*calcshannonent (Subdataset) infogain=baseentropy-n Ewentropy if (infogain>bestinfogain): Bestinfogain=infogain bestfeature=i return bestfeaturedef majorityCn T (classlist): classcount={} for vote in Classlist:if vote not in Classcount.keys (): Classcount[vote]=0 classcount[ Vote]+=1 sortedclasscount=sorted (Classcount.iteritems (), Key=operator.itemgetter (1), reverse=true) return Sortedclasscount[0][0] #创建树的函数代码def Createtree (dataset,labels): classlist=[example[-1] For example in DataSet] if Classlist.count (Classlist[0]) ==len (classlist): #类别完全相同规则停止继续划分 return classlist[0] If Len (dataset[0]) ==1:return Majo RITYCNT (classlist) bestfeat=choosebestfeaturetosplit (dataSet) bestfeatlabel=labels[bestfeat] myTree={ bestfeatlabel:{}} del (Labels[bestfeat]) #得到列表包含的所有属性 FeatvalueS=[example[bestfeat] For example in DataSet] Uniquevals=set (featvalues) for value in uniquevals:sublabels=labels[:] Mytree[bestfeatlabel][value]=createtree (Splitdataset (dataset,bestfeat,value), sublabels) return myTree# Test algorithm: Use decision tree to perform classification def classify (INPUTTREE,FEATLABELS,TESTVEC): Firststr=inputtree.keys () [0] seconddict=inputtree[ FIRSTSTR] Featindex=featlabels.index (FIRSTSTR) for key in Seconddict.keys (): If testvec[featindex]==key:if type (Seconddict[key]). __name__== ' dict ': classlabel=classify (Seconddict[key],featlabels,testvec) else:classLabel=se Conddict[key] return classlabel# using algorithm: Decision tree Storage def storetree (inputtree,filename): Import pickle fw=open (filename, ' W ') pic Kle.dump (INPUTTREE,FW) fw.close () def grabtree (filename): Import pickle fr=open (filename) return Pickle.load (FR)
The
treeplotter.py is as follows:
#!/usr/bin/python#-*-coding:utf-8-*-import matplotlib.pyplot as pltfrom numpy import *import operator# define text box and arrow format Decis Ionnode=dict (boxstyle= "Sawtooth", fc= "0.8") Leafnode=dict (boxstyle= "Round4", fc= "0.8") Arrow_args=dict (arrowstyle= "<-") #绘制箭头的注解def Plotnode (nodetxt,centerpt,parentpt,nodetype): CreatePlot.ax1.annotate (NODETXT,XY=PARENTPT, xycoords= ' axes fraction ', xytext=centerpt,textcoords= ' axes fraction ', va= ' center ', ha= ' center ', Bbox=nodetype, Arrowprops=arrow_args) def createplot (): Fig=plt.figure (1,facecolor= ' white ') FIG.CLF () Createplot.ax1=plt.subplot ( 111,frameon=false) Plotnode (U ' decision node ', (0.5,0.1), (0.1,0.5), Decisionnode) plotnode (U ' leaf node ', (0.8,0.1), (0.3,0.8), Leafnode) plt.show () #获取叶节点的数目和树的层数def Getnumleafs (mytree): Numleafs=0 firststr=mytree.keys () [0] Seconddict=myt REE[FIRSTSTR] for key in Seconddict.keys (): If Type (Seconddict[key]). __name__== ' dict ': numleafs + = ge Tnumleafs (Seconddict[key]) Else:numleafs +=1 return NUMLEAFSDEF gettreedepth (mytree): Maxdepth=0 firststr=mytree.keys () [0] seconddict=mytree[firststr] for key in Seconddic T.keys (): If Type (Seconddict[key]). __name__== ' dict ': thisdepth=1+gettreedepth (Seconddict[key]) el Se:thisdepth=1 if Thisdepth>maxdepth:maxdepth=thisdepth return maxdepthdef retrievetree (i): listOfTrees=[{ ' No surfacing ': {0: ' No ', 1:{' flippers ': {0: ' No ', 1: ' Yes '}}}, {' No surfacing ': {0: ' No ', 1:{' flippers ': {0:{' head ' : {0: ' No ', 1: ' Yes '}},1: ' No '}}}] return listoftrees[i] #在父节点间填充文本信息 def plotmidtext (cntrpt,parentpt,txtstring): X Mid= (Parentpt[0]-cntrpt[0])/2.0+cntrpt[0] ymid= (parentpt[1]-cntrpt[1])/2.0+cntrpt[1] CreatePlot.ax1.text (XMid, ymid,txtstring) #计算宽和高def Plottree (mytree,parentpt,nodetxt): Numleafs=getnumleafs (mytree) depth=gettreedepth ( mytree) Firststr=mytree.keys () [0] cntrpt= (plottree.xoff+ (1.0+float (NUMLEAFS))/2.0/plottree.totalw,plottree.yoff) Plotmidtext (Cntrpt,parentpt,nodetxt) #计算父节点和子节点的中间位置 Plotnode (Firststr,cntrpt,parentpt,decisionnode) seconddict=mytree[firststr] Plottree.yoff=plottr Ee.yoff-1.0/plottree.totald for key in Seconddict.keys (): If Type (Seconddict[key]). __name__== ' Dict ': Plottree (Seconddict[key],cntrpt,str (key)) ELSE:PLOTTREE.XOFF=PLOTTREE.XOFF+1.0/PLOTTREE.TOTALW Plotnode (Seconddict[key], (Plottree.xoff,plottree.yoff), Cntrpt,leafnode) Plotmidtext ((PlotTree.xOff,plotTree. Yoff), Cntrpt,str (key)) Plottree.yoff=plottree.yoff+1.0/plottree.totalddef Createplot (intree): Fig=plt.figure (1,f acecolor= ' White ') FIG.CLF () axprops=dict (xticks=[],yticks=[]) Createplot.ax1=plt.subplot (111,frameon=false,**axpr OPS) Plottree.totalw=float (Getnumleafs (Intree)) plottree.totald=float (Gettreedepth (intree)) PLOTTREE.XOFF=-0.5/PL ottree.totalw;plottree.yoff=1.0; Plottree (Intree, (0.5,1.0), ') Plt.show ()
Lenses.txt as follows:
Run as follows:
1>>>ImportTrees2>>>ImportTreeplotter3>>> Fr=open ('Lenses.txt')4>>> Lenses=[inst.strip (). Split ('\ t') forInstinchfr.readlines ()]5>>> lenseslabels=[' Age','Prescript','astigmatic','tearrate']6>>> lensestree=Trees.createtree (lenses,lenseslabels)7>>>Lensestree8{'tearrate': {'reduced':'No lenses','Normal': {'astigmatic': {'Yes': {'Prescript': {'Hyper': {' Age': {'Pre':'No lenses','Presbyopic':'No lenses',' Young':' Hard'}},'Myope':' Hard'}},'No': {' Age': {'Pre':'Soft','Presbyopic': {'Prescript': {'Hyper':'Soft','Myope':'No lenses'}},' Young':'Soft'}}}}}}9>>> Treeplotter.createplot (Lensestree)
The graph shows that the decision tree matches the experimental data very well, but these matching options may be too many. We call this problem over-matching (overfitting). To reduce the problem of over-matching, we can crop the decision tree and remove some unnecessary leaf nodes. If the leaf node can only add a little information, you can delete the node and incorporate it into other leaf nodes.
ID3 algorithm of decision tree for predicting invisible eye type--python realization