Machine Learning: Decision Tree in python practice and decision tree in python practice

Source: Internet
Author: User

Machine Learning: Decision Tree in python practice and decision tree in python practice

Decision tree principle: Find the final feature from the dataset and iteratively divide the dataset until the data under a branch belongs to the same type or has traversed all the features of the partitioned dataset, stop the decision tree algorithm.

Each time you divide a dataset, there are many features. How can we choose which feature to divide the dataset? Here we need to introduce the concepts of information gain and information entropy.

I. Information Gain

The principle of dividing a dataset is to change unordered data into orders. Information gain is the change in information after the dataset is divided. Knowing how to calculate the information gain, we can calculate the information gain obtained by dividing data sets based on each feature. Selecting the feature with the highest information gain is the best choice. First, let's clarify the definition of the information: the information of the symbol xi is defined as l (xi) =-log2 p (xi), and p (xi) is the probability of selecting this class. The entropy H of the information source is-Σ p (xi) · log2 p (xi ). Based on this formula, we write the following code to calculate Shannon entropy.

def calcShannonEnt(dataSet): NumEntries = len(dataSet) labelsCount = {} for i in dataSet:  currentlabel = i[-1]  if currentlabel not in labelsCount.keys():   labelsCount[currentlabel]=0  labelsCount[currentlabel]+=1 ShannonEnt = 0.0 for key in labelsCount:  prob = labelsCount[key]/NumEntries  ShannonEnt -= prob*log(prob,2) return ShannonEnt

In the preceding custom function, we need to import the log method from math import log. We can use a simple example to test it.

def createdataSet(): #dataSet = [['1','1','yes'],['1','0','no'],['0','1','no'],['0','0','no']] dataSet = [[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,0,'no']] labels = ['no surfacing','flippers'] return dataSet,labels

The entropy here is 0.811. When we increase the data type, the entropy increases. Here, three types of changed datasets are 'yes', 'no', and 'maybe'. That is to say, the more chaotic the data is, the greater the entropy.

Classification Algorithms calculate information entropy and divide datasets. In the decision tree algorithm, we calculate the entropy of a dataset based on each feature, and then determine which feature is the best classification method.

def splitDataSet(dataSet,axis,value): retDataSet = [] for featVec in dataSet:  if featVec[axis] == value:   reducedfeatVec = featVec[:axis]   reducedfeatVec.extend(featVec[axis+1:])   retDataSet.append(reducedfeatVec) return retDataSet

Axis indicates the feature of the dataset, and value indicates the return value of the feature. Note the differences between the extend and append methods. To illustrate the difference

The following is a test result of the dataset partitioning function:

Axis = 0, value = 1, which is divided by whether the 0th feature vectors of the myDat dataset are equal to 1.

Next we will traverse the entire dataset, calculate the Shannon entropy for each divided dataset, and find the best feature division method.

def choosebestfeatureToSplit(dataSet): Numfeatures = len(dataSet)-1 BaseShannonEnt = calcShannonEnt(dataSet) bestInfoGain=0.0 bestfeature = -1 for i in range(Numfeatures):  featlist = [example[i] for example in dataSet]  featSet = set(featlist)  newEntropy = 0.0  for value in featSet:   subDataSet = splitDataSet(dataSet,i,value)   prob = len(subDataSet)/len(dataSet)   newEntropy += prob*calcShannonEnt(subDataSet)   infoGain = BaseShannonEnt-newEntropy  if infoGain>bestInfoGain:   bestInfoGain=infoGain   bestfeature = i return bestfeature

Information gain is the reduction of entropy or the reduction of Data disorder. Finally, compare the information gain among all features to return the best feature division index. Function Test Result:

Next, we start to build a decision tree recursively. We need to calculate the number of columns before building and check whether all attributes are used by the algorithm. This function adopts the same method as calssify0 in Chapter 2.

def majorityCnt(classlist): ClassCount = {} for vote in classlist:  if vote not in ClassCount.keys():   ClassCount[vote]=0  ClassCount[vote]+=1 sortedClassCount = sorted(ClassCount.items(),key = operator.itemgetter(1),reverse = True) return sortedClassCount[0][0]def createTrees(dataSet,labels): classList = [example[-1] for example in dataSet] if classList.count(classList[0]) == len(classList):  return classList[0] if len(dataSet[0])==1:  return majorityCnt(classList) bestfeature = choosebestfeatureToSplit(dataSet) bestfeatureLabel = labels[bestfeature] myTree = {bestfeatureLabel:{}} del(labels[bestfeature]) featValue = [example[bestfeature] for example in dataSet] uniqueValue = set(featValue) for value in uniqueValue:  subLabels = labels[:]  myTree[bestfeatureLabel][value] = createTrees(splitDataSet(dataSet,bestfeature,value),subLabels) return myTree

