Machine learning in Action Learning notes: Drawing a tree chart & predicting contact lens types using decision Trees

Source: Internet
Author: User

The previous section implements the decision tree, but only using nested dictionaries containing tree structure information, its representation is difficult to understand, and it is clear that it is necessary to draw an intuitive two-fork tree diagram. Python does not provide its own drawing tree tool, it needs to write its own functions, in conjunction with the Matplotlib library to create its own tree chart. This part of the code is many and complex, involving two-dimensional coordinate operations, although the code in the book is available, but the function and various variables very much, feel very messy, while a large number of recursive, so can only be repeated research, back and forth with a day more time, just about understand, so need to note.

I. Drawing a property map

Here, the Matplotlib annotation tool is used to annotations the details of the decision tree drawing, including the text box at the Spawn node, adding text annotations, providing text coloring, and so on. Before you draw a whole tree, it is best to master the drawing of individual tree nodes first. A simple example is as follows:

#-*-Coding:utf-8-*-"" " Created on Fri Sep 01:15:01 2015@author:herbert " ""ImportMatplotlib.pyplot asPltnonleafnodes = dict (Boxstyle ="Sawtooth", FC ="0.8") Leafnodes = dict (Boxstyle ="Round4", FC ="0.8") line = dict (Arrowstyle ="<-") def plotnode(NodeName, TARGETPT, PARENTPT, NodeType):CreatePlot.ax1.annotate (nodeName, xy = parentpt, xycoords =' axes fraction ', Xytext = targetpt, textcoords =' axes fraction ', VA ="Center", ha ="Center", bbox = nodeType, Arrowprops = line) def createplot():Fig = Plt.figure (1, Facecolor =' White ') FIG.CLF () Createplot.ax1 = Plt.subplot (111, Frameon =False) Plotnode (' Nonleafnode ', (0.2,0.1), (0.4,0.8), Nonleafnodes) Plotnode (' Leafnode ', (0.8,0.1), (0.6,0.8), Leafnodes) plt.show () Createplot ()

Output Result:

In this instance, the plotNode() function is used to draw arrows and nodes, and each time the function is called, an arrow and a node are drawn. Later, there is a more detailed explanation for the function. The createPlot() function creates a dialog box for the output image and aligns it with some simple settings, calling two times plotNode() , generating a pair of nodes and arrows pointing to the node.

Draw an entire tree

This part of the function and variables, in order to facilitate the future expansion of functions, need to give the necessary annotations:

