1. 程式人生 > >【JAVA實現】樸素貝葉斯分類演算法

【JAVA實現】樸素貝葉斯分類演算法

       之前部落格提到的KNN演算法以及決策樹演算法都是要求分類器給出“該資料例項屬於哪一類”這類問題的明確答案,正因為如此,才出現了使用決策樹分類時,有時無法判定某一測試例項屬於哪一類別。使用樸素貝葉斯演算法則可以避免這個問題,它給出了這個例項屬於某一類別的概率值,然後通過比較概率值,可以找到該例項最有可能屬於哪一類別。

       該演算法可以用如下形式表示:


       直接求解概率值很困難,因此我們可以通過下列式子進行變化。


       因為分母對於所有類別都是固定值,所以我們只要求能使得分子最大化的類別即可。又因為樸素貝葉斯假設各特徵屬性條件獨立,所以有:


       樸素貝葉斯分類器通常有兩種實現方式:一種基於貝努利模型實現,一種基於多項式模型實現。貝努利實現方式也稱“詞集模型”,其不考慮詞在文件中出現的次數,只考慮出不出現,因此在這個意義上相當於假設詞是等權重的。而多項式模型也稱“詞袋模型”,它考慮詞在文件中的出現次數。本文采用的是多項式模型。

       這次的案例使用的是使用樸素貝葉斯過濾垃圾郵件。訓練集如下所示:


       第一個資料夾下放的是25篇正常的純文字郵件,第二個資料夾放的是25篇純文字垃圾郵件。

       首先是郵件內容的封裝。

package naivebayesian;

import java.util.List;

public class Email {
	
	private List<String> wordList;
	private int flag;
	
	public int getFlag() {
		return flag;
	}
	public void setFlag(int flag) {
		this.flag = flag;
	}
	public List<String> getWordList() {
		return wordList;
	}
	public void setWordList(List<String> wordList) {
		this.wordList = wordList;
	}
	
}
       第一個欄位為每封郵件分好詞的集合,第二個欄位表示這封郵件是否是垃圾郵件。
package naivebayesian;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;

public class NaiveBayesian {
	
	private List<Double> p0Vec = null;
	//垃圾郵件中每個詞出現的概率
	private List<Double> p1Vec = null;
	//垃圾郵件出現的概率
	private double pSpamRatio;

	/**
	 * 初始化資料集
	 * 
	 * @return
	 */
	public List<Email> initDataSet() {
		List<Email> dataSet = new ArrayList<Email>();
		BufferedReader bufferedReader1 = null;
		BufferedReader bufferedReader2 = null;
		try {
			for (int i = 1; i < 26; i++) {
				bufferedReader1 = new BufferedReader(new InputStreamReader(
						new FileInputStream(
								"/home/shenchao/Desktop/MLSourceCode/machinelearninginaction/Ch04/email/ham/"
										+ i + ".txt")));
				StringBuilder sb1 = new StringBuilder();
				String string = null;
				while ((string = bufferedReader1.readLine()) != null) {
					sb1.append(string);
				}
				Email hamEmail = new Email();
				hamEmail.setWordList(textParse(sb1.toString()));
				hamEmail.setFlag(0);

				bufferedReader2 = new BufferedReader(new InputStreamReader(
						new FileInputStream(
								"/home/shenchao/Desktop/MLSourceCode/machinelearninginaction/Ch04/email/spam/"
										+ i + ".txt")));
				StringBuilder sb2 = new StringBuilder();
				while ((string = bufferedReader2.readLine()) != null) {
					sb2.append(string);
				}
				Email spamEmail = new Email();
				spamEmail.setWordList(textParse(sb2.toString()));
				spamEmail.setFlag(1);

				dataSet.add(hamEmail);
				dataSet.add(spamEmail);
			}
			return dataSet;
		} catch (Exception e) {
			e.printStackTrace();
			throw new RuntimeException();
		} finally {
			try {
				bufferedReader1.close();
				bufferedReader2.close();
			} catch (IOException e) {
				e.printStackTrace();
			}

		}
	}