The final result of the decision tree is as follows:

With the above results, we do not seem intuitive, so we will use the matplotlib annotation to plot the tree structure. Matplotlib provides an annotation tool annotations that can add text annotations to data graphs. Let's test the usage of this annotation tool first.

import matplotlib.pyplot as pltdecisionNode = dict(boxstyle = 'sawtooth',fc = '0.8')leafNode = dict(boxstyle = 'sawtooth',fc = '0.8')arrow_args = dict(arrowstyle = '<-')def plotNode(nodeTxt,centerPt,parentPt,nodeType): createPlot.ax1.annotate(nodeTxt,xy = parentPt,xycoords = 'axes fraction',\       xytext = centerPt,textcoords = 'axes fraction',\       va = 'center',ha = 'center',bbox = nodeType,\       arrowprops = arrow_args) def createPlot(): fig = plt.figure(1,facecolor = 'white') fig.clf() createPlot.ax1 = plt.subplot(111,frameon = False) plotNode('test1',(0.5,0.1),(0.1,0.5),decisionNode) plotNode('test2',(0.8,0.1),(0.3,0.8),leafNode) plt.show()

After testing this small example, we will begin to build the annotation tree. Although there are xy coordinates, we may have some trouble in how to place Tree nodes. Therefore, we need to know the number of leaf nodes and the depth of the tree. The following two functions are used to obtain the number of leaf nodes and the depth of the tree. The two functions have the same structure and traverse all child nodes from the first keyword () the function determines whether a subnode is of the dictionary type. If the subnode is of the dictionary type, it can be considered as a judgment node and call the getNumleafs () function recursively to traverse the entire tree, and return the number of leaf nodes. The getTreeDepth () function calculates the number of nodes to be judged during traversal. The termination condition of this function is the leaf node. once it reaches the leaf node, it is returned from the recursive call and the variable for calculating the tree depth is added.

def getNumleafs(myTree): numLeafs=0 key_sorted= sorted(myTree.keys()) firstStr = key_sorted[0] secondDict = myTree[firstStr] for key in secondDict.keys():  if type(secondDict[key]).__name__=='dict':   numLeafs+=getNumleafs(secondDict[key])  else:   numLeafs+=1 return numLeafsdef getTreeDepth(myTree): maxdepth=0 key_sorted= sorted(myTree.keys()) firstStr = key_sorted[0] secondDict = myTree[firstStr] for key in secondDict.keys():  if type(secondDict[key]).__name__ == 'dict':   thedepth=1+getTreeDepth(secondDict[key])  else:   thedepth=1  if thedepth>maxdepth:   maxdepth=thedepth return maxdepth

The test results are as follows:

The final decision tree diagram is provided to verify the correctness of the above results.

We can see that the depth of the tree is indeed two layers, and the number of leaf nodes is three. Next, we will provide the key functions for drawing the decision tree graph, and the result will be the medium decision tree.

def plotMidText(cntrPt,parentPt,txtString): xMid = (parentPt[0]-cntrPt[0])/2.0+cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0+cntrPt[1] createPlot.ax1.text(xMid,yMid,txtString) def plotTree(myTree,parentPt,nodeTxt): numLeafs = getNumleafs(myTree) depth = getTreeDepth(myTree) key_sorted= sorted(myTree.keys()) firstStr = key_sorted[0] cntrPt = (plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff) plotMidText(cntrPt,parentPt,nodeTxt) plotNode(firstStr,cntrPt,parentPt,decisionNode) secondDict = myTree[firstStr] plotTree.yOff -= 1.0/plotTree.totalD for key in secondDict.keys():  if type(secondDict[key]).__name__ == 'dict':   plotTree(secondDict[key],cntrPt,str(key))  else:   plotTree.xOff+=1.0/plotTree.totalW   plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)   plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key)) plotTree.yOff+=1.0/plotTree.totalD def createPlot(inTree): fig = plt.figure(1,facecolor = 'white') fig.clf() axprops = dict(xticks = [],yticks = []) createPlot.ax1 = plt.subplot(111,frameon = False,**axprops) plotTree.totalW = float(getNumleafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/ plotTree.totalW; plotTree.yOff = 1.0 plotTree(inTree,(0.5,1.0),'') plt.show()

The above is all the content of this article. I hope it will be helpful for your learning and support for helping customers.

Related Article

Contact Us

The content source of this page is from Internet, which doesn't represent Alibaba Cloud's opinion; products and services mentioned on that page don't have any relationship with Alibaba Cloud. If the content of the page makes you feel confusing, please write us an email, we will handle the problem within 5 days after receiving your email.

If you find any instances of plagiarism from the community, please send an email to: info-contact@alibabacloud.com and provide relevant evidence. A staff member will contact you within 5 working days.

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.