1. 程式人生 > >CBA演算法---基於關聯規則進行分類的演算法

CBA演算法---基於關聯規則進行分類的演算法

介紹

CBA演算法全稱是Classification base of Association,就是基於關聯規則進行分類的演算法,說到關聯規則,我們就會想到Apriori和FP-Tree演算法都是關聯規則挖掘演算法,而CBA演算法正是利用了Apriori挖掘出的關聯規則,然後做分類判斷,所以在某種程度上說,CBA演算法也可以說是一種整合挖掘演算法。

演算法原理

CBA演算法作為分類演算法,他的分類情況也就是給定一些預先知道的屬性,然後叫你判斷出他的決策屬性是哪個值。判斷的依據就是Apriori演算法挖掘出的頻繁項,如果一個項集中包含預先知道的屬性,同時也包含分類屬性值,然後我們計算此頻繁項能否匯出已知屬性值推出決策屬性值的關聯規則,如果滿足規則的最小置信度的要求,那麼可以把頻繁項中的決策屬性值作為最後的分類結果。具體的演算法細節如下:

1、輸入資料記錄,就是一條條的屬性值。

2、對屬性值做數字的替換(按照列從上往下尋找屬性值),就類似於Apriori中的一條條事務記錄。

3、根據這個轉化後的事務記錄,進行Apriori演算法計算,挖掘出頻繁項集。

4、輸入查詢的屬性值,找出符合條件的頻繁項集(需要包含查詢屬性和分類決策屬性),如果能夠推匯出這樣的關聯規則,就算分類成功,輸出分類結果。

這裡以之前我做的CART演算法的測試資料為CBA演算法的測試資料,如下:

Rid Age Income Student CreditRating BuysComputer
1 13 High No Fair CLassNo
2 11 High No Excellent CLassNo
3 25 High No Fair CLassYes
4 45 Medium No Fair CLassYes
5 50 Low Yes Fair CLassYes
6 51 Low Yes Excellent CLassNo
7 30 Low Yes Excellent CLassYes
8 13 Medium No Fair CLassNo
9 9 Low Yes Fair CLassYes
10 55 Medium Yes Fair CLassYes
11 14 Medium Yes Excellent CLassYes
12 33 Medium No Excellent CLassYes
13 33 High Yes Fair CLassYes
14 41 Medium No Excellent CLassNo
屬性值對應的數字替換圖:
Medium=5, CLassYes=12, Excellent=10, Low=6, Fair=9, CLassNo=11, Young=1, Middle_aged=2, Yes=8, No=7, High=4, Senior=3
體會之後的資料變為了下面的事務項:
Rid Age Income Student CreditRating BuysComputer 
1 1 4 7 9 11 
2 1 4 7 10 11 
3 2 4 7 9 12 
4 3 5 7 9 12 
5 3 6 8 9 12 
6 3 6 8 10 11 
7 2 6 8 10 12 
8 1 5 7 9 11 
9 1 6 8 9 12 
10 3 5 8 9 12 
11 1 5 8 10 12 
12 2 5 7 10 12 
13 2 4 8 9 12 
14 3 5 7 10 11 
把每條記錄看出事務項,就和Apriori演算法的輸入格式基本一樣了,後面就是進行連線運算和剪枝步驟等Apriori演算法的步驟了,在這裡就不詳細描述了,Apriori演算法的實現可以點選這裡進行了解。

演算法的程式碼實現

測試資料就是上面的內容。

CBATool.java:

package DataMining_CBA;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import DataMining_CBA.AprioriTool.AprioriTool;
import DataMining_CBA.AprioriTool.FrequentItem;

/**
 * CBA演算法(關聯規則分類)工具類
 * 
 * @author lyq
 * 
 */
public class CBATool {
	// 年齡的類別劃分
	public final String AGE = "Age";
	public final String AGE_YOUNG = "Young";
	public final String AGE_MIDDLE_AGED = "Middle_aged";
	public final String AGE_Senior = "Senior";

