1. 程式人生 > >機器學習實戰(二)決策樹DT(Decision Tree、ID3演算法)

機器學習實戰(二)決策樹DT(Decision Tree、ID3演算法)

目錄

0. 前言

1. 資訊增益(ID3)

2. 決策樹(Decision Tree)

3. 實戰案例

3.1. 隱形眼鏡案例

3.2. 儲存決策樹

3.3. 決策樹畫圖表示


學習完機器學習實戰的決策樹,簡單的做個筆記。文中部分描述屬於個人消化後的理解,僅供參考。

所有程式碼和資料可以訪問 我的 github

如果這篇文章對你有一點小小的幫助,請給個關注喔~我會非常開心的~

0. 前言

決策樹(Decision Tree)的執行流程很好理解,如下圖所示(圖源:西瓜書),在樹上的每一個結點進行判斷,選擇分支,直到走到葉子結點,得出分類:

  • 優點:計算複雜度不高、輸出結果易於理解、對缺失值不敏感
  • 缺點:可能會產生過擬合
  • 適用資料型別:數值型和標稱型(數值型資料需要離散化)

決策樹構建中,目標就是找到當前哪個特徵在劃分資料時起到決定性作用,劃分資料有多種辦法,如資訊增益(ID3)、資訊增益率(C4.5)、基尼係數(CART),本篇主要介紹資訊增益(ID3演算法)。

1. 資訊增益(ID3)

首先,介紹夏農熵(entropy),熵定義為資訊的期望值,熵越高,說明資訊的混亂程度越高

Ent(D)=-\sum_{k=1}^{\left|\gamma \right|}p(k)\log_{2}p(k)

其中,D 表示資料集,k 表示資料集中的每一個類別,p(k) 表示這個屬於類別的資料佔所有資料的比例。

資訊增益(information gain)定義為原始的熵減去當前的熵,增益越大,說明當前熵越小,說明資料混亂程度越小

Gain(D,a)=Ent(D)-\sum_{v=1}^{V}\frac{\left|D^v\right|}{\left|D\right|}Ent(D^v)

其中,V 表示按照此特徵劃分的子集數量,v 表示第 v 個子集,Ent(D^v) 表示子集的資訊熵,\frac{\left|D^v\right|}{\left|D\right|} 表示子集資料佔所有資料的比例。

注:資訊增益更偏向於選擇取值較多的特徵,這是它的缺點。

2. 決策樹(Decision Tree)

演算法流程可簡單表示為:

  1. 遍歷當前資料所有的特徵,計算資訊增益最大的特徵,作為當前劃分資料的結點,並去除此特徵
  2. 對劃分後每個分支上的子集繼續進行步驟 1
     
  3. 如果當前子集內的資料都是同一型別,則停止劃分,標記葉子結點
  4. 如果子集內資料還未統一型別,而已經沒有特徵,則採用多數表決原則

3. 實戰案例

以下將展示書中的三個案例的程式碼段,所有程式碼和資料可以在github中下載:

3.1. 隱形眼鏡案例

# coding:utf-8
from math import log
import operator
import pickle

"""
隱形眼鏡案例
"""


# 計算夏農熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt


# 按照給定特徵劃分資料集
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    # 只選擇第 axis 列的值為 value 的資料
    # 去除這個特徵,取資料[:axis] 和 [axis+1:] 段
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


# 選擇最好的資料集劃分方式
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    # 遍歷每一個特徵
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        # 遍歷這個特徵的所有特徵值
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            # 判斷這個子集佔所有資料集的比例
            prob = len(subDataSet) / float(len(dataSet))
            # 新的資訊熵 = 所有子集的資訊熵乘以比例再求和
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature


# 多數表決原則
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(),
                              key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


# 建立決策樹
# labels 為特徵的標籤
def createTree(dataSet, labels):
    # 獲取當前資料集最後一列的類別資訊
    classList = [example[-1] for example in dataSet]
    # 如果最後一列都是一種類別
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    # 如果當前資料集沒有可劃分的特徵
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    # 獲取最好的劃分資料集特徵
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    # 在特徵標籤中刪除當前特徵
    del (labels[bestFeat])
    # 獲取這個特徵的列,遍歷此特徵的所有特徵值
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        # 特徵有幾個取值,這個結點就有幾個分支
        # 每個取值,都劃分出子集,遞迴建樹
        myTree[bestFeatLabel][value] = createTree(
            splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree


# 分類函式
def classify(inputTree, featLabels, testVec):
    # 獲取第一個特徵
    firstStr = list(inputTree.keys())[0]
    # 獲取這個特徵下的鍵值對的值
    secondDict = inputTree[firstStr]
    # 獲取這個特徵的索引
    featIndex = featLabels.index(firstStr)
    # 遍歷每一個分支
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            # 判斷當前分支下是否還有分支
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel


if __name__ == '__main__':
    fr = open('lenses.txt')
    lenses = [inst.strip().split('\t') for inst in fr.readlines()]
    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
    lensesTree = createTree(lenses, lensesLabels)
    print(lensesTree)

3.2. 儲存決策樹

# 儲存樹
def storeTree(inputTree, filename):
    fw = open(filename, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()


# 取出儲存的樹
def grabTree(filename):
    fr = open(filename, 'rb')
    return pickle.load(fr)

3.3. 決策樹畫圖表示

# coding:utf-8
import matplotlib.pyplot as plt

# 解決顯示中文問題
from pylab import *

mpl.rcParams['font.sans-serif'] = ['SimHei']

"""
決策樹畫圖
"""


# 建立樹的字典
def retrieveTree(i):
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                   {'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {
                       'prescript': {'hyper': {'age': {'pre': 'no lenses', 'young': 'hard', 'presbyopic': 'no lenses'}},
                                     'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'young': 'soft', 'presbyopic': {
                       'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}}}}}}}
                   ]
    return listOfTrees[i]


# 獲取葉節點的數目
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


# 獲取樹的層數
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth


# 使用文字註解繪製樹節點
decisionNode = dict(boxstyle='sawtooth', fc='0.8')
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')


# 畫節點
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va='center', ha='center', bbox=nodeType,
                            arrowprops=arrow_args)


# 在父子節點間填充文字資訊
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)


# 畫樹
def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,
              plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),
                     cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD


# 主要畫圖函式
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()


if __name__ == '__main__':
    myTree = retrieveTree(1)
    createPlot(myTree)

如果這篇文章對你有一點小小的幫助,請給個關注喔~我會非常開心的~