1. 程式人生 > >資料探勘筆記-聚類-SpectralClustering-原理與簡單實現

資料探勘筆記-聚類-SpectralClustering-原理與簡單實現

譜聚類(Spectral Clustering, SC)是一種基於圖論的聚類方法——將帶權無向圖劃分為兩個或兩個以上的最優子圖,使子圖內部儘量相似,而子圖間距離儘量距離較遠,以達到常見的聚類的目的。其中的最優是指最優目標函式不同,可以是Min Cut、Nomarlized Cut、Ratio Cut等。譜聚類能夠識別任意形狀的樣本空間且收斂於全域性最優解,其基本思想是利用樣本資料的相似矩陣(拉普拉斯矩陣)進行特徵分解後得到的特徵向量進行聚類。

Spectral Clustering 演算法步驟:

1)根據資料構造一個GraphGraph的每一個節點對應一個數據點,將相似的點連線起來,並且邊的權重用於表示資料之間的相似度。把這個Graph

用鄰接矩陣的形式表示出來,記為 W

2)把W的每一列元素活者行元素加起來得到N個數,把它們放在對角線上(其他地方都是零),組成一個N*N的度矩陣,記為

3)根據度矩陣與鄰接矩陣得出拉普拉斯矩陣 L = D - W 

4)求出拉普拉斯矩陣L的前k個特徵值(除非特殊說明,否則k指按照特徵值的大小從小到大的順序)以及對應的特徵向量。

5)把這k個特徵(列)向量排列在一起組成一個N*k的矩陣,將其中每一行看作k維空間中的一個向量,並使用 K-Means演算法進行聚類。聚類的結果中每一行所屬的類別就是原來Graph中的節點亦即最初的N個數據點分別所屬的類別。

示例


Spectral Clustering 

和傳統的聚類方法(如 K-Means等)對比:
1)和 K-Medoids 類似,Spectral Clustering 只需要資料之間的相似度矩陣就可以了,而不必像K-means那樣要求資料必須是 維歐氏空間中的向量。Spectral Clustering 所需要的所有資訊都包含在 中。不過一般 W 並不總是等於最初的相似度矩陣——回憶一下,W 是我們構造出來的 Graph 的鄰接矩陣表示,通常我們在構造 Graph 的時候為了方便進行聚類,更加強到“區域性”的連通性,亦即主要考慮把相似的點連線在一起,比如:我們可以設定一個閾值,如果兩個點的相似度小於這個閾值,就把他們看作是不連線的。另一種構造 Graph 鄰接的方法是將 n 個與節點最相似的點與其連線起來。

2)由於抓住了主要矛盾,忽略了次要的東西,因此比傳統的聚類演算法更加健壯一些,對於不規則的誤差資料不是那麼敏感,而且效能也要好一些。許多實驗都證明了這一點。事實上,在各種現代聚類演算法的比較中,K-means 通常都是作為 baseline 而存在的。實際上 Spectral Clustering 是在用特徵向量的元素來表示原來的資料,並在這種“更好的表示形式”上進行 K-Means 。實際上這種“更好的表示形式”是用 Laplacian Eig進行降維的後的結果。而降維的目的正是“抓住主要矛盾,忽略次要的東西”。
3計算複雜度比 K-means 要小。這個在高維資料上表現尤為明顯。例如文字資料,通常排列起來是維度非常高(比如幾千或者幾萬)的稀疏矩陣,對稀疏矩陣求特徵值和特徵向量有很高效的辦法,得到的結果是一些 k 維的向量(通常 k 不會很大),在這些低維的資料上做 K-Means 運算量非常小。但是對於原始資料直接做 K-Means 的話,雖然最初的資料是稀疏矩陣,但是 K-Means 中有一個求 Centroid 的運算,就是求一個平均值:許多稀疏的向量的平均值求出來並不一定還是稀疏向量,事實上,在文字資料裡,很多情況下求出來的 Centroid 向量是非常稠密,這時再計算向量之間的距離的時候,運算量就變得非常大,直接導致普通的 K-Means 巨慢無比,而 Spectral Clustering 等工序更多的演算法則迅速得多的結果。

Java簡單實現程式碼如下:

public class SpectralClusteringBuilder {

	public static int DIMENSION = 30;
	
	public static double THRESHOLD = 0.01;