	/**
	 * 分詞,英文的分詞相比中文的分詞要簡單很多,這裡使用的分隔符為除單詞、數字外的任意字串
	 * 如果使用中文,則可以使用中科院的一套分詞系統,分詞效果還算不錯
	 * 
	 * @param originalString
	 * @return
	 * @return
	 */
	private List<String> textParse(String originalString) {
		String[] s = originalString.split("\\W");
		List<String> wordList = new ArrayList<String>();
		for (String string : s) {
			if (string.contains(" ")) {
				continue;
			}
			if (string.length() > 2) {
				wordList.add(string.toLowerCase());
			}
		}
		return wordList;
	}

	/**
	 * 構建單詞集,此長度等於向量長度
	 * 
	 * @return
	 */
	public Set<String> createVocabList(List<Email> dataSet) {
		Set<String> set = new LinkedHashSet<String>();
		for (Email email : dataSet) {
			for (String string : email.getWordList()) {
				set.add(string);
			}
		}
		return set;
	}

	/**
	 * 將郵件轉換為向量
	 * 
	 * @param vocabSet
	 * @param inputSet
	 * @return
	 */
	public List<Integer> setOfWords2Vec(Set<String> vocabSet, Email email) {
		List<Integer> returnVec = new ArrayList<Integer>();
		for (String word : vocabSet) {
			returnVec.add(calWordFreq(word, email));
		}
		return returnVec;
	}

	/**
	 * 計算一個詞在某個集合中的出現次數
	 * 
	 * @return
	 */
	private int calWordFreq(String word, Email email) {
		int num = 0;
		for (String string : email.getWordList()) {
			if (string.equals(word)) {
				++num;
			}
		}
		return num;
	}

	public void trainNB(Set<String> vocabSet, List<Email> dataSet) {
		// 訓練文字的數量
		int numTrainDocs = dataSet.size();
		// 訓練集中垃圾郵件的概率
		pSpamRatio = (double) calSpamNum(dataSet) / numTrainDocs;

		// 記錄每個類別下每個詞的出現次數
		List<Integer> p0Num = new ArrayList<Integer>();
		List<Integer> p1Num = new ArrayList<Integer>();
		// 記錄每個類別下一共出現了多少詞,為防止分母為0,所以在此預設值為2
		double p0Denom = 2.0, p1Denom = 2.0;
		for (Email email : dataSet) {
			List<Integer> list = setOfWords2Vec(vocabSet, email);
			// 如果是垃圾郵件
			if (email.getFlag() == 1) {
				p1Num = vecAddVec(p1Num, list);
				//計算該類別下出現的所有單詞數目
				p1Denom += calTotalWordNum(list);
			}else {
				p0Num = vecAddVec(p0Num, list);
				p0Denom += calTotalWordNum(list);
			}
		}
		p0Vec = calWordRatio(p0Num, p0Denom);
		p1Vec = calWordRatio(p1Num, p1Denom);
	}

	/**
	 * 兩個向量相加
	 * 
	 * @param vec1
	 * @param vec2
	 * @return
	 */
	private List<Integer> vecAddVec(List<Integer> vec1,
			List<Integer> vec2) {
		if (vec1.size() == 0) {
			return vec2;
		}
		List<Integer> list = new ArrayList<Integer>();
		for (int i = 0; i < vec1.size(); i++) {
			list.add(vec1.get(i) + vec2.get(i));
		}
		return list;
	}
	
	/**
	 * 計算垃圾郵件的數量
	 * 
	 * @param dataSet
	 * @return
	 */
	private int calSpamNum(List<Email> dataSet) {
		int time = 0;
		for (Email email : dataSet) {
			time += email.getFlag();
		}
		return time;
	}
	
	/**
	 * 統計出現的所有單詞數
	 * @param list
	 * @return
	 */
	private int calTotalWordNum(List<Integer> list) {
		int num = 0;
		for (Integer integer : list) {
			num += integer;
		}
		return num;
	}
	
