1. 程式人生 > >【詳解】CS231n assignment1KNN中不使用迴圈計算距離:從原理到程式

【詳解】CS231n assignment1KNN中不使用迴圈計算距離:從原理到程式

本文主要講述不使用迴圈結構來計算兩個矩陣的歐氏距離, 設訓練集矩陣為train,size為num_train * num_features,設驗證集矩陣為validate,size為num_test,num_features。
因此我們計算每一個驗證集樣本到訓練集樣本的距離,就是將訓練集矩陣train的某一行拿出來與驗證集矩陣validate的某一行計算歐式距離。
這在兩層迴圈中就是這麼做的,相比大家都明白。但是不使用迴圈可以有一點難受。本文就是從演算法原理上到程式本身上去解釋怎麼做的。

首先,我不得不承認,我沒有認識到不使用迴圈和使用迴圈中時間的對比情況,先看一個我實際執行的時間結果:可以看見不使用迴圈要快非常多,只與為什麼兩層迴圈比一層迴圈還快我不是那麼明白。
在這裡插入圖片描述


本來我還沒覺得有不使用迴圈的必要,做完這個實驗,我才開始認真考慮到時間成本,並且花了一些時間在不使用迴圈來求距離上。

以下正文

第一步:完全平方公式

我們明白,所謂的歐氏距離就是先求差,再求平方和,以及求二次方根;我們假設求兩個向量的歐式距離:
在這裡插入圖片描述
不是一般性,我假設x=x1, x2, x3, y=y1, y2, y3,因此,歐氏距離也就是
在這裡插入圖片描述
我們可以做一個變換,根據
在這裡插入圖片描述
我們可以知道,當x=x1, x2, x3, y=y1, y2, y3時,有
在這裡插入圖片描述

第二步:維度驗證

前面,我們假設訓練集矩陣為train,size為num_train * num_features,設驗證集矩陣為validate,size為num_test,num_features。
那麼,假如我們計算好了之後,距離矩陣應該是怎樣的維度呢?
我們假設讓它這樣排列:第i行第j列的距離表示驗證集的第i行向量和訓練集的第j行向量的距離。
因此,這個距離矩陣dist應該是num_validate, num_train。
在計算過程中,我們始終要注意維度是否合理

第三步:計算方法

根據上面解釋,我們將矩陣距離換成多項來運算,也就是
在這裡插入圖片描述
而我們知道,電腦是擅長矩陣運算的。
我們將訓練矩陣train看做a, 將驗證集validate看做b,我們就是要求(a-b)^2
但是這裡都是矩陣,資料是二維的。

比較清楚的,如果我們想要將兩個資料的兩行進行處理,一個方法就是將其中一個矩陣轉置,這樣就變成了一行與一列進行運算,這樣非常適合矩陣運算。

第四步:程式實現

特別注意這裡的資料維度

def compute_distances_no_loops(self, X):
    """
    Compute the distance between each test point in X and each training point
    in self.X_train using no explicit loops.

    Input / Output: Same as compute_distances_two_loops
    """
	num_test = X.shape[0]
	num_train = self.X_train.shape[0]
	dists = np.zeros((num_test, num_train)) 
	
	ab = np.dot(X, self.X_train.T)  # num_test * num_train
	a2 = np.sum(np.square(X), axis=1).reshape(-1, 1)   # num_test * 1
	b2 = np.sum(np.square(self.X_train.T), axis=0).reshape(1, -1)  # 1 * num_train
	dists = -2 * ab + a2 + b2 # 不同維度計算會自動 broadcast
	dists = np.sqrt(dists)

	return dists

程式中,我就是先將X矩陣(也就是驗證集)進行轉置,使其和目標矩陣(距離矩陣)的行數相同。
轉置之後,計算ab就變成矩陣乘法;
對於最後的距離矩陣來說,每一行都是驗證集的對應行與訓練集的距離,所以a^2是相同的;
與此相同,距離矩陣的每一列都是訓練集的對應行與驗證集的距離,所以都加上b^2
最後的加法是可以進行broadcast。會自動的將a2 和 b2 分別加到每一行和每一列

後記

感覺還是沒能說清楚。
在此提醒大家不要著急,尤其是新學者,可能多遇到幾次就好了。
祝各位最後都能順利理解。這裡只能是拋磚引玉了。