《機器學習實戰》學習筆記:繪製樹形圖&使用決策樹預測隱形眼鏡類型,
上一節實現了決策樹,但只是使用包含樹結構資訊的嵌套字典來實現,其表示形式較難理解,顯然,繪製直觀的二叉樹圖是十分必要的。Python沒有提供內建的繪製樹工具,需要自己編寫函數,結合Matplotlib庫建立自己的樹形圖。這一部分的代碼多而複雜,涉及二維座標運算;書裡的代碼雖然可用,但函數和各種變數非常多,感覺非常淩亂,同時大量使用遞迴,因此只能反覆研究,反反覆複用了一天多時間,才差不多搞懂,因此需要備忘一下。
一.繪製屬性圖
這裡使用Matplotlib的註解工具annotations實現決策樹繪製的各種細節,包括產生節點處的文字框、添加文本注釋、提供對文字著色等等。在畫一整顆樹之前,最好先掌握單個樹節點的繪製。一個簡單一實例如下:
# -*- coding: utf-8 -*-"""Created on Fri Sep 04 01:15:01 2015@author: Herbert"""import matplotlib.pyplot as pltnonLeafNodes = dict(boxstyle = "sawtooth", fc = "0.8")leafNodes = dict(boxstyle = "round4", fc = "0.8")line = dict(arrowstyle = "<-")def plotNode(nodeName, targetPt, parentPt, nodeType): createPlot.ax1.annotate(nodeName, xy = parentPt, xycoords = \ 'axes fraction', xytext = targetPt, \ textcoords = 'axes fraction', va = \ "center", ha = "center", bbox = nodeType, \ arrowprops = line)def createPlot(): fig = plt.figure(1, facecolor = 'white') fig.clf() createPlot.ax1 = plt.subplot(111, frameon = False) plotNode('nonLeafNode', (0.2, 0.1), (0.4, 0.8), nonLeafNodes) plotNode('LeafNode', (0.8, 0.1), (0.6, 0.8), leafNodes) plt.show()createPlot()
輸出結果:
該執行個體中,plotNode()
函數用於繪製箭頭和節點,該函數每調用一次,將繪製一個箭頭和一個節點。後面對於該函數有比較詳細的解釋。createPlot()
函數建立了輸出映像的對話方塊並對齊進行一些簡單的設定,同時調用了兩次plotNode()
,產生一對節點和指向節點的箭頭。
繪製整顆樹
這部分的函數和變數較多,為方便日後擴充功能,需要給出必要的標註:
# -*- coding: utf-8 -*-"""Created on Fri Sep 04 01:15:01 2015@author: Herbert"""import matplotlib.pyplot as plt# 部分代碼是對繪製圖形的一些定義,主要定義了文字框和剪頭的格式nonLeafNodes = dict(boxstyle = "sawtooth", fc = "0.8")leafNodes = dict(boxstyle = "round4", fc = "0.8")line = dict(arrowstyle = "<-")# 使用遞迴計算樹的葉子節點數目def getLeafNum(tree): num = 0 firstKey = tree.keys()[0] secondDict = tree[firstKey] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': num += getLeafNum(secondDict[key]) else: num += 1 return num# 同葉子節點計算函數,使用遞迴計算決策樹的深度 def getTreeDepth(tree): maxDepth = 0 firstKey = tree.keys()[0] secondDict = tree[firstKey] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': depth = getTreeDepth(secondDict[key]) + 1 else: depth = 1 if depth > maxDepth: maxDepth = depth return maxDepth# 在前面例子已實現的函數,用於注釋形式繪製節點和箭頭def plotNode(nodeName, targetPt, parentPt, nodeType): createPlot.ax1.annotate(nodeName, xy = parentPt, xycoords = \ 'axes fraction', xytext = targetPt, \ textcoords = 'axes fraction', va = \ "center", ha = "center", bbox = nodeType, \ arrowprops = line)# 用於繪製剪頭線上的標註,涉及座標計算,其實就是兩個點座標的中心處添加標註 def insertText(targetPt, parentPt, info): xCoord = (parentPt[0] - targetPt[0]) / 2.0 + targetPt[0] yCoord = (parentPt[1] - targetPt[1]) / 2.0 + targetPt[1] createPlot.ax1.text(xCoord, yCoord, info)# 實現整個樹的繪製邏輯和座標運算,使用的遞迴,重要的函數# 其中兩個全域變數plotTree.xOff和plotTree.yOff# 用於追蹤已繪製的節點位置,並放置下個節點的恰當位置def plotTree(tree, parentPt, info): # 分別調用兩個函數算出樹的葉子節點數目和樹的深度 leafNum = getLeafNum(tree) treeDepth = getTreeDepth(tree) firstKey = tree.keys()[0] # the text label for this node firstPt = (plotTree.xOff + (1.0 + float(leafNum)) / 2.0/plotTree.totalW,\ plotTree.yOff) insertText(firstPt, parentPt, info) plotNode(firstKey, firstPt, parentPt, nonLeafNodes) secondDict = tree[firstKey] plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': plotTree(secondDict[key], firstPt, str(key)) else: plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), \ firstPt, leafNodes) insertText((plotTree.xOff, plotTree.yOff), firstPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD# 以下函數執行真正的繪圖操作,plotTree()函數只是樹的一些邏輯和座標運算def createPlot(inTree): fig = plt.figure(1, facecolor = 'white') fig.clf() createPlot.ax1 = plt.subplot(111, frameon = False) #, **axprops) # 全域變數plotTree.totalW和plotTree.totalD # 用於儲存樹的寬度和樹的深度 plotTree.totalW = float(getLeafNum(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5 / plotTree.totalW plotTree.yOff = 1.0 plotTree(inTree, (0.5, 1.0), ' ') plt.show()# 一個小的測試集def retrieveTree(i): listOfTrees = [{'no surfacing':{0: 'no', 1:{'flippers':{0:'no', 1:'yes'}}}},\ {'no surfacing':{0: 'no', 1:{'flippers':{0:{'head':{0:'no', \ 1:'yes'}}, 1:'no'}}}}] return listOfTrees[i]createPlot(retrieveTree(1)) # 調用測試集中一棵樹進行繪製
retrieveTree()
函數中包含兩顆獨立的樹,分別輸入參數即可返回樹的參數tree
,最後執行createPlot(tree)
即得到畫圖的結果,如下所示:
書中關於遞迴計算樹的葉子節點和深度這部分十分簡單,在編寫繪製屬性圖的函數時,難度在於這本書中一些繪圖座標的取值以及在計算節點座標所作的處理,書中對於這部分的解釋比較散亂。部落格:http://www.cnblogs.com/fantasy01/p/4595902.html 給出了十分詳盡的解釋,包括座標的求解和公式的分析,以下只摘取一部分作為瞭解:
這裡說一下具體繪製的時候是利用自訂,如:
這裡繪圖,作者選取了一個很聰明的方式,並不會因為樹的節點的增減和深度的增減而導致繪製出來的圖形出現問題,當然不能太密集。這裡利用整 棵樹的葉子節點數作為份數將整個x軸的長度進行平均切分,利用樹的深度作為份數將y軸長度作平均切分,並利用plotTree.xOff作為最近繪製的一 個葉子節點的x座標,當再一次繪製葉子節點座標的時候才會plotTree.xOff才會發生改變;用plotTree.yOff作為當前繪製的深 度,plotTree.yOff是在每遞迴一層就會減一份(上邊所說的按份平均切分),其他時候是利用這兩個座標點去計算非葉子節點,這兩個參數其實就可 以確定一個點座標,這個座標確定的時候就是繪製節點的時候
plotTree
函數的整體步驟分為以下三步:
以下是plotTree
和createPlot
函數的詳細解析,因此把兩個函數的代碼單獨拿出來了:
# 實現整個樹的繪製邏輯和座標運算,使用的遞迴,重要的函數# 其中兩個全域變數plotTree.xOff和plotTree.yOff# 用於追蹤已繪製的節點位置,並放置下個節點的恰當位置def plotTree(tree, parentPt, info): # 分別調用兩個函數算出樹的葉子節點數目和樹的深度 leafNum = getLeafNum(tree) treeDepth = getTreeDepth(tree) firstKey = tree.keys()[0] # the text label for this node firstPt = (plotTree.xOff + (1.0 + float(leafNum)) / 2.0/plotTree.totalW,\ plotTree.yOff) insertText(firstPt, parentPt, info) plotNode(firstKey, firstPt, parentPt, nonLeafNodes) secondDict = tree[firstKey] plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': plotTree(secondDict[key], firstPt, str(key)) else: plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), \ firstPt, leafNodes) insertText((plotTree.xOff, plotTree.yOff), firstPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD# 以下函數執行真正的繪圖操作,plotTree()函數只是樹的一些邏輯和座標運算def createPlot(inTree): fig = plt.figure(1, facecolor = 'white') fig.clf() createPlot.ax1 = plt.subplot(111, frameon = False) #, **axprops) # 全域變數plotTree.totalW和plotTree.totalD # 用於儲存樹的寬度和樹的深度 plotTree.totalW = float(getLeafNum(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5 / plotTree.totalW plotTree.yOff = 1.0 plotTree(inTree, (0.5, 1.0), ' ') plt.show()
首先代碼對整個畫圖區間根據葉子節點數和深度進行平均切分,並且x
和y
軸的總長度均為1
,如同:
解釋如下:
1.圖中的方形為非葉子節點的位置,@
是葉子節點的位置,因此的一個表格的長度應該為: 1/plotTree.totalW
,但是葉子節點的位置應該為@
所在位置,則在開始的時候 plotTree.xOff
的賦值為: -0.5/plotTree.totalW
,即意為開始x
軸位置為第一個表格左邊的半個表格距離位置,這樣作的好處是在以後確定@
位置時候可以直接加整數倍的 1/plotTree.totalW
。
2.plotTree函數中的一句代碼如下:
firstPt = (plotTree.xOff + (1.0 + float(leafNum)) / 2.0/ plotTree.totalW, plotTree.yOff)
其中,變數plotTree.xOff
即為最近繪製的一個葉子節點的x
軸座標,在確定當前節點位置時每次只需確定當前節點有幾個葉子節點,因此其葉子節點所佔的總距離就確定了即為: float(numLeafs)/plotTree.totalW
,因此當前節點的位置即為其所有葉子節點所佔距離的中間即一半為: float(numLeafs)/2.0/plotTree.totalW
,但是由於開始plotTree.xOff
賦值並非從0
開始,而是左移了半個表格,因此還需加上半個表格距離即為: 1/2/plotTree.totalW
,則加起來便為: (1.0 + float(numLeafs))/2.0/plotTree.totalW
,因此位移量確定,則x
軸的位置變為: plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW
3.關於plotTree()
函數的參數
plotTree(inTree, (0.5, 1.0), ' ')
對plotTree()
函數的第二個參數賦值為(0.5, 1.0)
,因為開始的根節點並不用劃線,因此父節點和當前節點的位置需要重合,利用2中的確定當前節點的位置為(0.5, 1.0)
。
總結:利用這樣的逐漸增加x
軸的座標,以及逐漸降低y
軸的座標能能夠很好的將樹的葉子節點數和深度考慮進去,因此圖的邏輯比例就很好的確定了,即使映像尺寸改變,我們仍然可以看到按比例繪製的樹形圖。
二.使用決策樹預測隱形眼鏡類型
這裡實現一個例子,即利用決策樹預測一個患者需要佩戴的隱形眼鏡類型。以下是整個預測的大體步驟:
# -*- coding: utf-8 -*-"""Created on Sat Sep 05 01:56:04 2015@author: Herbert"""import pickledef storeTree(tree, filename): fw = open(filename, 'w') pickle.dump(tree, fw) fw.close()def getTree(filename): fr = open(filename) return pickle.load(fr)
以下代碼實現了決策樹預測隱形眼鏡模型的執行個體,使用的資料集是隱形眼鏡資料集,它包含很多患者的眼部狀況的觀察條件以及醫生推薦的隱形眼鏡類型,其中隱形眼鏡類型包括:硬材質(hard)
、軟材質(soft)
和不適合佩戴隱形眼鏡(no lenses)
, 資料來源於UCI資料庫。代碼最後調用了之前準備好的createPlot()
函數繪製樹形圖。
# -*- coding: utf-8 -*-"""Created on Sat Sep 05 14:21:43 2015@author: Herbert"""import treeimport plotTreeimport saveTreefr = open('lenses.txt')lensesData = [data.strip().split('\t') for data in fr.readlines()]lensesLabel = ['age', 'prescript', 'astigmatic', 'tearRate']lensesTree = tree.buildTree(lensesData, lensesLabel)#print lensesDataprint lensesTreeprint plotTree.createPlot(lensesTree)
可以看到,前期實現了決策樹的構建和繪製,使用不同的資料集都可以得到很直觀的結果,可以看到,沿著決策樹的不同分支,可以得到不同患者需要佩戴的隱形眼鏡的類型。
三.關於本章使用的決策樹的總結
回到決策樹的演算法層面,以上代碼的實現基於ID3決策樹構造演算法,它是一個非常經典的演算法,但其實缺點也不少。實際上決策樹的使用中常常會遇到一個問題,即“過度匹配”。有時候,過多的分支選擇或匹配選項會給決策帶來負面的效果。為了減少過度匹配的問題,通常演算法設計者會在一些實際情況中選擇“剪枝”。簡單說來,如果葉子節點只能增加少許資訊,則可以刪除該節點。
另外,還有幾種目前很流行的決策樹構造演算法:C4.5、C5.0和CART,後期需繼續深入研究。
參考資料:http://blog.sina.com.cn/s/blog_7399ad1f01014wec.html
著作權聲明:本文為博主原創文章,未經博主允許不得轉載。