1. 程式人生 > >資料探勘筆記-聚類-KMeans-原理與簡單實現

資料探勘筆記-聚類-KMeans-原理與簡單實現

K中心點演算法(K-medoids)提出了新的質點選取方式,而不是簡單像k-means演算法採用均值計演算法。在K中心點演算法中,每次迭代後的質點都是從聚類的樣本點中選取,而選取的標準就是當該樣本點成為新的質點後能提高類簇的聚類質量,使得類簇更緊湊。該演算法使用絕對誤差標準來定義一個類簇的緊湊程度。


如果某樣本點成為質點後,絕對誤差能小於原質點所造成的絕對誤差,那麼K中心點演算法認為該樣本點是可以取代原質點的,在一次迭代重計算類簇質點的時候,我們選擇絕對誤差最小的那個樣本點成為新的質點。較好的解決了對離群點/噪聲資料的敏感,但時間複雜度上升至O(k(m-k)^2)。計算量顯然要比KMeans要大,一般只適合小資料量。

二分KMeans

二分KMeans是對基本KMeans的直接擴充,它基於一種簡單想法:為了得到K個簇,將所有點集合分裂成兩個簇,從這些簇中選取一個繼續分裂,直到產生K個簇。

二分k均值(bisecting k-means)演算法的主要思想是:首先將所有點作為一個簇,然後將該簇一分為二。之後選擇能最大程度降低聚類代價函式(也就是誤差平方和)的簇劃分為兩個簇。以此進行下去,直到簇的數目等於使用者給定的數目k為止。

       以上隱含著一個原則是:因為聚類的誤差平方和能夠衡量聚類效能,該值越小表示資料點月接近於它們的質心,聚類效果就越好。所以我們就需要對誤差平方和最大的簇進行再一次的劃分,因為誤差平方和越大,表示該簇聚類越不好,越有可能是多個簇被當成一個簇了,所以我們首先需要對這個簇進行劃分。

二分k均值演算法的虛擬碼如下:

將所有資料點看成一個簇

當簇數目小於k時

       對每一個簇

              計算總誤差

              在給定的簇上面進行k-均值聚類(k=2)

              計算將該簇一分為二後的總誤差

       選擇使得誤差最小的那個簇進行劃分操作

下面用Java來簡單實現演算法,考慮簡單,點只用了二維。

public class KMeansCluster extends AbstractCluster {
	
	public static final double THRESHOLD = 1.0;
	
	public List<Point> initData() {
		List<Point> points = new ArrayList<Point>();
		InputStream in = null;
		BufferedReader br = null;
		try {
			in = KMeansCluster.class.getClassLoader().getResourceAsStream("kmeans1.txt");
			br = new BufferedReader(new InputStreamReader(in));
			String line = br.readLine();
			while (null != line && !"".equals(line)) {
				StringTokenizer tokenizer = new StringTokenizer(line);
				double x = Double.parseDouble(tokenizer.nextToken());
				double y = Double.parseDouble(tokenizer.nextToken());
				points.add(new Point(x , y));
				line = br.readLine();
			}
		} catch (Exception e) {
			e.printStackTrace();
		} finally {
			IOUtils.closeQuietly(in);
			IOUtils.closeQuietly(br);
		}
		return points;
	}
	
	//隨機生成中心點,並生成初始的K個聚類
	public List<PointCluster> genInitCluster(List<Point> points, int k) {
		List<PointCluster> clusters = new ArrayList<PointCluster>();
		Random random = new Random();
		for (int i = 0, len = points.size(); i < k; i++) {
			PointCluster cluster = new PointCluster();
			Point center = points.get(random.nextInt(len));
			cluster.setCenter(center);
			cluster.getPoints().add(center);
			clusters.add(cluster);
		}
		return clusters;
	}
	
	//將點歸入到聚類中
	public void handleCluster(List<Point> points, List<PointCluster> clusters) {
		for (Point point : points) {
			PointCluster minCluster = null;
			double minDistance = Integer.MAX_VALUE;
			for (PointCluster cluster : clusters) {
				Point center = cluster.getCenter();
				double distance = euclideanDistance(point, center);
//				double distance = manhattanDistance(point, center);
				if (distance < minDistance) {
					minDistance = distance;
					minCluster = cluster;
				}
			}
			if (null != minCluster) {
				minCluster.getPoints().add(point);
			}
		}
		//終止條件定義為原中心點與新中心點距離小於一定閥值
		//當然也可以定義為原中心點等於新中心點
		boolean flag = true;
		for (PointCluster cluster : clusters) {
			Point center = cluster.getCenter();
			System.out.println("center: " + center);
			Point newCenter = cluster.computeMeansCenter();
			System.out.println("new center: " + newCenter);
//			if (!center.equals(newCenter)) {
			double distance = euclideanDistance(center, newCenter);
			System.out.println("distaince: " + distance);
			if (distance > THRESHOLD) {
				flag = false;
				cluster.setCenter(newCenter);
			}
		}
		if (!flag) {
			for (PointCluster cluster : clusters) {
				cluster.getPoints().clear();
			}
			handleCluster(points, clusters);
		}
	}
	
	public List<PointCluster> cluster(List<Point> points, int k) {
		List<PointCluster> clusters = genInitCluster(points, k);
		handleCluster(points, clusters);
		return clusters;
	}
	
