C++單刷《機器學習實戰》之二——決策樹
演算法概述:決策樹是用於分類的一種常用方法,根據資料集特徵值的不同,構造決策樹來將資料集不斷分成子資料集,直至決策樹下的每個分支都是同一類或用完所有的特徵值。
決策樹的一般流程:
(1)收集資料
(2)準備資料:樹構造演算法只適用於標稱型資料,因此數值型資料必須離散化,最好轉為bool型別。
(3)分析資料:尋找能夠最好地劃分資料集的特徵。
(4)訓練:構造樹的資料結構。
(5)測試:使用決策樹計算錯誤率。
(6)使用;將訓練好的決策樹用於分類。
1.獲取資料
對訓練集的要求:需要了解資料的特徵數量,每一個特徵所對應的特徵值,每一個數據的標籤。
在C++中我們不再通過矩陣來存取資料集,通過以下結構體來描述每一個數據,並將資料存入vector中:
struct data
{
int featNum; //特徵數量
vector<bool> features; //特徵值
string label; //標籤
};
vector<data> dataset; //資料集
書中給出了一個海洋生物資料的表格:
參考該表格建立資料集:
data d1; d1.featNum = 2; d1.features.push_back(true); d1.features.push_back(true); d1.label = "yes"; dataset.push_back(d1);
2.資訊增益
決策樹通過資料的某一項特徵來將資料集分類,有的特徵值可以較好的對資料分類,有些則不行。例如為了將貓和狗區分開來,選取是否會游泳,是否會爬樹這些特徵則分類效果較好,若選擇是否四肢著地奔跑這個特徵則分類效果會很差。構造決策樹的第一步就是要選擇能使得分類效果最好的特徵進行第一步分類,在分好的子資料集中再選擇次好的特徵值進一步分類,直至決策樹下每一分支都為同一類或用完所有的特徵值。
下圖為書中一個決策樹的流程圖,用於對郵件進行分類:
如何判斷一個特徵值用於分類的效果好壞,我們需要了解資訊增益的概念,並簡單介紹一下夏農熵,不過多討論資訊理論的內容。
夏農熵用於描述資訊的無序程度,舉個栗子,箱子
我們分類的最終目的,就是希望在資料集中取出的任一樣本,能以較大概率判斷其屬於某一類,使夏農熵達到最小。分類之前和之後資料的夏農熵之差稱為資訊增益。
如果待分類的資料可能劃分在多個分類之中,則符號xi的資訊定義為
其中p(xi)是選擇該分類的概率。
為了計算熵,我們需要計算所有類別所有可能值所包含的資訊,通過以下公式得到:
其中n是分類的數目,H是最終計算出的熵。有興趣的同學可以動手算算例子中從A中取球和從B 中取球的過程所對應的夏農熵。
計算夏農熵的程式碼實現:
double calcShannonEnt(vector<data> myData)
// 該函式返回特定資料集的夏農熵
// myData:要計算夏農熵的資料集
{
size_t numEntries = myData.size();
map<string,size_t> labelCounts;
for (auto it = myData.begin(); it != myData.end();it++)
{
string currentLabel = it->label;
if (labelCounts.count(currentLabel) == 0)
{
labelCounts[currentLabel] = 0;
}
labelCounts[currentLabel] += 1;
}
double shannonEnt = 0.0;
for (auto it_map = labelCounts.begin(); it_map != labelCounts.end(); it_map++)
{
double prob = (double)(it_map->second) / (double)numEntries;
shannonEnt -= prob * log2(prob);
}
return shannonEnt;
}
計算海洋生物資料集的夏農熵:
createDataset(); //建立資料集
cout << calcShannonEnt(dataset) << endl; //計算資料集的夏農熵
執行結果:
3.劃分資料集
先寫個按照給定特徵劃分資料集的函式:
vector<data> splitDataSet(vector<data> myData, int axis, bool value)
//按照給定特徵和特徵值劃分資料集
//myData:資料集 axis:給定特徵 value:該特徵的特徵值
{
vector<data> retDataSet;
for (auto it = myData.begin(); it != myData.end(); it++)
{
auto it_feat = it->features.begin();
auto it_axis = it_feat + axis;
data d;
d.featNum = it->featNum - 1;
if (*(it_axis) == value)
{
for (; it_feat != it->features.end(); it_feat++)
{
if (it_feat == it_axis)
continue;
bool temp = *(it_feat);
d.features.push_back(temp);
}
d.label = it->label;
retDataSet.push_back(d);
}
}
return retDataSet;
}
測試該函式:
int main()
{
createDataset(); //建立資料集
cout << "原始資料集:" << endl;
outputData(dataset); //原始資料集
int axis = 0;
bool value = true;
cout << endl;
cout << "選取第" << axis + 1 << "個特徵," << "特徵值為:" << value << endl;
cout << endl;
vector<data> retData = splitDataSet(dataset, axis, value); //找出第一個特徵中特徵值為i 的資料集
cout << "劃分後的子資料集:" << endl;
outputData(retData); //輸出
return 0;
} //輸出
執行結果:
從執行結果可以看到,前3條資料滿足第一個特徵值為1這個條件,劃分出了相應的子集。
用特徵對資料集進行分類後會使得夏農熵減小,獲得一定的資訊增益,找出使得資訊增益最大的分類方法進行分類。程式碼清單如下:
size_t chooseBestFeatureToSplit(vector<data> myData)
//找出最好的資料集劃分方式,即找出最合適的特徵用於分類
//myData:資料集
{
double baseEntropy = calcShannonEnt(myData); //計算原始資料集的夏農熵
double bestInfoGain = 0.0; //資訊增益的最大值
size_t bestFeature = -1; //最“好”的特徵
// auto it_feat = myData.begin()->features.begin();
for (int i = 0; i < myData.begin()->featNum;i++)
{
auto it = myData.begin();
set<bool> featSet;
for (; it != myData.end(); it++)
{
featSet.insert(it->features[i]);
}
double newEntory = 0;
for (auto it_feat = featSet.begin(); it_feat != featSet.end(); it_feat++) //計算每種劃分方式的夏農熵
{
vector<data> subDataSet = splitDataSet(dataset, i, *(it_feat));
double prob = (double)subDataSet.size() / (double)dataset.size();
newEntory += prob*calcShannonEnt(subDataSet);
}
double infoGain = baseEntropy - newEntory;
if (infoGain > bestInfoGain) //計算最好的資訊增益
{
bestInfoGain = infoGain;
bestFeature = i;
}
}
return bestFeature; //返回最好的特徵
}
該函式用於尋找使得分類後資訊增益最大的特徵。
測試該函式:
int main()
{
createDataset(); //建立資料集
cout << "原始資料集:" << endl;
cout << "不浮出水面可以生存" << '\t' << "是否有腳蹼" << '\t' << "屬於魚類" << endl;
outputData(dataset); //原始資料集
cout << endl;
cout << "最好的特徵為第 " << chooseBestFeatureToSplit(dataset) +1 <<" 個特徵"<< endl;
return 0;
}
執行結果:
4.構建決策樹
構建決策樹的過程,就是選取特徵來劃分資料集,直到劃分出所有的類別或用完所有的特徵屬性。
當所有的特徵屬性都用完後,可能劃分出來的資料集裡面仍然有不同的類別,此時通過多數表決的方式來決定該資料集的分類。
string majorityCnt(vector<data> myData)
//若葉子節點下有多個類別,採用多數表決的方式決定該葉子節點的分類
{
string result;
map<string, size_t> labelCounts;
for (auto it = myData.begin(); it != myData.end(); it++)
{
string currentLabel = it->label;
if (labelCounts.count(currentLabel) == 0)
{
labelCounts[currentLabel] = 0;
}
labelCounts[currentLabel] += 1;
}
auto it = labelCounts.begin();
result = it->first;
size_t num = it->second;
for (; it != labelCounts.end(); it++)
{
if (it->second > num)
{
num = it->second;
result = it->first;
}
}
return result;
}
接下來構建決策樹:
node* createTree(vector<data> myData)
//構造決策樹
{
node* root = new node();
auto it = myData.begin();
set<string> labels_set;
for (; it != myData.end(); it++)
{
labels_set.insert(it->label);
}
it = myData.begin();
if (myData.size() == 1 || labels_set.size() == 1) //若資料集只有一項或者只有一類,則返回該分類
{
string text = it->label;
root->label = text;
return root;
}
if (it->featNum == 0) //若資料集下特徵數量為0,則返回出現次數最多的分類
{
root->label = majorityCnt(myData);
return root;
}
size_t best_feat = chooseBestFeatureToSplit(myData); //選擇最好的特徵進行分類
root->feature = best_feat;
vector<data> left_data = splitDataSet(myData, best_feat, false); //將資料集按特徵分為兩類
vector<data> right_data = splitDataSet(myData, best_feat, true);
root->left = createTree(left_data); //建立左子樹和右子樹
root->right = createTree(right_data);
return root;
}
5.執行分類
決策樹構建完畢後,利用該決策樹執行分類:
string classify(node* tree, data input)
//利用構建好的決策樹執行分類
{
string result;
if (tree->label != "") //判斷是否為葉節點, 找到葉節點,返回結果
{
result = tree->label;
}
else
{
node* sub_tree = new node; //不是葉節點,執行遞迴遍歷
size_t best_feat = tree->feature; //在該節點執行分類用到的特徵
bool feat_val = input.features[best_feat];
if (!feat_val) //若特徵值為false,轉到左子樹,若為true,轉到右子樹,遞迴搜尋
{
sub_tree = tree->left;
}
else
{
sub_tree = tree->right;
}
input.featNum -= 1;
size_t index = 0;
for (auto it = input.features.begin(); it != input.features.end(); it++) //去掉用過的特徵值
{
if (index == best_feat)
{
input.features.erase(it);
break;
}
index++;
}
result = classify(sub_tree, input);
}
return result;
}
主函式測試:
int main()
{
createDataset(); //建立資料集
cout << "原始資料集:" << endl;
cout << "不浮出水面可以生存" << '\t' << "是否有腳蹼" << '\t' << "屬於魚類" << endl;
outputData(dataset); //原始資料集
node* tree = createTree(dataset);
data d;
bool val;
cout << "輸入特徵值:" << endl;
for (int i = 0; i < 2; i++)
{
cin >> val;
d.features.push_back(val);
}
string result = classify(tree, d);
cout << result << endl;
}
測試結果:
完整程式碼:
#include <iostream>
#include <cmath>
#include<map>
#include<string>
#include<sstream>
#include<fstream>
#include<vector>
#include<set>
#include<algorithm>
using namespace std;
struct data
{
int featNum; //特徵數量
vector<bool> features; //特徵值
string label; //標籤
};
vector<data> dataset; //資料集
struct node
{
size_t feature;
node* left;
node* right;
string label;
// node();
};
void createDataset()
{
//建立資料集
data d1;
d1.featNum = 2;
d1.features.push_back(true);
d1.features.push_back(true);
d1.label = "yes";
dataset.push_back(d1);
data d2;
d2.featNum = 2;
d2.features.push_back(true);
d2.features.push_back(true);
d2.label = "yes";
dataset.push_back(d2);
data d3;
d3.featNum = 2;
d3.features.push_back(true);
d3.features.push_back(false);
d3.label = "no";
dataset.push_back(d3);
data d4;
d4.featNum = 2;
d4.features.push_back(false);
d4.features.push_back(true);
d4.label = "no";
dataset.push_back(d4);
data d5;
d5.featNum = 2;
d5.features.push_back(false);
d5.features.push_back(true);
d5.label = "no";
dataset.push_back(d5);
}
//bool dataSet[5][3] = { { true, true, true }, { true, true, true }, { true, false, false }, { false, true, false },{false, true, false } };
//string labels[2] = { "no surfacing", "flippers" };
double calcShannonEnt(vector<data> myData)
// 該函式返回特定資料集的夏農熵
// myData:要計算夏農熵的資料集
{
size_t numEntries = myData.size();
map<string,size_t> labelCounts;
for (auto it = myData.begin(); it != myData.end();it++)
{
string currentLabel = it->label;
if (labelCounts.count(currentLabel) == 0)
{
labelCounts[currentLabel] = 0;
}
labelCounts[currentLabel] += 1;
}
double shannonEnt = 0.0;
for (auto it_map = labelCounts.begin(); it_map != labelCounts.end(); it_map++)
{
double prob = (double)(it_map->second) / (double)numEntries;
shannonEnt -= prob * log2(prob);
}
return shannonEnt;
}
vector<data> splitDataSet(vector<data> myData, int axis, bool value)
//按照給定特徵和特徵值劃分資料集
//myData:資料集 axis:給定特徵的索引 value:該特徵的特徵值
{
vector<data> retDataSet; //劃分的子資料集
for (auto it = myData.begin(); it != myData.end(); it++)
{
auto it_feat = it->features.begin();
auto it_axis = it_feat + axis;
data d;
d.featNum = it->featNum - 1;
if (*(it_axis) == value)
{
for (; it_feat != it->features.end(); it_feat++)
{
if (it_feat == it_axis)
continue;
bool temp = *(it_feat);
d.features.push_back(temp);
}
d.label = it->label;
retDataSet.push_back(d);
}
}
return retDataSet;
}
size_t chooseBestFeatureToSplit(vector<data> myData)
//找出最好的資料集劃分方式,即找出最合適的特徵用於分類
//myData:資料集
{
double baseEntropy = calcShannonEnt(myData); //計算原始資料集的夏農熵
double bestInfoGain = 0.0; //資訊增益的最大值
size_t bestFeature = -1; //最“好”的特徵
// auto it_feat = myData.begin()->features.begin();
for (int i = 0; i < myData.begin()->featNum;i++)
{
auto it = myData.begin();
set<bool> featSet;
for (; it != myData.end(); it++)
{
featSet.insert(it->features[i]);
}
double newEntory = 0;
for (auto it_feat = featSet.begin(); it_feat != featSet.end(); it_feat++) //計算每種劃分方式的夏農熵
{
vector<data> subDataSet = splitDataSet(dataset, i, *(it_feat));
double prob = (double)subDataSet.size() / (double)dataset.size();
newEntory += prob*calcShannonEnt(subDataSet);
}
double infoGain = baseEntropy - newEntory;
if (infoGain > bestInfoGain) //計算最好的資訊增益
{
bestInfoGain = infoGain;
bestFeature = i;
}
}
return bestFeature; //返回最好的特徵
}
void outputData(vector<data> myData)
{
for (auto it = myData.begin(); it != myData.end(); it++)
{
auto it_feat = it->features.begin();
for (; it_feat != it->features.end(); it_feat++)
{
cout << *(it_feat) << '\t';
}
cout << it->label << endl;
}
}
string majorityCnt(vector<data> myData)
//若葉子節點下有多個類別,採用多數表決的方式決定該葉子節點的分類
{
string result;
map<string, size_t> labelCounts;
for (auto it = myData.begin(); it != myData.end(); it++)
{
string currentLabel = it->label;
if (labelCounts.count(currentLabel) == 0)
{
labelCounts[currentLabel] = 0;
}
labelCounts[currentLabel] += 1;
}
auto it = labelCounts.begin();
result = it->first;
size_t num = it->second;
for (; it != labelCounts.end(); it++)
{
if (it->second > num)
{
num = it->second;
result = it->first;
}
}
return result;
}
node* createTree(vector<data> myData)
//構造決策樹
{
node* root = new node();
auto it = myData.begin();
set<string> labels_set;
for (; it != myData.end(); it++)
{
labels_set.insert(it->label);
}
it = myData.begin();
if (myData.size() == 1 || labels_set.size() == 1) //若資料集只有一項或者只有一類,則返回該分類
{
string text = it->label;
root->label = text;
return root;
}
if (it->featNum == 0) //若資料集下特徵數量為0,則返回出現次數最多的分類
{
root->label = majorityCnt(myData);
return root;
}
size_t best_feat = chooseBestFeatureToSplit(myData); //選擇最好的特徵進行分類
root->feature = best_feat;
vector<data> left_data = splitDataSet(myData, best_feat, false); //將資料集按特徵分為兩類
vector<data> right_data = splitDataSet(myData, best_feat, true);
root->left = createTree(left_data); //建立左子樹和右子樹
root->right = createTree(right_data);
return root;
}
string classify(node* tree, data input)
//利用構建好的決策樹執行分類
{
string result;
if (tree->label != "") //判斷是否為葉節點, 找到葉節點,返回結果
{
result = tree->label;
}
else
{
node* sub_tree = new node; //不是葉節點,執行遞迴遍歷
size_t best_feat = tree->feature; //在該節點執行分類用到的特徵
bool feat_val = input.features[best_feat];
if (!feat_val) //若特徵值為false,轉到左子樹,若為true,轉到右子樹,遞迴搜尋
{
sub_tree = tree->left;
}
else
{
sub_tree = tree->right;
}
input.featNum -= 1;
size_t index = 0;
for (auto it = input.features.begin(); it != input.features.end(); it++) //去掉用過的特徵值
{
if (index == best_feat)
{
input.features.erase(it);
break;
}
index++;
}
result = classify(sub_tree, input);
}
return result;
}
int main()
{
createDataset(); //建立資料集
cout << "原始資料集:" << endl;
cout << "不浮出水面可以生存" << '\t' << "是否有腳蹼" << '\t' << "屬於魚類" << endl;
outputData(dataset); //原始資料集
node* tree = createTree(dataset);
data d;
bool val;
cout << "輸入特徵值:" << endl;
for (int i = 0; i < 2; i++)
{
cin >> val;
d.features.push_back(val);
}
string result = classify(tree, d);
cout << "屬於魚類:" << result << endl;
}
//createDataset(); //建立資料集
//cout << "原始資料集:" << endl;
//cout << "不浮出水面可以生存" << '\t' << "是否有腳蹼" << '\t' << "屬於魚類" << endl;
//outputData(dataset); //原始資料集
//cout << endl;
//cout << "最好的劃分特徵為第 " << chooseBestFeatureToSplit(dataset) +1 <<" 個特徵"<< endl;
//int axis = 0;
//bool value = true;
//createDataset(); //建立資料集
//cout << "原始資料集:" << endl;
//cout << "不浮出水面可以生存" << '\t' << "是否有腳蹼" << '\t' << "屬於魚類" << endl;
//outputData(dataset); //原始資料集
//cout << "選取第" << axis + 1 << "個特徵," << "特徵值為:" << value << endl;
//cout << endl;
//vector<data> retData = splitDataSet(dataset, axis, value); //找出第一個特徵中特徵值為i 的資料集
//cout << "劃分後的子資料集:" << endl;
//cout << "是否有腳蹼" << '\t' << "屬於魚類" << endl;
//outputData(retData); //輸出
//auto it = retData.begin();
//cout << "特徵數量" << '\t'<<it->featNum<<endl;
//vector<data> retData2 = splitDataSet(retData, axis, value);
//outputData(retData2);
//auto it2 = retData2.begin();
//cout << "特徵數量" << '\t' << it2->featNum << endl;
// cout << calcShannonEnt(dataset) << endl; //計算資料集的夏農熵
// cout << chooseBestFeatureToSplit(dataset) << endl;
// vector<data> retData = splitDataSet(dataset, 0, 0);
// outputData(retData);
// cout << calcShannonEnt(dataset) << endl;