Learning notes for machine learning practice: Create a tree chart and use a decision tree to predict the contact lens type,

Source: Internet
Author: User

Learning notes for "Machine Learning Practice": Draw a tree chart & use a decision tree to predict the contact lens type,

The decision tree is implemented in the previous section, but it is only implemented using a nested dictionary containing tree structure information. Its representation is difficult to understand. Obviously, it is necessary to draw an intuitive Binary Tree chart. Python does not provide its own tree rendering tool. You need to write your own functions and create your own tree embedding Based on the Matplotlib library. This part of the code is complex and involves two-dimensional coordinate operations. Although the code in the book is available, there are a lot of functions and various variables, and it feels very messy. At the same time, a lot of recursion is used, therefore, it takes more than a day to understand the problem.

1. Draw an attribute chart

The Matplotlib Annotation Tool annotations is used to draw various details of decision tree, including generating text boxes at nodes, adding text comments, and providing text coloring. Before drawing a whole tree, it is best to master the painting of a single tree node. A simple example is as follows:

# -*- coding: utf-8 -*-"""Created on Fri Sep 04 01:15:01 2015@author: Herbert"""import matplotlib.pyplot as pltnonLeafNodes = dict(boxstyle = "sawtooth", fc = "0.8")leafNodes = dict(boxstyle = "round4", fc = "0.8")line = dict(arrowstyle = "<-")def plotNode(nodeName, targetPt, parentPt, nodeType):    createPlot.ax1.annotate(nodeName, xy = parentPt, xycoords = \                            'axes fraction', xytext = targetPt, \                            textcoords = 'axes fraction', va = \                            "center", ha = "center", bbox = nodeType, \                            arrowprops = line)def createPlot():    fig = plt.figure(1, facecolor = 'white')    fig.clf()    createPlot.ax1 = plt.subplot(111, frameon = False)    plotNode('nonLeafNode', (0.2, 0.1), (0.4, 0.8), nonLeafNodes)    plotNode('LeafNode', (0.8, 0.1), (0.6, 0.8), leafNodes)    plt.show()createPlot()

Output result:

In this instance,plotNode()A function is used to draw arrows and nodes. Each time a function is called, an arrow and a node are drawn. This function is explained in detail later.createPlot()The function creates a dialog box for the output image and performs some simple settings. At the same time, it callsplotNode()Generate a pair of nodes and arrows pointing to nodes.

Draw the entire tree

There are many functions and variables in this part. To facilitate future function expansion, you need to provide the necessary Annotations:

#-*-Coding: UTF-8-*-"Created on Fri Sep 04 01:15:01 2015 @ author: Herbert" import matplotlib. pyplot as plt # Some Code defines Drawing Graphics. It mainly defines the format of text box and scissors nonLeafNodes = dict (boxstyle = "sawtooth", fc = "0.8 ") leafNodes = dict (boxstyle = "round4", fc = "0.8") line = dict (arrowstyle = "<-") # use recursion to calculate the number of leaf nodes in the tree def getLeafNum (tree): num = 0 firstKey = tree. keys () [0] secondDict = tree [firstKey] for key in secondDict. keys (): if type (secondDict [key]). _ name _ = 'dict ': num + = getLeafNum (secondDict [key]) else: num + = 1 return num # same leaf node computing function, use recursive Calculation of decision tree depth def getTreeDepth (tree): maxDepth = 0 firstKey = tree. keys () [0] secondDict = tree [firstKey] for key in secondDict. keys (): if type (secondDict [key]). _ name _ = 'dict ': depth = getTreeDepth (secondDict [key]) + 1 else: depth = 1 if depth> maxDepth: maxDepth = depth return maxDepth # functions implemented in the previous example are used to draw nodes and arrows def plotNode (nodeName, targetPt, parentPt, nodeType) in the form of Annotations: createPlot. ax1.annotate (nodeName, xy = parentPt, xycoords = \ 'axes fraction', xytext = targetPt, \ textcoords = 'axes fraction', va = \ "center ", ha = "center", bbox = nodeType, \ arrowprops = line) # used to draw the Annotation on the cutter line, involving coordinate calculation, in fact, the def insertText (targetPt, parentPt, info) is added to the center of the two coordinate points: xCoord = (parentPt [0]-targetPt [0]) /2.0 + targetPt [0] yCoord = (parentPt [1]-targetPt [1])/2.0 + targetPt [1] createPlot. ax1.text (xCoord, yCoord, info) # implements the logic and coordinate operations of the entire tree. recursion is used. Important functions # two global variables plotTree are used. xOff and plotTree. yOff # used to track the locations of drawn nodes and place the appropriate locations of the next node def plotTree (tree, parentPt, info ): # Call two functions to calculate the number of leaf nodes and the depth of the tree leafNum = getLeafNum (tree) treeDepth = getTreeDepth (tree) firstKey = tree. keys () [0] # the text label for this node firstPt = (plotTree. xOff + (1.0 + float (leafNum)/2.0/plotTree. totalW, \ plotTree. yOff) insertText (firstPt, parentPt, info) plotNode (firstKey, firstPt, parentPt, nonLeafNodes) secondDict = tree [firstKey] plotTree. yOff = plotTree. yOff-1.0/plotTree. totalD for key in secondDict. keys (): if type (secondDict [key]). _ name _ = 'dict ': plotTree (secondDict [key], firstPt, str (key) else: plotTree. xOff = plotTree. xOff + 1.0/plotTree. totalW plotNode (secondDict [key], (plotTree. xOff, plotTree. yOff), \ firstPt, leafNodes) insertText (plotTree. xOff, plotTree. yOff), firstPt, str (key) plotTree. yOff = plotTree. yOff + 1.0/plotTree. totalD # The following functions perform real drawing operations. The plotTree () function is only some logic and coordinate operations of the tree def createPlot (inTree): fig = plt. figure (1, facecolor = 'white') fig. clf () createPlot. ax1 = plt. subplot (111, frameon = False) #, ** axprops) # global variable plotTree. totalW and plotTree. totalD # used to store the width of the tree and the depth of the tree plotTree. totalW = float (getLeafNum (inTree) plotTree. totalD = float (getTreeDepth (inTree) plotTree. xOff =-0.5/plotTree. totalW plotTree. yOff = 1.0 plotTree (inTree, (0.5, 1.0), '') plt. show () # a small test set def retrieveTree (I): listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers ': {0: 'no', 1: 'yes' }}},\{ 'no surfacing': {0: 'no', 1: {'flippers ': {0: {'head': {0: 'no', \ 1: 'yes'}, 1: 'No' }}] return listOfTrees [I] createPlot (retrieveTree (1) # Call a tree in the test set to draw