	public void build() {
		List<Point> points = initData();
		List<PointCluster> clusters = cluster(points, 4);
		printClusters(clusters);
	}

	public static void main(String[] args) {
		KMeansCluster builder = new KMeansCluster();
		builder.build();
	}
}
KMediodsCluster
public class KMediodsCluster extends AbstractCluster {
	
	public static final double THRESHOLD = 2.0;
	
	public List<Point> initData() {
		List<Point> points = new ArrayList<Point>();
		InputStream in = null;
		BufferedReader br = null;
		try {
			in = KMediodsCluster.class.getClassLoader().getResourceAsStream("kmeans1.txt");
			br = new BufferedReader(new InputStreamReader(in));
			String line = br.readLine();
			while (null != line && !"".equals(line)) {
				StringTokenizer tokenizer = new StringTokenizer(line);
				double x = Double.parseDouble(tokenizer.nextToken());
				double y = Double.parseDouble(tokenizer.nextToken());
				points.add(new Point(x , y));
				line = br.readLine();
			}
		} catch (Exception e) {
			e.printStackTrace();
		} finally {
			IOUtils.closeQuietly(in);
			IOUtils.closeQuietly(br);
		}
		return points;
	}
	
	//隨機生成中心點,並生成初始的K個聚類
	public List<PointCluster> genInitCluster(List<Point> points, int k) {
		List<PointCluster> clusters = new ArrayList<PointCluster>();
		Random random = new Random();
		for (int i = 0, len = points.size(); i < k; i++) {
			PointCluster cluster = new PointCluster();
			Point center = points.get(random.nextInt(len));
			cluster.setCenter(center);
			cluster.getPoints().add(center);
			clusters.add(cluster);
		}
		return clusters;
	}
	
	//將點歸入到聚類中
	public void handleCluster(List<Point> points, List<PointCluster> clusters) {
		for (Point point : points) {
			PointCluster minCluster = null;
			double minDistance = Integer.MAX_VALUE;
			for (PointCluster cluster : clusters) {
				Point center = cluster.getCenter();
				double distance = euclideanDistance(point, center);
//				double distance = manhattanDistance(point, center);
				if (distance < minDistance) {
					minDistance = distance;
					minCluster = cluster;
				}
			}
			if (null != minCluster) {
				minCluster.getPoints().add(point);
			}
		}
		//終止條件定義為原中心點與新中心點距離小於一定閥值
		//當然也可以定義為原中心點等於新中心點
		boolean flag = true;
		for (PointCluster cluster : clusters) {
			Point center = cluster.getCenter();
			System.out.println("center: " + center);
			Point newCenter = cluster.computeMediodsCenter();
			System.out.println("new center: " + newCenter);
//			if (!center.equals(newCenter)) {
			double distance = euclideanDistance(center, newCenter);
			System.out.println("distaince: " + distance);
			if (distance > THRESHOLD) {
				flag = false;
				cluster.setCenter(newCenter);
			}
		}
		if (!flag) {
			for (PointCluster cluster : clusters) {
				cluster.getPoints().clear();
			}
			handleCluster(points, clusters);
		}
	}
	
	public List<PointCluster> cluster(List<Point> points, int k) {
		List<PointCluster> clusters = genInitCluster(points, k);
		handleCluster(points, clusters);
		return clusters;
	}
	
	public void build() {
		List<Point> points = initData();
		List<PointCluster> clusters = cluster(points, 4);
		printClusters(clusters);
	}
	
	public static void main(String[] args) {
		KMediodsCluster builder = new KMediodsCluster();
		builder.build();
	}
}
PointCluster
public class PointCluster {

	private Point center = null;

	private List<Point> points = null;

	public Point getCenter() {
		return center;
	}

	public void setCenter(Point center) {
		this.center = center;
	}

	public List<Point> getPoints() {
		if (null == points) {
			points = new ArrayList<Point>();
		}
		return points;
	}

	public void setPoints(List<Point> points) {
		this.points = points;
	}
	
	public Point computeMeansCenter() {
		int len = getPoints().size();
		double a = 0.0, b = 0.0;
		for (Point point : getPoints()) {
			a += point.getX();
			b += point.getY();
		}
		return new Point(a / len, b / len);
	}
	
	public Point computeMediodsCenter() {
		Point targetPoint = null;
		double distance = Integer.MAX_VALUE;
		for (Point point : getPoints()) {
			double d = 0.0;
			for (Point temp : getPoints()) {
				d += manhattanDistance(point, temp);
			}
			if (d < distance) {
				distance = d;
				targetPoint = point;
			}
		}
		return targetPoint;
	}
	
	public double computeSSE() {
		double result = 0.0;
		for (Point point : getPoints()) {
			result += euclideanDistance(point, center);
		}
		return result;
	}
	
	//計算兩點之間的曼哈頓距離
	protected double manhattanDistance(Point a, Point b) {
		return Math.abs(a.getX() - b.getX()) + Math.abs(a.getY() - b.getY());
	}
	
	//計算兩點之間的歐氏距離
	protected double euclideanDistance(Point a, Point b) {
		double sum =  Math.pow(a.getX() - b.getX(), 2) + Math.pow(a.getY() - b.getY(), 2);
		return Math.sqrt(sum);
	}
}
程式碼託管:https://github.com/fighting-one-piece/repository-datamining.git