#-*-Coding:utf-8-*-"" " Created on Fri Sep 01:15:01 2015@author:herbert " ""ImportMatplotlib.pyplot asPlt# Part of the code is a few definitions of drawing, the main definition of the text box and the shape of the cutting headNonleafnodes = dict (Boxstyle ="Sawtooth", FC ="0.8") Leafnodes = dict (Boxstyle ="Round4", FC ="0.8") line = dict (Arrowstyle ="<-")# Number of leaf nodes using recursive calculation tree def getleafnum(tree):num =0Firstkey = Tree.keys () [0] Seconddict = Tree[firstkey] forKeyinchSeconddict.keys ():ifType (Seconddict[key]). __name__ = =' Dict ': num + = Getleafnum (Seconddict[key])Else: num + =1    returnNum# The same leaf node calculates the function, using recursion to calculate the depth of the decision tree def gettreedepth(tree):MaxDepth =0Firstkey = Tree.keys () [0] Seconddict = Tree[firstkey] forKeyinchSeconddict.keys ():ifType (Seconddict[key]). __name__ = =' Dict ': depth = gettreedepth (Seconddict[key]) +1        Else: depth =1        ifDepth > maxdepth:maxdepth = depthreturnMaxDepth# In the previous example implemented functions for annotation form drawing nodes and arrows def plotnode(NodeName, TARGETPT, PARENTPT, NodeType):CreatePlot.ax1.annotate (nodeName, xy = parentpt, xycoords =' axes fraction ', Xytext = targetpt, textcoords =' axes fraction ', VA ="Center", ha ="Center", bbox = nodeType, Arrowprops = line)# used to draw labels on the head line, involving coordinate calculations, in fact, the center of the two point coordinates to add a callout def inserttext(targetpt, PARENTPT, info):Xcoord = (parentpt[0]-targetpt[0]) /2.0+ targetpt[0] Ycoord = (parentpt[1]-targetpt[1]) /2.0+ targetpt[1] CreatePlot.ax1.text (Xcoord, Ycoord, info)# implements the entire tree's drawing logic and coordinate operations, using recursive, important functions# Two of these global variables Plottree.xoff and Plottree.yoff# used to track the position of a drawn node and place the next node in the right place def plottree(tree, PARENTPT, info):    # Two functions were called to figure out the number of leaf nodes and the depth of the treeLeafnum = Getleafnum (tree) treedepth = gettreedepth (tree) Firstkey = Tree.keys () [0]# The text label for this nodeFIRSTPT = (Plottree.xoff + (1.0+ float (leafnum))/2.0/PLOTTREE.TOTALW, Plottree.yoff) inserttext (FIRSTPT, PARENTPT, info) plotnode (firstkey, FIRSTPT, pare NTPT, nonleafnodes) seconddict = Tree[firstkey] Plottree.yoff = Plottree.yoff-1.0/Plottree.totald forKeyinchSeconddict.keys ():ifType (Seconddict[key]). __name__ = =' Dict ': Plottree (Seconddict[key], firstpt, str (key))Else: Plottree.xoff = Plottree.xoff +1.0/PLOTTREE.TOTALW Plotnode (Seconddict[key], (Plottree.xoff, Plottree.yoff), FIRSTPT, LEAFNO DES) InsertText ((Plottree.xoff, Plottree.yoff), FIRSTPT, str (key)) Plottree.yoff = Plottree.yoff +1.0/Plottree.totald# The following function performs a true drawing operation, and the Plottree () function is just some logical and coordinate operations of the tree def createplot(intree):Fig = Plt.figure (1, Facecolor =' White ') FIG.CLF () Createplot.ax1 = Plt.subplot (111, Frameon =False)#, **axprops)    # Global Variables plottree.totalw and Plottree.totald    # used to store the width of the tree and the depth of the treePLOTTREE.TOTALW = Float (getleafnum (intree)) Plottree.totald = float (gettreedepth (intree)) Plottree.xoff =-0.5/PLOTTREE.TOTALW Plottree.yoff =1.0Plottree (Intree, (0.5,1.0),"') Plt.show ()# A small set of tests def retrievetree(i):Listoftrees = [{' no surfacing ':{0:' No ',1:{' flippers ':{0:' No ',1:' yes '}}}},                    {' no surfacing ':{0:' No ',1:{' flippers ':{0:{' head ':{0:' No ',1:' yes '}},1:' No '}}}}]returnListoftrees[i]createplot (Retrievetree (1))# Call a tree in the test set to draw

retrieveTree()The function contains two separate trees, input parameters to return the parameters of the tree, and the tree final execution createPlot(tree) is to get the result of the drawing, as follows:

The part of the book about the leaf node and depth of the recursive computation tree is very simple, when writing the function of drawing attribute graph, the difficulty lies in the value of some drawing coordinates in this book and the processing in computing node coordinates, the explanation of this part is quite scattered in the book. Blog: http://www.cnblogs.com/fantasy01/p/4595902.html gives a very detailed explanation, including the solution of coordinates and the analysis of the formula, the following only to extract a portion as an understanding:

Here is the time to make a specific drawing using customizations such as:

Drawing here, the author chose a very clever way, and not because of the tree nodes of the increase and decrease and the depth of the drawing caused by the problem, of course, not too dense. Here, the entire length of the x-axis is divided by the number of leaf nodes of the whole tree, and the y-axis length is divided by the depth of the tree, using the Plottree.xoff as the x-coordinate of a recently drawn leaf node. The Plottree.xoff will only change when the leaf node coordinates are drawn again, using Plottree.yoff as the depth of the current drawing, Plottree.yoff is in each recursive layer will be reduced by one (above the average split), the other time is to use the two coordinate points to calculate the non-leaf node, these two parameters can actually determine a point coordinate, this coordinate is determined when the node is drawn

