python實現mean-shift聚類演算法
阿新 • • 發佈:2020-06-10
本文例項為大家分享了python實現mean-shift聚類演算法的具體程式碼,供大家參考,具體內容如下
1、新建MeanShift.py檔案
import numpy as np # 定義 預先設定 的閾值 STOP_THRESHOLD = 1e-4 CLUSTER_THRESHOLD = 1e-1 # 定義度量函式 def distance(a,b): return np.linalg.norm(np.array(a) - np.array(b)) # 定義高斯核函式 def gaussian_kernel(distance,bandwidth): return (1 / (bandwidth * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((distance / bandwidth)) ** 2) # mean_shift類 class mean_shift(object): def __init__(self,kernel=gaussian_kernel): self.kernel = kernel def fit(self,points,kernel_bandwidth): shift_points = np.array(points) shifting = [True] * points.shape[0] while True: max_dist = 0 for i in range(0,len(shift_points)): if not shifting[i]: continue p_shift_init = shift_points[i].copy() shift_points[i] = self._shift_point(shift_points[i],kernel_bandwidth) dist = distance(shift_points[i],p_shift_init) max_dist = max(max_dist,dist) shifting[i] = dist > STOP_THRESHOLD if(max_dist < STOP_THRESHOLD): break cluster_ids = self._cluster_points(shift_points.tolist()) return shift_points,cluster_ids def _shift_point(self,point,kernel_bandwidth): shift_x = 0.0 shift_y = 0.0 scale = 0.0 for p in points: dist = distance(point,p) weight = self.kernel(dist,kernel_bandwidth) shift_x += p[0] * weight shift_y += p[1] * weight scale += weight shift_x = shift_x / scale shift_y = shift_y / scale return [shift_x,shift_y] def _cluster_points(self,points): cluster_ids = [] cluster_idx = 0 cluster_centers = [] for i,point in enumerate(points): if(len(cluster_ids) == 0): cluster_ids.append(cluster_idx) cluster_centers.append(point) cluster_idx += 1 else: for center in cluster_centers: dist = distance(point,center) if(dist < CLUSTER_THRESHOLD): cluster_ids.append(cluster_centers.index(center)) if(len(cluster_ids) < i + 1): cluster_ids.append(cluster_idx) cluster_centers.append(point) cluster_idx += 1 return cluster_ids
2、呼叫上述py檔案
# -*- coding: utf-8 -*- """ Created on Tue Oct 09 11:02:08 2018 @author: muli """ from sklearn.datasets.samples_generator import make_blobs import matplotlib.pyplot as plt import random import numpy as np import MeanShift def colors(n): ret = [] for i in range(n): ret.append((random.uniform(0,1),random.uniform(0,1))) return ret def main(): centers = [[-1,-1],[-1,1],[1,1]] X,_ = make_blobs(n_samples=300,centers=centers,cluster_std=0.4) mean_shifter = MeanShift.mean_shift() _,mean_shift_result = mean_shifter.fit(X,kernel_bandwidth=0.5) np.set_printoptions(precision=3) print('input: {}'.format(X)) print('assined clusters: {}'.format(mean_shift_result)) color = colors(np.unique(mean_shift_result).size) for i in range(len(mean_shift_result)): plt.scatter(X[i,0],X[i,color = color[mean_shift_result[i]]) plt.show() if __name__ == '__main__': main()
結果如圖所示:
參考連結
以上就是本文的全部內容,希望對大家的學習有所幫助,也希望大家多多支援我們。