機器學習實戰(二)決策樹DT(Decision Tree、ID3演算法)
阿新 • • 發佈:2018-11-02
目錄
學習完機器學習實戰的決策樹,簡單的做個筆記。文中部分描述屬於個人消化後的理解,僅供參考。
所有程式碼和資料可以訪問 我的 github
如果這篇文章對你有一點小小的幫助,請給個關注喔~我會非常開心的~
0. 前言
決策樹(Decision Tree)的執行流程很好理解,如下圖所示(圖源:西瓜書),在樹上的每一個結點進行判斷,選擇分支,直到走到葉子結點,得出分類:
- 優點:計算複雜度不高、輸出結果易於理解、對缺失值不敏感
- 缺點:可能會產生過擬合
- 適用資料型別:數值型和標稱型(數值型資料需要離散化)
決策樹構建中,目標就是找到當前哪個特徵在劃分資料時起到決定性作用,劃分資料有多種辦法,如資訊增益(ID3)、資訊增益率(C4.5)、基尼係數(CART),本篇主要介紹資訊增益(ID3演算法)。
1. 資訊增益(ID3)
首先,介紹夏農熵(entropy),熵定義為資訊的期望值,熵越高,說明資訊的混亂程度越高:
其中, 表示資料集, 表示資料集中的每一個類別, 表示這個屬於類別的資料佔所有資料的比例。
資訊增益(information gain)定義為原始的熵減去當前的熵,增益越大,說明當前熵越小,說明資料混亂程度越小:
其中, 表示按照此特徵劃分的子集數量, 表示第 個子集, 表示子集的資訊熵, 表示子集資料佔所有資料的比例。
注:資訊增益更偏向於選擇取值較多的特徵,這是它的缺點。
2. 決策樹(Decision Tree)
演算法流程可簡單表示為:
- 遍歷當前資料所有的特徵,計算資訊增益最大的特徵,作為當前劃分資料的結點,並去除此特徵
- 對劃分後每個分支上的子集繼續進行步驟
- 如果當前子集內的資料都是同一型別,則停止劃分,標記葉子結點
- 如果子集內資料還未統一型別,而已經沒有特徵,則採用多數表決原則
3. 實戰案例
以下將展示書中的三個案例的程式碼段,所有程式碼和資料可以在github中下載:
3.1. 隱形眼鏡案例
# coding:utf-8
from math import log
import operator
import pickle
"""
隱形眼鏡案例
"""
# 計算夏農熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
# 按照給定特徵劃分資料集
def splitDataSet(dataSet, axis, value):
retDataSet = []
# 只選擇第 axis 列的值為 value 的資料
# 去除這個特徵,取資料[:axis] 和 [axis+1:] 段
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
# 選擇最好的資料集劃分方式
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
# 遍歷每一個特徵
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
# 遍歷這個特徵的所有特徵值
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
# 判斷這個子集佔所有資料集的比例
prob = len(subDataSet) / float(len(dataSet))
# 新的資訊熵 = 所有子集的資訊熵乘以比例再求和
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
# 多數表決原則
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),
key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
# 建立決策樹
# labels 為特徵的標籤
def createTree(dataSet, labels):
# 獲取當前資料集最後一列的類別資訊
classList = [example[-1] for example in dataSet]
# 如果最後一列都是一種類別
if classList.count(classList[0]) == len(classList):
return classList[0]
# 如果當前資料集沒有可劃分的特徵
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 獲取最好的劃分資料集特徵
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
# 在特徵標籤中刪除當前特徵
del (labels[bestFeat])
# 獲取這個特徵的列,遍歷此特徵的所有特徵值
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
# 特徵有幾個取值,這個結點就有幾個分支
# 每個取值,都劃分出子集,遞迴建樹
myTree[bestFeatLabel][value] = createTree(
splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
# 分類函式
def classify(inputTree, featLabels, testVec):
# 獲取第一個特徵
firstStr = list(inputTree.keys())[0]
# 獲取這個特徵下的鍵值對的值
secondDict = inputTree[firstStr]
# 獲取這個特徵的索引
featIndex = featLabels.index(firstStr)
# 遍歷每一個分支
for key in secondDict.keys():
if testVec[featIndex] == key:
# 判斷當前分支下是否還有分支
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
if __name__ == '__main__':
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses, lensesLabels)
print(lensesTree)
3.2. 儲存決策樹
# 儲存樹
def storeTree(inputTree, filename):
fw = open(filename, 'wb')
pickle.dump(inputTree, fw)
fw.close()
# 取出儲存的樹
def grabTree(filename):
fr = open(filename, 'rb')
return pickle.load(fr)
3.3. 決策樹畫圖表示
# coding:utf-8
import matplotlib.pyplot as plt
# 解決顯示中文問題
from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei']
"""
決策樹畫圖
"""
# 建立樹的字典
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {
'prescript': {'hyper': {'age': {'pre': 'no lenses', 'young': 'hard', 'presbyopic': 'no lenses'}},
'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'young': 'soft', 'presbyopic': {
'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}}}}}}}
]
return listOfTrees[i]
# 獲取葉節點的數目
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
# 獲取樹的層數
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
# 使用文字註解繪製樹節點
decisionNode = dict(boxstyle='sawtooth', fc='0.8')
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')
# 畫節點
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va='center', ha='center', bbox=nodeType,
arrowprops=arrow_args)
# 在父子節點間填充文字資訊
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
# 畫樹
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,
plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),
cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
# 主要畫圖函式
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
if __name__ == '__main__':
myTree = retrieveTree(1)
createPlot(myTree)
如果這篇文章對你有一點小小的幫助,請給個關注喔~我會非常開心的~