	/**
	 * 計算每個單詞在該類別下的出現概率,為防止分子為0,導致樸素貝葉斯公式為0,設定分子的預設值為1
	 * @param list
	 * @param wordNum
	 * @return
	 */
	private List<Double> calWordRatio(List<Integer> list, double wordNum) {
		List<Double> vec = new ArrayList<Double>();
		for (Integer i : list) {
			vec.add(Math.log((double)(i+1) / wordNum));
		}
		return vec;
	}
	
	/**
	 * 比較不同類別 p(w0,w1,w2...wn | ci)*p(ci) 的大小   <br>
	 *  p(w0,w1,w2...wn | ci) = p(w0|ci)*p(w1|ci)*p(w2|ci)... <br>
	 *  由於防止下溢,對中間計算值都取了對數,因此上述公式化為log(p(w0,w1,w2...wn | ci)) + log(p(ci)),即
	 *  化為多個式子相加得到結果
	 *  
	 * @param email
	 * @return 返回概率最大值 
	 */
	public int classifyNB(List<Integer> emailVec) {
		double p0 = calProbabilityByClass(p0Vec, emailVec) + Math.log(1 - pSpamRatio);
		double p1 = calProbabilityByClass(p1Vec, emailVec) + Math.log(pSpamRatio);
		if (p0 > p1) {
			return 0;
		}else {
			return 1;
		}
	}
	
	private double calProbabilityByClass(List<Double> vec,List<Integer> emailVec) {
		double sum = 0.0;
		for (int i = 0; i < vec.size(); i++) {
			sum += (vec.get(i) * emailVec.get(i));
		}
		return sum;
	}
	
	public double testingNB() {
		List<Email> dataSet = initDataSet();
		List<Email> testSet = new ArrayList<Email>();
		//隨機取前10作為測試樣本
		for (int i = 0; i < 10; i++) {
			Random random = new Random();
			int n = random.nextInt(50-i);
			testSet.add(dataSet.get(n));
			//從訓練樣本中刪除這10條測試樣本
			dataSet.remove(n);
		}
		Set<String> vocabSet = createVocabList(dataSet);
		//訓練樣本
		trainNB(vocabSet, dataSet);
		
		int errorCount = 0;
		for (Email email : testSet) {
			if (classifyNB(setOfWords2Vec(vocabSet, email)) != email.getFlag()) {
				++errorCount;
			}
		}
//		System.out.println("the error rate is: " + (double) errorCount / testSet.size());
		return (double) errorCount / testSet.size();
	}

	public static void main(String[] args) {
		NaiveBayesian bayesian = new NaiveBayesian();
		double d = 0;
		for (int i = 0; i < 50; i++) {
			d +=bayesian.testingNB();
		}
		System.out.println("total error rate is: " + d / 50);
	}

}
       可能比較亂,我把所有的模組寫在了一起。首先是從兩個資料夾中讀取郵件,建立資料集。讀取同時,會對郵件的內容進行分詞,這裡的分詞規則比較簡單,分隔符為非文字字元,可以根據自己規則定義。

       利用貝葉斯分類器對文件進行分類時,要計算多個概率的乘積以獲得文件屬於某個類別的概率,如果其中一個概率值為0,那麼最後乘積也為0,因此為降低這種影響,可以將所有詞的出現次數初始化為1,將分母初始化為2。另一個問題是下溢,這是由於太多很小的數相乘造成的,導致最後四捨五入會得到0.解決辦法是對乘積取自然對數。

       最後,從訓練樣本中隨機選取10條作為測試集,進行錯誤率計算,重複十次,取平均數得到最後的錯誤率為4%。發現結果還是不錯的。

       通過特徵之間的條件獨立性假設,可以降低對資料量的需求。但是獨立性假設是指一個詞的出現概率並不依賴於文件中的其他詞,可見這個假設過於簡單,可能會影響最後的準確率。

       如有什麼問題,歡迎大家和我一起學習探討。