	// 測試資料地址
	private String filePath;
	// 最小支援度閾值率
	private double minSupportRate;
	// 最小置信度閾值,用來判斷是否能夠成為關聯規則
	private double minConf;
	// 最小支援度
	private int minSupportCount;
	// 屬性列名稱
	private String[] attrNames;
	// 類別屬性所代表的數字集合
	private ArrayList<Integer> classTypes;
	// 用二維陣列儲存測試資料
	private ArrayList<String[]> totalDatas;
	// Apriori演算法工具類
	private AprioriTool aprioriTool;
	// 屬性到數字的對映圖
	private HashMap<String, Integer> attr2Num;
	private HashMap<Integer, String> num2Attr;

	public CBATool(String filePath, double minSupportRate, double minConf) {
		this.filePath = filePath;
		this.minConf = minConf;
		this.minSupportRate = minSupportRate;
		readDataFile();
	}

	/**
	 * 從檔案中讀取資料
	 */
	private void readDataFile() {
		File file = new File(filePath);
		ArrayList<String[]> dataArray = new ArrayList<String[]>();

		try {
			BufferedReader in = new BufferedReader(new FileReader(file));
			String str;
			String[] tempArray;
			while ((str = in.readLine()) != null) {
				tempArray = str.split(" ");
				dataArray.add(tempArray);
			}
			in.close();
		} catch (IOException e) {
			e.getStackTrace();
		}

		totalDatas = new ArrayList<>();
		for (String[] array : dataArray) {
			totalDatas.add(array);
		}
		attrNames = totalDatas.get(0);
		minSupportCount = (int) (minSupportRate * totalDatas.size());

		attributeReplace();
	}

	/**
	 * 屬性值的替換,替換成數字的形式,以便進行頻繁項的挖掘
	 */
	private void attributeReplace() {
		int currentValue = 1;
		int num = 0;
		String s;
		// 屬性名到數字的對映圖
		attr2Num = new HashMap<>();
		num2Attr = new HashMap<>();
		classTypes = new ArrayList<>();

		// 按照1列列的方式來,從左往右邊掃描,跳過列名稱行和id列
		for (int j = 1; j < attrNames.length; j++) {
			for (int i = 1; i < totalDatas.size(); i++) {
				s = totalDatas.get(i)[j];
				// 如果是數字形式的,這裡只做年齡類別轉換,其他的數字情況類似
				if (attrNames[j].equals(AGE)) {
					num = Integer.parseInt(s);
					if (num <= 20 && num > 0) {
						totalDatas.get(i)[j] = AGE_YOUNG;
					} else if (num > 20 && num <= 40) {
						totalDatas.get(i)[j] = AGE_MIDDLE_AGED;
					} else if (num > 40) {
						totalDatas.get(i)[j] = AGE_Senior;
					}
				}

				if (!attr2Num.containsKey(totalDatas.get(i)[j])) {
					attr2Num.put(totalDatas.get(i)[j], currentValue);
					num2Attr.put(currentValue, totalDatas.get(i)[j]);
					if (j == attrNames.length - 1) {
						// 如果是組後一列,說明是分類類別列,記錄下來
						classTypes.add(currentValue);
					}

					currentValue++;
				}
			}
		}

		// 對原始的資料作屬性替換,每條記錄變為類似於事務資料的形式
		for (int i = 1; i < totalDatas.size(); i++) {
			for (int j = 1; j < attrNames.length; j++) {
				s = totalDatas.get(i)[j];
				if (attr2Num.containsKey(s)) {
					totalDatas.get(i)[j] = attr2Num.get(s) + "";
				}
			}
		}
	}