retrieveTree()The function contains two independent trees. input parameters respectively to return the parameters of the tree.tree, And finally runcreatePlot(tree)The drawing result is as follows:

The leaf node and depth of the Recursive Computing tree in the book are very simple. When writing a function for drawing an attribute chart, the difficulty lies in the value of some plot coordinates in this book and the processing of the calculated node coordinates. The interpretation of this part is scattered. Blog: http://www.cnblogs.com/fantasy01/p/4595902.html gives a very detailed explanation, including coordinate solution and formula analysis, the following is only part of the extraction for understanding:

Here we will talk about the specific use of custom rendering, such:

Here, the author selects a very clever method, and does not cause problems in the drawing because of the increase or decrease of the node of the tree and the increase or decrease of the depth. Of course, it cannot be too intensive. Here, we use the number of leaf nodes of the entire tree as the number of parts to evenly split the length of the entire X axis, and use the depth of the tree as the number of parts to evenly split the Y axis length and use plotTree. xOff is the x coordinate of a leaf node recently drawn. plotTree is used when the leaf node coordinates are drawn again. xOff will change; plotTree is used. yOff is used as the depth of the current plot, plotTree. yOff reduces each recursive layer by one (the average split by parts mentioned above). In other cases, the two coordinate points are used to calculate non-leaf nodes, the two parameters can be used to determine the coordinate of a node.

plotTreeThe overall function steps are as follows:

Below isplotTreeAndcreatePlotTherefore, the code of the two functions is taken out separately:

# Implement the rendering logic and coordinate operations of the entire tree, using recursion, important functions # two global variables plotTree. xOff and plotTree. yOff # used to track the locations of drawn nodes and place the appropriate locations of the next node def plotTree (tree, parentPt, info ): # Call two functions to calculate the number of leaf nodes and the depth of the tree leafNum = getLeafNum (tree) treeDepth = getTreeDepth (tree) firstKey = tree. keys () [0] # the text label for this node firstPt = (plotTree. xOff + (1.0 + float (leafNum)/2.0/plotTree. totalW, \ plotTree. yOff) insertText (firstPt, parentPt, info) plotNode (firstKey, firstPt, parentPt, nonLeafNodes) secondDict = tree [firstKey] plotTree. yOff = plotTree. yOff-1.0/plotTree. totalD for key in secondDict. keys (): if type (secondDict [key]). _ name _ = 'dict ': plotTree (secondDict [key], firstPt, str (key) else: plotTree. xOff = plotTree. xOff + 1.0/plotTree. totalW plotNode (secondDict [key], (plotTree. xOff, plotTree. yOff), \ firstPt, leafNodes) insertText (plotTree. xOff, plotTree. yOff), firstPt, str (key) plotTree. yOff = plotTree. yOff + 1.0/plotTree. totalD # The following functions perform real drawing operations. The plotTree () function is only some logic and coordinate operations of the tree def createPlot (inTree): fig = plt. figure (1, facecolor = 'white') fig. clf () createPlot. ax1 = plt. subplot (111, frameon = False) #, ** axprops) # global variable plotTree. totalW and plotTree. totalD # used to store the width of the tree and the depth of the tree plotTree. totalW = float (getLeafNum (inTree) plotTree. totalD = float (getTreeDepth (inTree) plotTree. xOff =-0.5/plotTree. totalW plotTree. yOff = 1.0 plotTree (inTree, (0.5, 1.0), '') plt. show ()

