1. 程式人生 > 資訊 >步子太快容易犧牲精度,梯度下降複雜度獲嚴格數學證明

步子太快容易犧牲精度,梯度下降複雜度獲嚴格數學證明

梯度下降是機器學習中求最小值最常用的一種演算法。儘管這種演算法應用廣泛,但是人們關於它計算複雜度的理論研究卻寥寥無幾。

在今年 ACM 舉辦的計算機理論頂會 STOC 上,牛津大學和利物浦大學的學者們,給我們證明了這個理論問題的答案。

他們得到了梯度下降演算法的計算複雜度,等於兩類計算機問題的交集。

這篇文章也成為了 STOC 2021 的最佳論文。

梯度下降的複雜度

四位作者研究人員將目光放在了 TFNP 中兩個子集問題的交集。

第一個子集稱為 PLS(多項式區域性搜尋)。

這是一系列問題,涉及在特定區域中尋找函式的最小值或最大值。

屬於 PLS 的一個典型例子是規劃一條路線的任務,以最短的路線經過一些城市,且只能通過切換城市的順序來改變行程。

通過調整順序可以很容易看出哪些路線縮短了行程,最終你會找到某一條路線,無法進一步縮短路程,這條路線 x 就是你要找到的最小值。

用數學公式來表示就是:(p 是求路線總長度的函式,g (x) 表示改變 x 得到的新路線)

TFNP 問題的第二個子集是 PPAD(有向圖上的多項式奇偶校驗引數)。

這個問題的解來自更復雜的過程,比如 Brouwer 不動點定理,即對於滿足一定條件的連續函式,存在一個點保持不變。

例如,如果你攪動一杯水,Brouwer 不動點定理保證絕對會有一個水分子會回到它最初的位置。

用數學公式來表示就是:

實際應用中,我們不可能要求找到以上兩個問題絕對精確的解,只要誤差小於規定的值 ε 即可,也就是:

PLS 和 PPAD 這兩類問題的交集本身形成了一類稱為 PLS∩PPAD 的問題。

然而,直到現在,研究人員都無法找到 PLS∩PPAD 完全問題的一個天然的例子。所謂的完全問題,就是某類問題中最典型、最難的問題。

現在,來自牛津大學和利物浦大學的學者們終於找到了,梯度下降問題(GD)就是,它等價於 PLS 與 PPAD 的交集。

PPAD∩PLS 是可以通過在有界域上執行梯度下降來解決的所有問題的類別。

而 PLS 與 PPAD 的交集,被他們證明等價於 CLS(連續局域搜尋問題)

PLS 與 PPAD 的任意解(either-solution)就是 PLS∩PPAD 完全問題的解。

到了這裡,梯度下降演算法與這兩個問題有什麼聯絡呢?

請看梯度下降演算法的迭代公式:

在求解實際問題,我們也是在尋找區域性最小值的近似解。我們可以設定兩種計算終止條件:

1、如果 x’與 x 這兩個點的損失函式小於精度 ε:

那麼計算終止,這與前面 PLS 中的 Real-Local-Opt 問題類似。

2、如果 x’與 x 這兩個點的空間距離小於精度 ε:

那麼計算終止,這與前面 PPAD 中的 Brouwer 不動點問題類似。

第一種相當於是 PLS,第二種相當於是 PPAD。

該結果意味著,梯度下降演算法精度和速度之間存在基本聯絡,為獲得更高精度,計算時間將會不成比例地迅速增長。

精度與時間的平衡點

實際上,吳恩達在自己的機器學習課程中已經指出,梯度下降演算法的運算複雜度和步數 n 的平方成正比。

若對精度要求高,需要將學習率 η 設定得更小。

如果機器學習研究者可能希望將實驗的精度提高到 2 倍,那麼可能不得不將梯度下降演算法的執行時間增加到 4 倍。

這表明,梯度下降在實踐中必須做出某種妥協。要麼接受不太高的精度,要麼花費更長的執行時間來換取。

例如,一些對 SGD 進行加速的優化演算法,雖然收斂速度更快,但很有可能陷入區域性最小值。要想獲得精度更高的結果,往往必須迴歸到 SGD。

對於某些精度很重要的問題,執行時長會讓梯度下降演算法變得不可行。

但這並不是說梯度下降的快速演算法不存在,但如果存在著這樣的演算法,將意味著 PLS∩PPAD 也存在快速演算法,但尋找後者的快速演算法要比前者難得多。

最後,這一問題的計算機自動證明程式碼已經開源,有興趣的朋友可以前去觀摩嘗試。

參考連結:

  • https://www.quantamagazine.org/how-big-data-carried-graph-theory-into-new-dimensions-20210819/

  • https://www.youtube.com/watch?v=as720_SRpY0&ab_channel=SIGACTEC

  • https://arxiv.org/abs/2011.01929

  • https://github.com/jfearnley/PPADPLS/