1. 程式人生 > >決策樹——CART——之R語言rpart包

決策樹——CART——之R語言rpart包

R是一種用於統計計算與作圖的開源軟體,同時也是一種程式語言,它廣泛應用於企業和學術界的資料分析領域,正在成為最通用的語言之一。由於近幾年資料探勘、大資料等概念的走紅,R也越來越多的被人關注。

一、環境準備

作業系統windows

下載安裝地址:http://mirrors.xmu.edu.cn/CRAN/

下載安裝RStudio一個非常實用的R語言的IDE,是一個免費的軟體

二、下載安裝rpart

點選InstallPackages

packages中輸入:rpart->Install

等待安裝完成。

注:Installfrom中可以選擇安裝方式,圖中顯示的是從CRAN中通過網路連線下載,也可以選擇在本地檔案中尋找package

安裝。

三、下載資料

我試驗的資料是從UCI上下載的資料,當然也可以用rpart包中自帶的資料集。

將資料解壓後,會有許多.csv格式的資料檔案,在這裡本次試驗選擇的是bank.csv

四、使用R讀取資料

開啟RStudio

在控制檯

或者新建一個R Script

接下來我們在R Script中書寫程式碼,同樣,也可以在控制檯上一行一行的書寫,一條一條的執行,但是,程式碼換行時要按shift+Enter

 bank <- read.csv("D:/data/MachineLearning/bank/bank.csv",header=TRUE,sep=";") #讀取bank.csv資料檔案
 #注意:windows檔案路徑複製後文件的分隔符為“\”,但是R語言中不識別這種分隔符,她只識別“/”,
 #header=TRUE表示使用檔案的頭標籤,預設為FALSE,sep=";"表示資料用分號分隔,預設為"",
 bank_train <- bank[1:4000,] #對讀入的資料人為分割為訓練組和測試組,
 bank_test <- bank[4001:4521,1:16]
 bank_test1 <- bank[4001:4521,]
 library(rpart)  #在使用包前首先要使用該命令匯入包,也可以在Packages中包前的框框中打鉤
 fit <- rpart(y~age+job+marital+education+default+balance+housing+loan+contact
              +day+month+duration+campaign+pdays+previous+poutcome,method="class",
              data=bank_train)
 # 我們可以使用help(rpart)來獲取rpart的使用幫助,幫助文件Usage如下
 # rpart(formula, data, weights, subset, na.action = na.rpart, method,
 # model = FALSE, x = FALSE, y = TRUE, parms, control, cost, ...)
 # 在這裡我們只設置formula,data,model這三個引數
 plot(fit,uniform=TRUE,main="Classification Tree for Bank") #畫決策樹圖
 text(fit,use.n=TRUE,all=TRUE)
 #至此,第一個決策樹圖畫好了,第一個訓練的模型儲存在fit中
 
 #下面我們對測試資料進行預測(此處預測的是y值是yes or no)
 result <- predict(fit,bank_test,type="class") 
 # 在控制檯中直接輸入result即可檢視預測的結果,由於數目較多,我們寫一個小的程式,將預測
 # 結果同真實值比較一下,看正確率有多少
 # 詳情見 count_result.R
 # 我們寫完的函式儲存在本地磁碟中,使用時必須指明路徑,使用source()函式
 source("D:/work/R_work/count_result.R")
 count_result(result,bank_test1) #結果為0.9021
 
 #通過觀察資料,我們可以發現,在poutcome與contact屬性中,有許多unknown的值,
 #通過summary(bank)我們可以看到,unknown值在其所在屬性框中所佔比例過大,而且該
 #值其實為缺失值,所以我們使用rpart()函式中的na.action引數,來處理缺失值
 #由於R只識別NA缺失值,所以我們需要對資料框中的unknown值進行處理
 n <- nrow(bank)  #獲得data的行數
 for (i in 1:n){
   if(bank[i,9]=="unknown"){   #判斷第i,9個數據是否為unknown
     bank[i,9] <- NA           #將第i,9個數據替換為NA
   }
   if(bank[i,16]=="unknown"){
     bank[i,16] <- NA
   }
 }
 #我們已知第9、16列為含有unknown的屬性框
 fit2 <- rpart(y~age+job+marital+education+default+balance+housing+loan+contact
               +              +day+month+duration+campaign+pdays+previous+poutcome,method="class",
               +              data=bank_train,na.action=na.rpart)
 plot(fit,uniform=TRUE,main="Classification Tree for Bank") #畫決策樹圖
 text(fit,use.n=TRUE,all=TRUE)
 result2 <- predict(fit2,bank_test,type="class")
 count_result(result2)# 結果仍為0.9021表示之前的關於缺失值的推測不準確
 
 #下邊我們探索使用rpart()的control引數設定 
 fit3 <- rpart(y~age+job+marital+education+default+balance+housing+loan+contact
               +              +day+month+duration+campaign+pdays+previous+poutcome,method="class",
               +              data=bank_train,na.action=na.rpart,control=rpart.control(minsplit=20,cp=0.001))
 result3 <- predict(fit3,bank_test,type="class")
 count_result(result3,bank_test1)# 結果為0.90403 預測的準確度有所上升
 #下邊我們隊minsplit(最小分割點)設大一點 40
 fit4 <- rpart(y~age+job+marital+education+default+balance+housing+loan+contact
               +              +day+month+duration+campaign+pdays+previous+poutcome,method="class",
               +              data=bank_train,na.action=na.rpart,control=rpart.control(minsplit=40,cp=0.001))
 result4 <- predict(fit4,bank_test,type="class")
 count_result(result4,bank_test1) #結果為0.9136 這說明隨著分割點的增多,預測的準確率越高,

關於rpart的其他函式的功能探索,請繼續關注...

count_result.R

count_result <- function(result,data_test){
  n <- length(result)
  count_right <- 0
  i <- 1
  for (i in 1:n){
    if (result[i]==data_test[i,17]){
      count_right = count_right+1
    }
  }
  print(count_right/n)
}