	/**
	 * Apriori計算全部頻繁項集
	 * @return
	 */
	private ArrayList<FrequentItem> aprioriCalculate() {
		String[] tempArray;
		ArrayList<FrequentItem> totalFrequentItems;
		ArrayList<String[]> copyData = (ArrayList<String[]>) totalDatas.clone();
		
		// 去除屬性名稱行
		copyData.remove(0);
		// 去除首列ID
		for (int i = 0; i < copyData.size(); i++) {
			String[] array = copyData.get(i);
			tempArray = new String[array.length - 1];
			System.arraycopy(array, 1, tempArray, 0, tempArray.length);
			copyData.set(i, tempArray);
		}
		aprioriTool = new AprioriTool(copyData, minSupportCount);
		aprioriTool.computeLink();
		totalFrequentItems = aprioriTool.getTotalFrequentItems();

		return totalFrequentItems;
	}

	/**
	 * 基於關聯規則的分類
	 * 
	 * @param attrValues
	 *            預先知道的一些屬性
	 * @return
	 */
	public String CBAJudge(String attrValues) {
		int value = 0;
		// 最終分類類別
		String classType = null;
		String[] tempArray;
		// 已知的屬性值
		ArrayList<String> attrValueList = new ArrayList<>();
		ArrayList<FrequentItem> totalFrequentItems;

		totalFrequentItems = aprioriCalculate();
		// 將查詢條件進行逐一屬性的分割
		String[] array = attrValues.split(",");
		for (String record : array) {
			tempArray = record.split("=");
			value = attr2Num.get(tempArray[1]);
			attrValueList.add(value + "");
		}

		// 在頻繁項集中尋找符合條件的項
		for (FrequentItem item : totalFrequentItems) {
			// 過濾掉不滿足個數頻繁項
			if (item.getIdArray().length < (attrValueList.size() + 1)) {
				continue;
			}

			// 要保證查詢的屬性都包含在頻繁項集中
			if (itemIsSatisfied(item, attrValueList)) {
				tempArray = item.getIdArray();
				classType = classificationBaseRules(tempArray);

				if (classType != null) {
					// 作屬性替換
					classType = num2Attr.get(Integer.parseInt(classType));
					break;
				}
			}
		}

		return classType;
	}

	/**
	 * 基於關聯規則進行分類
	 * 
	 * @param items
	 *            頻繁項
	 * @return
	 */
	private String classificationBaseRules(String[] items) {
		String classType = null;
		String[] arrayTemp;
		int count1 = 0;
		int count2 = 0;
		// 置信度
		double confidenceRate;

		String[] noClassTypeItems = new String[items.length - 1];
		for (int i = 0, k = 0; i < items.length; i++) {
			if (!classTypes.contains(Integer.parseInt(items[i]))) {
				noClassTypeItems[k] = items[i];
				k++;
			} else {
				classType = items[i];
			}
		}

		for (String[] array : totalDatas) {
			// 去除ID數字號
			arrayTemp = new String[array.length - 1];
			System.arraycopy(array, 1, arrayTemp, 0, array.length - 1);
			if (isStrArrayContain(arrayTemp, noClassTypeItems)) {
				count1++;

				if (isStrArrayContain(arrayTemp, items)) {
					count2++;
				}
			}
		}

		// 做置信度的計算
		confidenceRate = count1 * 1.0 / count2;
		if (confidenceRate >= minConf) {
			return classType;
		} else {
			// 如果不滿足最小置信度要求,則此關聯規則無效
			return null;
		}
	}

	/**
	 * 判斷單個字元是否包含在字元陣列中
	 * 
	 * @param array
	 *            字元陣列
	 * @param s
	 *            判斷的單字元
	 * @return
	 */
	private boolean strIsContained(String[] array, String s) {
		boolean isContained = false;

		for (String str : array) {
			if (str.equals(s)) {
				isContained = true;
				break;
			}
		}

		return isContained;
	}

	/**
	 * 陣列array2是否包含於array1中,不需要完全一樣
	 * 
	 * @param array1
	 * @param array2
	 * @return
	 */
	private boolean isStrArrayContain(String[] array1, String[] array2) {
		boolean isContain = true;
		for (String s2 : array2) {
			isContain = false;
			for (String s1 : array1) {
				// 只要s2字元存在於array1中,這個字元就算包含在array1中
				if (s2.equals(s1)) {
					isContain = true;
					break;
				}
			}

			// 一旦發現不包含的字元,則array2陣列不包含於array1中
			if (!isContain) {
				break;
			}
		}

		return isContain;
	}