	public Data getInitData() {
		Data data = new Data();
		try {
			String path = SpectralClustering.class.getClassLoader()
					.getResource("測試").toURI().getPath();
			DocumentSet documentSet = DocumentLoader.loadDocumentSet(path);
			List<Document> documents = documentSet.getDocuments();
			DocumentUtils.calculateTFIDF_0(documents);
			DocumentUtils.calculateSimilarity(documents, new CosineDistance());
			Map<String, Map<String, Double>> nmap = new HashMap<String, Map<String, Double>>();
			Map<String, String> cmap = new HashMap<String, String>();
			for (Document document : documents) {
				String name = document.getName();
				cmap.put(name, document.getCategory());
				Map<String, Double> similarities = nmap.get(name);
				if (null == similarities) {
					similarities = new HashMap<String, Double>();
					nmap.put(name, similarities);
				}
				for (DocumentSimilarity similarity : document.getSimilarities()) {
					if (similarity.getDoc2().getName().equalsIgnoreCase(similarity.getDoc1().getName())) {
						similarities.put(similarity.getDoc2().getName(), 0.0);
					} else {
						similarities.put(similarity.getDoc2().getName(), similarity.getDistance());
					}
				}
			}
			String[] docnames = nmap.keySet().toArray(new String[0]);
			data.setRow(docnames);
			data.setColumn(docnames);
			data.setDocnames(docnames);
			int len = docnames.length;
			double[][] original = new double[len][len];
			for (int i = 0; i < len; i++) {
				Map<String, Double> similarities = nmap.get(docnames[i]);
				for (int j = 0; j < len; j++) {
					double distance = similarities.get(docnames[j]);
					original[i][j] = distance;
				}
			}
			data.setOriginal(original);
			data.setCmap(cmap);
			data.setNmap(nmap);
		} catch (Exception e) {
			e.printStackTrace();
		}
		return data;
	}

	/**
	 * 獲取距離閥值在一定範圍內的點
	 * @param data
	 * @return
	 */
	public double[][] getWByDistance(Data data) {
		Map<String, Map<String, Double>> nmap = data.getNmap();
		String[] docnames = data.getDocnames();
		int len = docnames.length;
		double[][] w = new double[len][len];
		for (int i = 0; i < len; i++) {
			Map<String, Double> similarities = nmap.get(docnames[i]);
			for (int j = 0; j < len; j++) {
				double distance = similarities.get(docnames[j]);
				w[i][j] = distance < THRESHOLD ? 1 : 0;
			}
		}
		return w;
	}
	
	/**
	 * 獲取距離最近的K個點
	 * @param data
	 * @return
	 */
	public double[][] getWByKNearestNeighbors(Data data) {
		Map<String, Map<String, Double>> nmap = data.getNmap();
		String[] docnames = data.getDocnames();
		int len = docnames.length;
		double[][] w = new double[len][len];
		for (int i = 0; i < len; i++) {
			List<Map.Entry<String, Double>> similarities = 
					new ArrayList<Map.Entry<String, Double>>(nmap.get(docnames[i]).entrySet());
			sortSimilarities(similarities, DIMENSION);
			for (int j = 0; j < len; j++) {
				String name = docnames[j];
				boolean flag = false;
				for (Map.Entry<String, Double> entry : similarities) {
					if (name.equalsIgnoreCase(entry.getKey())) {
						flag = true;
						break;
					}
				}
				w[i][j] = flag ? 1 : 0;
			}
		}
		return w;
	}

	/**
	 * 垂直求和
	 * @param W
	 * @return
	 */
	public double[][] getVerticalD(double[][] W) {
		int row = W.length;
		int column = W[0].length;
		double[][] d = new double[row][column];
		for (int j = 0; j < column; j++) {
			double sum = 0;
			for (int i = 0; i < row; i++) {
				sum += W[i][j];
			}
			d[j][j] = sum;
		}
		return d;
	}

	/**
	 * 水平求和
	 * @param W
	 * @return
	 */
	public double[][] getHorizontalD(double[][] W) {
		int row = W.length;
		int column = W[0].length;
		double[][] d = new double[row][column];
		for (int i = 0; i < row; i++) {
			double sum = 0;
			for (int j = 0; j < column; j++) {
				sum += W[i][j];
			}
			d[i][i] = sum;
		}
		return d;
	}
	
	/**
	 * 相似度排序,並取前K個,倒敘
	 * @param similarities
	 * @param k
	 */
	public void sortSimilarities(List<Map.Entry<String, Double>> similarities, int k) {
		Collections.sort(similarities, new Comparator<Map.Entry<String, Double>>() {
			@Override
			public int compare(Entry<String, Double> o1,
					Entry<String, Double> o2) {
				return o2.getValue().compareTo(o1.getValue());
			}
		});
		while (similarities.size() > k) {
			similarities.remove(similarities.size() - 1);
		}
	}