First, the Code divides the entire painting interval evenly based on the number and depth of leaf nodes, andxAndyThe total length of the axis is1, Such:

The description is as follows::

1. The square in the figure is not a leaf node,@Is the position of the leaf node, so the length of a table should be:1/plotTree.totalWBut the position of the leaf node should be@Location, at the beginningplotTree.xOffThe value is:-0.5/plotTree.totalW, That is, startxThe axis position is the half table distance on the left of the first table. The advantage of this is that@You can add an integer multiple to the position.1/plotTree.totalW.

2. The code in the plotTree function is as follows:

firstPt = (plotTree.xOff + (1.0 + float(leafNum)) / 2.0/ plotTree.totalW, plotTree.yOff)

Where, variablesplotTree.xOffThat is,xAxis coordinates. When determining the current node location, you only need to determine the number of leaf nodes on the current node each time. Therefore, the total distance occupied by the leaf nodes is determined as follows:float(numLeafs)/plotTree.totalWTherefore, the current node is half the distance between all its leaf nodes:float(numLeafs)/2.0/plotTree.totalWBut sinceplotTree.xOffAssignment is not from0First, the left table is removed. Therefore, you need to add the following table distance:1/2/plotTree.totalW, The total value is:(1.0 + float(numLeafs))/2.0/plotTree.totalW, So the offset is determinedxThe axis position is changed:plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW

3. AboutplotTree()Function Parameters

plotTree(inTree, (0.5, 1.0), ' ')

PairplotTree()The second parameter value of the function is(0.5, 1.0)Because the start root node does not use dashes, the positions of the parent node and the current node need to overlap, and the current node location is determined by using 2(0.5, 1.0).

Summary: Use this to gradually increasexAxis coordinates and gradually decreaseyThe coordinate of the axis can well take into account the number and depth of leaf nodes of the tree, so the logical proportion of the graph is well determined, even if the image size changes, we can still see the scaled tree chart.

Ii. Use decision trees to predict contact lens types

Here is an example: Use a decision tree to predict the contact lens type that a patient needs to wear. The general steps for the entire prediction are as follows:

# -*- coding: utf-8 -*-"""Created on Sat Sep 05 01:56:04 2015@author: Herbert"""import pickledef storeTree(tree, filename):    fw = open(filename, 'w')    pickle.dump(tree, fw)    fw.close()def getTree(filename):    fr = open(filename)    return pickle.load(fr)

The following code uses a decision tree to predict the contact lens model. The dataset used is the contact lens dataset, which contains the observed conditions for eye condition of many patients and the contact lens type recommended by doctors, the contact lenses include hard materials.(hard)Soft Material(soft)And not suitable for wearing contact lenses(no lenses)The data is from the UCI database. The Code finally calls the preparedcreatePlot()Function draw tree chart.

# -*- coding: utf-8 -*-"""Created on Sat Sep 05 14:21:43 2015@author: Herbert"""import treeimport plotTreeimport saveTreefr = open('lenses.txt')lensesData = [data.strip().split('\t') for data in fr.readlines()]lensesLabel = ['age', 'prescript', 'astigmatic', 'tearRate']lensesTree = tree.buildTree(lensesData, lensesLabel)#print lensesDataprint lensesTreeprint plotTree.createPlot(lensesTree)

We can see that the decision tree is constructed and drawn in the early stage, and different datasets can be used to obtain intuitive results. We can see that different branches along the decision tree can be seen, different contact lens types are available for different patients.

Iii. Summary of the decision trees used in this chapter

Back to the decision tree algorithm layer, the implementation of the above Code is based on the ID3 decision tree construction algorithm, which is a very classic algorithm, but in fact there are many shortcomings. In fact, the use of decision trees often encounters a problem, that is"Over-matching". Sometimes, too many Branch selection or matching options will bring negative effects to decision-making. To reduce the issue of over-matching, algorithm designers usually select"Pruning". To put it simply, if a leaf node can only add a little bit of information, it can be deleted.

In addition, there are several popular decision tree construction algorithms C4.5, C5.0 and CART, which need to be further studied later.

References: http://blog.sina.com.cn/s/blog_7399ad1f01014wec.html

Copyright Disclaimer: This article is an original article by the blogger and cannot be reproduced without the permission of the blogger.

Contact Us

The content source of this page is from Internet, which doesn't represent Alibaba Cloud's opinion; products and services mentioned on that page don't have any relationship with Alibaba Cloud. If the content of the page makes you feel confusing, please write us an email, we will handle the problem within 5 days after receiving your email.

If you find any instances of plagiarism from the community, please send an email to: info-contact@alibabacloud.com and provide relevant evidence. A staff member will contact you within 5 working days.

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.