剖析DeepMind神經網絡記憶研究:模擬動物大腦實現連續學習

lll_sss 7年前發布 | 25K 次閱讀 神經網絡 深度學習

前幾天,Google DeepMind 公開了一篇新論文《克服神經網絡中的災難性遺忘(Overcoming catastrophic forgetting in neural networks)》,介紹了這家世界頂級的人工智能研究機構在「記憶(memory)」方面的研究進展,之后該公司還在其官方博客上發布了一篇介紹博客,參見機器之心的報道《 為機器賦予記憶:DeepMind 重磅研究提出彈性權重鞏固算法 》。近日,加利福尼亞大學戴維斯分校計算機科學工程和統計學的畢業生 Rylan Schaeffer 在其博客上發表了一篇對該研究的深度解讀文章,機器之心對該文章進行了編譯介紹 

引言

我發現網絡上鋪天蓋地的人工智能相關信息中的絕大多數都可以分為兩類:一是向外行人士解釋進展情況,二是向其他研究者解釋進展。我還沒找到什么好資源能讓有技術背景但對不了解更前沿的進展的人可以自己充電。我想要成為這中間的橋梁——通過為前沿研究提供(相對)簡單易懂的詳細解釋來實現。首先,讓我們從《Overcoming Catastrophic Forgetting in Neural Networks》這篇論文開始吧。

動機

實現通用人工智能的關鍵步驟是獲得連續學習的能力,也就是說,一個代理(agent)必須能在不遺忘舊任務的執行方法的同時習得如何執行新任務。然而,這種看似簡單的特性在歷史上卻一直未能實現。McCloskey 和 Cohen(1989)首先注意到了這種能力的缺失——他們首先訓練一個神經網絡學會了給一個數字加 1,然后又訓練該神經網絡學會了給數字加 2,但之后該網絡就不會給數字加 1 了。他們將這個問題稱為「災難性遺忘(catastrophic forgetting)」,因為神經網絡往往是通過快速覆寫來學習新任務,而這樣就會失去執行之前的任務所必需的參數。

克服災難性遺忘方面的進展一直收效甚微。之前曾有兩篇論文《Policy Distillation》和《Actor-Mimic: Deep Multitask and Transfer Reinforcement Learning》通過在訓練過程中提供所有任務的數據而在混合任務上實現了很好的表現。但是,如果一個接一個地引入這些任務,那么這種多任務學習范式就必須維持一個用于記錄和重放訓練數據的情景記憶系統(episodic memory system)才能獲得良好的表現。這種方法被稱為系統級鞏固(system-level consolidation),該方法受限于對記憶系統的需求,且該系統在規模上必須和被存儲的總記憶量相當;而隨著任務的增長,這種記憶的量也會增長。

然而,你可能也直覺地想到了帶著一個巨大的記憶庫來進行連續學習是錯誤的——畢竟,人類除了能學會走路,也能學會說話,而不需要維持關于學習如何走路的記憶。哺乳動物的大腦是如何實現這一能力的?Yang, Pan and Gan (2009) 說明學習是通過突觸后樹突棘(postsynaptic dendritic spines)隨時間而進行的形成和消除而實現的。樹突棘是指「神經元樹突上的突起,通常從突觸的單個軸突接收輸入」,如下所示:

具體而言,這些研究者檢查了小鼠學習過針對特定新任務的移動策略之后的大腦。當這些小鼠學會了最優的移動方式之后,研究者觀察到樹突棘形成出現了顯著增多。為了消除運動可能導致樹突棘形成的額外解釋,這些研究者還設置了一個進行運動的對照組——這個組沒有觀察到樹突棘形成。然后這些研究者還注意到,盡管大多數新形成的樹突棘后面都消失了,但仍有一小部分保留了下來。2015 年的另一項研究表明當特定的樹突棘被擦除時,其對應的技能也會隨之消失。

Kirkpatrick et. 在論文《Overcoming catastrophic forgetting in neural networks》中提到:「這些實驗發現……說明在新大腦皮層中的連續學習依賴于特定于任務的突觸鞏固(synaptic consolidation),其中知識是通過使一部分突觸更少塑性而獲得持久編碼的,所以能長時間保持穩定。」我們能使用類似的方法(根據每個神經元的重要性來改變單個神經元的可塑性)來克服災難性遺忘嗎?