	/**
	 * 判斷頻繁項集是否滿足查詢
	 * 
	 * @param item
	 *            待判斷的頻繁項集
	 * @param attrValues
	 *            查詢的屬性值列表
	 * @return
	 */
	private boolean itemIsSatisfied(FrequentItem item,
			ArrayList<String> attrValues) {
		boolean isContained = false;
		String[] array = item.getIdArray();

		for (String s : attrValues) {
			isContained = true;

			if (!strIsContained(array, s)) {
				isContained = false;
				break;
			}

			if (!isContained) {
				break;
			}
		}

		if (isContained) {
			isContained = false;

			// 還要驗證是否頻繁項集中是否包含分類屬性
			for (Integer type : classTypes) {
				if (strIsContained(array, type + "")) {
					isContained = true;
					break;
				}
			}
		}

		return isContained;
	}

}
呼叫類Client.java:
package DataMining_CBA;

import java.text.MessageFormat;

/**
 * CBA演算法--基於關聯規則的分類演算法
 * @author lyq
 *
 */
public class Client {
	public static void main(String[] args){
		String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
		String attrDesc = "Age=Senior,CreditRating=Fair";
		String classification = null;
		
		//最小支援度閾值率
		double minSupportRate = 0.2;
		//最小置信度閾值
		double minConf = 0.7;
		
		CBATool tool = new CBATool(filePath, minSupportRate, minConf);
		classification = tool.CBAJudge(attrDesc);
		System.out.println(MessageFormat.format("{0}的關聯分類結果為{1}", attrDesc, classification));
	}
}
程式碼的結果為:
頻繁1項集:
{1,},{2,},{3,},{4,},{5,},{6,},{7,},{8,},{9,},{10,},{11,},{12,},
頻繁2項集:
{1,7,},{1,9,},{1,11,},{2,12,},{3,5,},{3,8,},{3,9,},{3,12,},{4,7,},{4,9,},{5,7,},{5,9,},{5,10,},{5,12,},{6,8,},{6,12,},{7,9,},{7,10,},{7,11,},{7,12,},{8,9,},{8,10,},{8,12,},{9,12,},{10,11,},{10,12,},
頻繁3項集:
{1,7,11,},{3,9,12,},{6,8,12,},{8,9,12,},
頻繁4項集:

頻繁5項集:

頻繁6項集:

頻繁7項集:

頻繁8項集:

頻繁9項集:

頻繁10項集:

頻繁11項集:

Age=Senior,CreditRating=Fair的關聯分類結果為CLassYes
上面的有些項集為空說明沒有此項集。Apriori演算法類可以在這裡進行查閱,這裡只展示了CBA演算法的部分。

演算法的分析

我在準備實現CBA演算法的時候就預見到了這個演算法就是對Apriori演算法的一個包裝,在於2點,輸入資料的格式進行數字的轉換,還有就是輸出的時候做屬性對數字的替換,核心還是在於Apriori演算法的項集頻繁挖掘。

程式實現時遇到的問題

在這期間遇到了一個bug就是頻繁1項集在排序的時候出現了問題,後來發現原因是String.CompareTo(),原本應該是1,2,....11,12,用了前面這個方法後會變成1,10,2,。。就是10會比2小的情況,後來查了String.CompareTo()的比較規則,明白了他是一位位比較Ascall碼值,因為10的1比2小,最後果斷的改回了用Integer的比較方法了。這個問題別看是個小問題,1項集如果沒有排好序,後面的連線操作可能會出現少情況的可能,這個之前吃過這樣的虧了。

我對CBA演算法的理解

CBA演算法和巧妙的利用了關聯規則進行類別的分類,有別與其他的分類演算法。他的演算法好壞又會依靠Apriori演算法的執行好壞。