	public void print(double[][] values) {
		for (int i = 0, il = values.length; i < il; i++) {
			for (int j = 0, jl = values[0].length; j < jl; j++) {
				System.out.print(values[i][j] + "  ");
			}
			System.out.println("\n");
		}
	}

	// 隨機生成中心點,並生成初始的K個聚類
	public List<DataPointCluster> genInitCluster(List<DataPoint> points, int k) {
		List<DataPointCluster> clusters = new ArrayList<DataPointCluster>();
		Random random = new Random();
		Set<String> categories = new HashSet<String>();
		while (clusters.size() < k) {
			DataPoint center = points.get(random.nextInt(points.size()));
			String category = center.getCategory();
			if (categories.contains(category))
				continue;
			categories.add(category);
			DataPointCluster cluster = new DataPointCluster();
			cluster.setCenter(center);
			cluster.getDataPoints().add(center);
			clusters.add(cluster);
		}
		return clusters;
	}

	// 將點歸入到聚類中
	public void handleCluster(List<DataPoint> points,
			List<DataPointCluster> clusters, int iterNum) {
		for (DataPoint point : points) {
			DataPointCluster maxCluster = null;
			double maxDistance = Integer.MIN_VALUE;
			for (DataPointCluster cluster : clusters) {
				DataPoint center = cluster.getCenter();
				double distance = DistanceUtils.cosine(point.getValues(),
						center.getValues());
				if (distance > maxDistance) {
					maxDistance = distance;
					maxCluster = cluster;
				}
			}
			if (null != maxCluster) {
				maxCluster.getDataPoints().add(point);
			}
		}
		// 終止條件定義為原中心點與新中心點距離小於一定閥值
		// 當然也可以定義為原中心點等於新中心點
		boolean flag = true;
		for (DataPointCluster cluster : clusters) {
			DataPoint center = cluster.getCenter();
			DataPoint newCenter = cluster.computeMediodsCenter();
			double distance = DistanceUtils.cosine(newCenter.getValues(),
					center.getValues());
			if (distance > 0.5) {
				flag = false;
				cluster.setCenter(newCenter);
			}
		}
		if (!flag && iterNum < 25) {
			for (DataPointCluster cluster : clusters) {
				cluster.getDataPoints().clear();
			}
			handleCluster(points, clusters, ++iterNum);
		}
	}

	/**
	 * KMeans方法
	 * @param dataPoints
	 */
	public void kmeans(List<DataPoint> dataPoints) {
		List<DataPointCluster> clusters = genInitCluster(dataPoints, 4);
		handleCluster(dataPoints, clusters, 0);
		int success = 0, failure = 0;
		for (DataPointCluster cluster : clusters) {
			String category = cluster.getCenter().getCategory();
			for (DataPoint dataPoint : cluster.getDataPoints()) {
				String dpCategory = dataPoint.getCategory();
				if (category.equals(dpCategory)) {
					success++;
				} else {
					failure++;
				}
			}
		}
		System.out.println("total: " + (success + failure) + " success: "
				+ success + " failure: " + failure);
	}

	public void build() {
		Data data = getInitData();
		double[][] w = getWByKNearestNeighbors(data);
		double[][] d = getHorizontalD(w);
		Matrix W = new Matrix(w);
		Matrix D = new Matrix(d);
		Matrix L = D.minus(W);
		EigenvalueDecomposition eig = L.eig();
		double[][] v = eig.getV().getArray();
		double[][] vs = new double[v.length][DIMENSION];
		for (int i = 0, li = v.length; i < li; i++) {
			for (int j = 1, lj = DIMENSION; j <= lj; j++) {
				vs[i][j-1] = v[i][j];
			}
		}
		Matrix V = new Matrix(vs);
		Matrix O = new Matrix(data.getOriginal());
	    double[][] t = O.times(V).getArray();
	    List<DataPoint> dataPoints = new ArrayList<DataPoint>();
		for (int i = 0; i < t.length; i++) {
			DataPoint dataPoint = new DataPoint();
			dataPoint.setCategory(data.getCmap().get(data.getColumn()[i]));
			dataPoint.setValues(t[i]);
			dataPoints.add(dataPoint);
		}
		for (int n = 0; n < 10; n++) {
			kmeans(dataPoints);
		}
	}

	public static void main(String[] args) {
		new SpectralClusteringBuilder().build();
	}
}

程式碼託管:https://github.com/fighting-one-piece/repository-datamining.git