這篇論文的剩余部分推導并演示了他們的初步答案:可以。

直覺

假設有兩個任務 A 和 B,我們想要一個神經網絡按順序學習它們。當我們談到一個學習任務的神經網絡時,它實際上意味著讓神經網絡調整權重和偏置(統稱參數/θ)以使得神經網絡在該任務上實現更好的表現。之前的研究表明對于大型網絡而言,許多不同的參數配置可以實現類似的表現。通常,這意味著網絡被過參數化了,但是我們可以利用這一點:過參數化(overparameterization)能使得任務 B 的配置可能接近于任務 A 的配置。作者提供了有用的圖形:

在圖中,θ?A 代表在 A 任務中表現最好的 θ 的配置,還存在多種參數配置可以接近這個表現,灰色表示這一配置的集合;在這里使用橢圓來表示是因為有些參數的調整權重比其他參數更大。如果神經網絡隨后被設置為學習任務 B 而對記住任務 A 沒有任何興趣(即遵循任務 B 的誤差梯度),則該網絡將在藍色箭頭的方向上移動其參數。B 的最優解也具有類似的誤差橢圓體,上面由白色橢圓表示。

然而,我們還想記住任務 A。如果我們只是簡單使參數固化,就會按綠色箭頭發展,則處理任務 A 和 B 的性能都將變得糟糕。最好的辦法是根據參數對任務的重要程度來選擇其固化的程度;如果這樣的話,神經網絡參數的變化方向將遵循紅色箭頭,它將試圖找到同時能夠很好執行任務 A 和 B 的配置。作者稱這種算法「彈性權重鞏固(EWC/Elastic Weight Consolidation)」。這個名稱來自于突觸鞏固(synaptic consolidation),結合「彈性的」錨定參數(對先前解決方案的約束限制參數是二次的,因此是彈性的)。

數學解釋

在這里存在兩個問題。第一,為什么錨定函數是二次的?第二,如何判定哪個參數是「重要的」?

在回答這兩個問題之前,我們先要明白從概率的角度來理解神經網絡的訓練意味著什么。假設我們有一些數據 D,我們希望找到最具可能性的參數,它被表示為 p(θ|D)。我們可以是用貝葉斯規則來計算這個條件概率。

如果我們應用對數變換,則方程可以被重寫為:

假設數據 D 由兩個獨立的(independent)部分構成,用于任務 A 的數據 DA 和用于任務 B 的數據 DB。這個邏輯適用于多于兩個任務,但在這里不用詳述。使用獨立性(independence)的定義,我們可以重寫這個方程:

看看(3)右邊的中間三個項。它們看起來很熟悉嗎?它們應該。這三個項是方程(2)的右邊,但是 D 被 DA 代替了。簡單來說,這三個項等價于給定任務 A 數據的網絡參數的條件概率的對數。這樣,我們得到了下面這個方程:

讓我們先解釋一下方程(4)。左側仍然告訴我們如何計算整個數據集的 p(θ| D),但是當求解任務 A 時學習的所有信息都包含在條件概率 p(θ| DA)中。這個條件概率可以告訴我們哪些參數在解決任務 A 中很重要。

下一步是不明確的:「真實的后驗概率是難以處理的,因此,根據 Mackay (19) 對拉普拉斯近似的研究,我們將該后驗近似為一個高斯分布,其帶有由參數θ?A 給定的均值和一個由 Fisher 信息矩陣 F 的對角線給出的對角精度。」

讓我們詳細解釋一下。首先,為什么真正的后驗概率難以處理?論文并沒有解釋,答案是:貝葉斯規則告訴我們

p(θ|DA) 取決于 p(DA)=∫p(DA|θ′)p(θ′)dθ′,其中θ′是參數空間中的參數的可能配置。通常,該積分沒有封閉形式的解,留下數值近似以作為替代。數值近似的時間復雜性相對于參數的數量呈指數級增長,因此對于具有數億或更多參數的深度神經網絡,數值近似是不實際的。