plotTreeThe overall step of the function is divided into the following three steps:

    1. Draw yourself

    2. If the current child node is not a leaf node, recursion

    3. If the child node is a leaf node, draw the node

The following is a plotTree createPlot detailed parsing of the function, so the code for the two functions is taken out separately:

# implements the entire tree's drawing logic and coordinate operations, using recursive, important functions# Two of these global variables Plottree.xoff and Plottree.yoff# used to track the position of a drawn node and place the next node in the right place def plottree(tree, PARENTPT, info):    # Two functions were called to figure out the number of leaf nodes and the depth of the treeLeafnum = Getleafnum (tree) treedepth = gettreedepth (tree) Firstkey = Tree.keys () [0]# The text label for this nodeFIRSTPT = (Plottree.xoff + (1.0+ float (leafnum))/2.0/PLOTTREE.TOTALW, Plottree.yoff) inserttext (FIRSTPT, PARENTPT, info) plotnode (firstkey, FIRSTPT, pare NTPT, nonleafnodes) seconddict = Tree[firstkey] Plottree.yoff = Plottree.yoff-1.0/Plottree.totald forKeyinchSeconddict.keys ():ifType (Seconddict[key]). __name__ = =' Dict ': Plottree (Seconddict[key], firstpt, str (key))Else: Plottree.xoff = Plottree.xoff +1.0/PLOTTREE.TOTALW Plotnode (Seconddict[key], (Plottree.xoff, Plottree.yoff), FIRSTPT, LEAFNO DES) InsertText ((Plottree.xoff, Plottree.yoff), FIRSTPT, str (key)) Plottree.yoff = Plottree.yoff +1.0/Plottree.totald# The following function performs a true drawing operation, and the Plottree () function is just some logical and coordinate operations of the tree def createplot(intree):Fig = Plt.figure (1, Facecolor =' White ') FIG.CLF () Createplot.ax1 = Plt.subplot (111, Frameon =False)#, **axprops)    # Global Variables plottree.totalw and Plottree.totald    # used to store the width of the tree and the depth of the treePLOTTREE.TOTALW = Float (getleafnum (intree)) Plottree.totald = float (gettreedepth (intree)) Plottree.xoff =-0.5/PLOTTREE.TOTALW Plottree.yoff =1.0Plottree (Intree, (0.5,1.0),"') Plt.show ()

First, the code splits the entire drawing interval by the number and depth of leaves, x and y the total length of the axis is as 1 follows:

The explanations are as follows :

1. The square in the figure is the position of the non-leaf node, the position of the @ leaf node, so the length of a table should be: 1/plotTree.totalW , but the position of the leaf node should be the location @ , then at the beginning plotTree.xOff of the assignment is: -0.5/plotTree.totalW , meaning to start The position of the x axis is half the table distance to the left of the first table, and the benefit is that @ you can add an integer multiple at a later point in time 1/plotTree.totalW .

A code in the 2.plotTree function is as follows:

