1. 程式人生 > >K-SVD字典學習及其實現(Python)

K-SVD字典學習及其實現(Python)

演算法思想

演算法求解思路為交替迭代的進行稀疏編碼和字典更新兩個步驟. K-SVD在構建字典步驟中,K-SVD不僅僅將原子依次更新,對於原子對應的稀疏矩陣中行向量也依次進行了修正. 不像MOP,K-SVD不需要對矩陣求逆,而是利用SVD數學分析方法得到了一個新的原子和修正的係數向量.

固定係數矩陣X和字典矩陣D,字典的第k個原子為dk,同時dk對應的稀疏矩陣為X中的第k個行向量xkT. 假設當前更新進行到原子dk,樣本矩陣和字典逼近的誤差為:

YDX2F=Yj=1KdjxjT2F=(YjkdjxjT)dkxjT2F=EkdkxkT2F

在得到當前誤差矩陣E

k後,需要調整dkXkT,使其乘積與Ek的誤差儘可能的小.

如果直接對dkXkT進行更新,可能導致xkT不稀疏. 所以可以先把原有向量xkT中零元素去除,保留非零項,構成向量xkR,然後從誤差矩陣Ek中取出相應的列向量,構成矩陣ERk. 對ERk進行SVD(Singular Value Decomposition)分解,有ERk=UΔVT,由U的第一列更新dk,由V的第一列乘以Δ(1,1)所得結果更新xkR.

Python實現

import numpy as np
from sklearn import linear_model
import scipy.misc
from
matplotlib import pyplot as plt class KSVD(object): def __init__(self, n_components, max_iter=30, tol=1e-6, n_nonzero_coefs=None): """ 稀疏模型Y = DX,Y為樣本矩陣,使用KSVD動態更新字典矩陣D和稀疏矩陣X :param n_components: 字典所含原子個數(字典的列數) :param max_iter: 最大迭代次數 :param tol: 稀疏表示結果的容差 :param n_nonzero_coefs: 稀疏度 """
self.dictionary = None self.sparsecode = None self.max_iter = max_iter self.tol = tol self.n_components = n_components self.n_nonzero_coefs = n_nonzero_coefs def _initialize(self, y): """ 初始化字典矩陣 """ u, s, v = np.linalg.svd(y) self.dictionary = u[:, :self.n_components] def _update_dict(self, y, d, x): """ 使用KSVD更新字典的過程 """ for i in range(self.n_components): index = np.nonzero(x[i, :])[0] if len(index) == 0: continue d[:, i] = 0 r = (y - np.dot(d, x))[:, index] u, s, v = np.linalg.svd(r, full_matrices=False) d[:, i] = u[:, 0].T x[i, index] = s[0] * v[0, :] return d, x def fit(self, y): """ KSVD迭代過程 """ self._initialize(y) for i in range(self.max_iter): x = linear_model.orthogonal_mp(self.dictionary, y, n_nonzero_coefs=self.n_nonzero_coefs) e = np.linalg.norm(y - np.dot(self.dictionary, x)) if e < self.tol: break self._update_dict(y, self.dictionary, x) self.sparsecode = linear_model.orthogonal_mp(self.dictionary, y, n_nonzero_coefs=self.n_nonzero_coefs) return self.dictionary, self.sparsecode if __name__ == '__main__': im_ascent = scipy.misc.ascent().astype(np.float) ksvd = KSVD(300) dictionary, sparsecode = ksvd.fit(im_ascent) plt.figure() plt.subplot(1, 2, 1) plt.imshow(im_ascent) plt.subplot(1, 2, 2) plt.imshow(dictionary.dot(sparsecode)) plt.show()

執行結果:
KSVD字典學習結果