然后,Mackay 關于拉普拉斯近似的工作是什么,跟這里的研究有什么關系?我們使用θ*A 作為平均值,而非數值近似后驗分布,將其建模為多變量正態分布。方差呢?我們將把每個變量的方差指定為方差的倒數的精度。為了計算精度,我們將使用 Fisher 信息矩陣 F。Fisher 信息是「一種測量可觀察隨機變量 X 攜帶的關于 X 所依賴的概率的未知參數θ的信息的量的方法。」在我們的例子中,我們感興趣的是測量來自 DA 的每個數據所攜帶的關于θ的信息的量。Fisher 信息矩陣比數值近似計算更可行,這使得它成為一個有用的工具。

因此,我們可以為我們的網絡在任務 A 上訓練后在任務 B 上再定義一個新的損失函數。讓 LB(θ)僅作為任務 B 的損失。如果我們用 i 索引我們的參數,并且選擇標量λ來影響任務 A 對任務 B 的重要性,則在 EWC 中最小化的函數 L 是:

作者聲稱,EWC 具有相對于網絡參數的數量和訓練示例的數量是線性的運行時間。

實驗和結果

隨機模式

EWC 的第一個測試是簡單地看它能否比梯度下降(GD)更長時間地記住簡單的模式。這些研究者訓練了一個可將隨機二元模式和二元結果關聯起來的神經網絡。如果該網絡看到了一個之前見過的二元模式,那么就通過觀察其信噪比是否超過了一個閾值來評價其是否已經「記住」了該模式。使用這種簡單測試的原因是其具有一個分析解決方案。隨著模式數量的增加,EWC 和 GD 的表現都接近了它們的完美答案。但是 EWC 能夠記憶比 GD 遠遠更多的模式,如下圖所示:

MNIST

這些研究者為 EWC 所進行的第二個測試是一個修改過的 MNIST 版本。其沒有使用被給出的數據,而是生成了三個隨機的排列(permutation),并將每個排列都應用到該數據集中的每張圖像上。任務 A 是對被第一個排列轉換過的 MNIST 圖像中的數字進行分類,任務 B 是對被第二個排列轉換過的圖像中的數字進行分類,任務 C 類推。這些研究者構建了一個全連接的深度神經網絡并在任務 A、B 和 C 上對該網絡進行了訓練,同時在任務 A(在 A 上的訓練完成后)、B(在 B 上的訓練完成后)和 C(在 C 上的訓練完成后)上測試了該網絡的表現。訓練是分別使用隨機梯度下降(SGD)、使用 L2 正則化的均勻參數剛度(uniform parameter-rigidity using L2 regularization)、EWC 獨立完成的。下面是它們的結果:

如預期的一樣,SGD 出現了災難性遺忘;在任務 B 上訓練后在任務 A 上的表現出現了快速衰退,在任務 C 上訓練后更是進一步衰退。使參數更剛性能維持在第一個任務上的表現,但卻不能學習后續的任務。而 EWC 能在成功學習新任務的同時記住如何執行之前的任務。隨著任務數量的增加,EWC 也能維持相對較好的表現,相對地,帶有 dropout 正則化的 SGD 的表現會持續下降,如下所示:

Atari 2600

DeepMind 曾在更早的一篇論文中表明:當一次訓練和測試一個游戲時,深度 Q 網絡(DQN)能夠在多種 Atari 2600 游戲上實現超人類的表現。為了了解 EWC 可以如何在這種更具挑戰性的強化學習環境中實現連續學習,這些研究者對該 DQN 代理進行了修改,使其能夠使用 EWC。但是,他們還需要做出一項額外的修改:在哺乳動物的連續學習中,為了確定一個代理當前正在學習的任務,必需要一個高層面的系統,但該 DQN 代理完全不能做出這樣的確定。為了解決這個問題,研究者為其添加了一個基于 forget-me-not(FMN)過程的在線聚類算法,使得該 DQN 代理能夠為每一個推斷任務維持各自獨立的短期記憶緩存。

這就得到了一個能夠跨兩個時間尺度進行學習的 DQN 代理。在短期內,DQN 代理可以使用 SGD 等優化器(本案例中研究者使用了 RMSProp)來從經歷重放(experience replay)機制中學習。在長期內,該 DQN 代理使用 EWC 來鞏固其從各種任務上學習到的知識。研究者從 DQN 實現了人類水平表現的 Atari 游戲(共 19 個)中隨機選出了 10 個,然后該代理在每個單獨的游戲上訓練了一段時間,如下所示:

這些研究者對三個不同的 DQN 的代理進行了比較。藍色代理沒有使用 EWC,紅色代理使用了 EWC 和 forget-me-not 任務標識符。褐色代理則使用 EWC,并被提供了真實任務標簽。在一個任務上實現人類水平的表現被規范化為分數 1。如你所見,EWC 在這 10 個任務上都實現了接近人類水平的表現,而非 EWC 代理沒能在一個以上的任務上做到這一點。該代理是否被給出了一個真實標簽或是否必須對任務進行推導對結果的影響不大,但我認為這也表明了 FMN 過程的成果,而不僅僅是 EWC 的成功。

接下來的這個部分非常酷。如前所述,EWC 并沒有在這 10 個任務上實現人類水平的表現。為什么會這樣?一個可能的原因是 Fisher 信息矩陣可能對參數重要程度的估計不佳。為了實際驗證這一點,這些研究者研究了在僅一個游戲上訓練的代理的權重。不管這個游戲是什么游戲,它們都表現出了以下模式:如果權重受到了一個均勻隨機擾動的影響,隨著該擾動的增加,該代理的表現(規范化為 1)會下降;而如果權重受到的擾動得到了 Fisher 信息的對角線的逆的影響,那么該分數在面臨更大的擾動時也能保持穩定。這說明 Fisher 信息在確定參數的真正重要性方面是很好的。

然后,研究者嘗試在 null 空間中進行擾動。這本來應該是無效的,但實際上研究者觀察到了與逆 Fisher 空間中的結果類似的結果。這說明使用 Fisher 信息矩陣會導致將一些重要參數標記為不重要的情況——「因此很有可能當前實現的主要限制是其低估了參數的不確定性。」

討論

貝葉斯說明

作者們對 EWC 給出了非常好的貝葉斯解讀:「正式來說,當需要學習新的任務時,網絡參數被先驗所調節,也就是先前任務在給定數據參數上的后驗分布。這能實現先驗任務被約束的更快的參數學習率,并為重要的參數降低學習率。」

網絡重疊

在剛開始時,我提到神經網絡的過參數化(overparameterization)讓 EWC 能實現優異的表現。有一個很合理的問題就是:為了每個不同的任務將網絡劃分到特定部分,這些神經網絡能否給出更好的表現?或者通過共享表征,這些網絡是否能高效地使用其能力?為了解答這個問題,作者們測量了任務對在 Fisher 信息矩陣上的重疊情況(Fisher Overlap)。對高度類似的任務(例如只有一點不同的兩個隨機排列)而言,Fisher Overlap 相當高。即使不相似的任務,Fisher Overlap 也高于 0。隨著網絡深度的增加,Fisher Overlap 也會增加。下圖演示了該結果:

突觸可塑性的理論

研究者還討論了 EWC 可能能為神經可塑性方面的研究提供信息。級聯(Cascade)理論企圖構建突觸狀態的模型來對可塑性和穩定性建模。盡管 EWC 不能隨時間緩和參數,也因此不能遺忘先前的信息,但 EWC 和級聯都能通過讓突觸更不可塑而延展記憶穩定性。最近的一項研究提出除了存儲自身的實際權重之外,突觸也存儲當前權重的不確定性。EWC 是該思路的延展:在 EWC 中,每個突觸存儲三個值:權重、均值和方差。

總結

  • 沒有遺忘的連續學習任務對智能而言很有必要;

  • 研究表明哺乳動物大腦中的突觸鞏固有助于實現連續學習;

  • EWC 通過讓重要參數變的更可塑,從而模擬人工神經網絡中的突觸鞏固;

  • EWC 應用于神經網絡時表現出了比 SGB 更好的表現,而且在更大范圍上有一致的參數穩定性;

  • EWC 提供了說明權重鞏固是連續學習的基礎的線索證據。 

 

來自:http://www.jiqizhixin.com/article/2507

 

 本文由用戶 lll_sss 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
 轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
 本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!