firstPt = (plotTree.xOff + (1.02.0/ plotTree.totalW, plotTree.yOff)

Where the variable is the plotTree.xOff axis coordinate of a recently drawn leaf node, each time it is determined that the current node x has several leaf nodes, the total distance of its leaf nodes is determined as follows: float(numLeafs)/plotTree.totalW So the position of the current node is the middle half of the distance of all its leaf nodes: float(numLeafs)/2.0/plotTree.totalW But since the start plotTree.xOff assignment is not from the 0 beginning, but the left half of the table, so also need to add half of the table distance is: 1/2/plotTree.totalW , then add up is: (1.0 + float(numLeafs))/2.0/plotTree.totalW , So the offset is determined, the x position of the axis becomes:plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW

3. plotTree() arguments about the function

plotTree(inTree, (0.51.0‘ ‘)

The plotTree() second parameter of the function is assigned a value (0.5, 1.0) because the starting root node is not underlined, so the position of the parent node and the current node needs to be coincident, taking advantage of 2 to determine the current node position (0.5, 1.0) .

Summary : The use of such a gradual increase in the axis of the coordinates x , and gradually reduce the y axis of the coordinates can be very good to take into account the number of leaves and depth of the tree, so the logical ratio of the graph is very good to determine, even if the image size changes, we can still see proportionally drawn tree chart.

Two. Use decision trees to predict contact lens types

Here is an example of using decision trees to predict the type of contact lenses a patient needs to wear. Here are the general steps for the overall forecast:

    1. Collect data: Use small datasets provided in the book

    2. Preparing data: Preprocessing data in text, such as parsing data rows

    3. Profiling data: Quickly examine the data and use the Createplot () function draws the final tree chart

    4. Training decision Tree: Use createtree () function training

    5. Test decision Tree: Write a simple test function validation Decision tree Output & drawing results

    6. Use a decision Tree: This section allows you to store a well-trained decision tree so you can use

      Here to create a new script file savetree.py , save the well-trained decision tree on disk, where you need to use the Python module's pickle serialization object. The storetree () function is responsible for storing the tree in the filename (. txt) file in the current directory, while the gettree (filename) The reads data about the decision tree in the filename (. txt) file in the current directory.

#-*-coding:utf-8-*-  import  pickledef   Storetree  :  fw = open (filename,  ' W ' ) pickle.dump (tree, FW) Fw.close ()  def  gettree   (filename) :  fr = Open ( FileName) return  pickle.load (FR)  

The following code implements an example of a decision tree predicting a contact lens model, using a data set of contact lens data sets, which contains observation conditions for the eye condition of many patients and the type of contact lens recommended by the physician, with contact lens types including: Hard material (rigid) , soft material (soft) and is not suitable for wearing contact lenses (no lenses) , data from the UCI database. The code finally calls the previously prepared Createplot () function to draw a tree chart.

#-*-coding:utf-8-*-  import  treeimport  plottreeimport  savetreefr = Open ( ' lenses.txt ' ) lensesdata = [Data.strip (). Split ( "\ T ' ) for  data in  fr.readlines ()] Lenseslabel = [ ' age ' ,  ' prescript ' ,  ' astigmatic ' ,  ' tearrate ' ]lensestree = Tree.buildtree ( Lensesdata, Lenseslabel)  #print lensesdata  print  lensestreeprint  plottree.createplot (lensestree)  

It can be seen that the early implementation of the decision tree construction and drawing, using different data sets can be very intuitive results, you can see, along the different branches of the decision tree, you can get different patients need to wear the type of contact lenses.

Three. Summary of decision trees used in this chapter

Back to the decision tree algorithm level, the above code implementation is based on the ID3 decision tree Construction algorithm, it is a very classical algorithm, but in fact, a lot of shortcomings. In fact, the use of decision trees often encounter a problem, that is, "over-matching." Sometimes, too many branch selection or matching options can have a negative effect on your decisions. In order to reduce the problem of over-matching, the algorithm designer usually chooses "pruning"in some practical situations. Simply put, if the leaf node can only add a little information, you can delete the node.

In addition, there are several popular decision tree construction algorithms: C4.5, C5.0 and cart, which need to be further researched in the later stage.

Reference: http://blog.sina.com.cn/s/blog_7399ad1f01014wec.html

Copyright NOTICE: This article for Bo Master original article, without Bo Master permission not reproduced.

Machine learning in Action Learning notes: Drawing a tree chart & predicting contact lens types using decision Trees

Contact Us

The content source of this page is from Internet, which doesn't represent Alibaba Cloud's opinion; products and services mentioned on that page don't have any relationship with Alibaba Cloud. If the content of the page makes you feel confusing, please write us an email, we will handle the problem within 5 days after receiving your email.

If you find any instances of plagiarism from the community, please send an email to: info-contact@alibabacloud.com and provide relevant evidence. A staff member will contact you within 5 working days.

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.