[譯] 計算圖上的微積分:Backpropagation
引言
Backpropagation (BP) 是使得訓練深度模型在計算上可行的關鍵算法。對現代神經網絡,這個算法相較于無腦的實現可以使梯度下降的訓練速度提升千萬倍。而對于模型的訓練來說,這其實是 7 天和 20 萬年的天壤之別。
除了在深度學習中的使用,BP 本身在其他的領域中也是一種強大的計算工具,例如從天氣預報到分析數值的穩定性——只是同一種思想擁有不同的名稱而已。實際上,BP 已經在不同領域中被重復發明了數十次了(參見 Griewank (2010) )。更加一般性且與應用場景獨立的名稱叫做 反向微分 (reverse-mode differentiation)
反向微分,我自己翻譯的,如果有人知道正確的翻譯請告知。
從本質上看,BP 是一種快速求導的技術,可以作為一種不單單用在深度學習中并且可以勝任大量數值計算場景的基本的工具。
計算圖
計算圖是種很好的研究數學表達式的方式。例如,我們有這樣一個表達式e = (a + b) * ( b + 1)。其包含三個操作:兩個加法和一個乘法。為了更好的講述,我們引入兩個中間變量,c和d,這樣每個函數的輸出就有一個變量表示了。現在我們有:
下面可以創建計算圖了,我們將每個表達式和輸入的變量看做是節點。如果一個節點的值是另一個節點的輸入,就畫出一條從該節點到另一節點的邊。
計算圖是有向圖
![[譯] 計算圖上的微積分:Backpropagation](https://simg.open-open.com/show/e57ad464c62bafb15024ad7db86ec2d5.png)
這種樣式的圖在計算機科學領域到處可見,特別是在函數式程序中。他們與依賴圖(dependency graph)或者調用圖(call graph)緊密相關。同樣他們也是非常流行的深度學習框架 Theano 背后的核心抽象。
對于上面用計算圖表示的表達式,我們設置對應輸入變量的值,通過這個圖來計算每個節點的值。例如,假設a = 2, b = 1:
![[譯] 計算圖上的微積分:Backpropagation](https://simg.open-open.com/show/5fdc35aabab1f53a9c2aecd3f0523b24.png)
最終表達式的值就是6。
計算圖上的導數
如果想要理解計算圖上的導數,那么關鍵之處就是理解每條邊上的導數。如果a直接影響c,我們就想知道a如何影響了c。如果a改變了一丟丟,c會發生什么樣的變化?這種東西我們稱c關于a的偏導數。
為了計算在這幅圖中的偏導數,我們需要 和式法則( sum rule )和 乘式法則( product rule ):
![[譯] 計算圖上的微積分:Backpropagation](https://simg.open-open.com/show/874ba65972499bad0ade0fd4efbbfd6a.png)
和式法則 和 乘式法則
下面,在圖中每條邊上都有對應的導數了:
![[譯] 計算圖上的微積分:Backpropagation](https://simg.open-open.com/show/88e59df2f0624a4a52cfba8d4aec3140.png)
那如果我們想知道哪些沒有直接相連的節點之間的影響關系呢?假設就看看e如何被a影響的。如果我們以1的速度改變a,那么c也是以1的速度在改變,導致e發生了2的速度在改變。因此e是以1 * 2的關于a變化的速度在變化。
而一般的規則就是對一個點到另一個點的所有的可能的路徑進行求和,每條路徑對應于該路徑中的所有邊的導數之積。因此,為了獲得e關于b的導數:
![[譯] 計算圖上的微積分:Backpropagation](https://simg.open-open.com/show/0d1d3ce423beebd40eaa608ec4af2731.png)
e 關于 b 的導數
這個值就代表著b改變的速度通過c和d影響到e的速度。
路徑求和的法則其實就是 多元鏈式法則 ( multivariate chain rule )的 另一種思考方式。
分解路徑
路徑求和可能路徑數量很容易就會組合爆炸。
![[譯] 計算圖上的微積分:Backpropagation](https://simg.open-open.com/show/b109b54701c7e57825e1cf1822eb3e50.png)
在上面的圖中,從 X 到 Y 有三條路徑,從 Y 到 Z 也有三條。如果我們希望計算dZ/dX,那么就要對3 * 3 = 9條路徑進行求和了:
![[譯] 計算圖上的微積分:Backpropagation](https://simg.open-open.com/show/fbd93390b199af0df994b91df5330e02.png)
上面的圖有 9 條路徑,但是在圖變得更加復雜的時候,這個數量會指數級地增長。
相比于粗暴地對所有的路徑進行求和,更好的方式是進行因式分解:
![[譯] 計算圖上的微積分:Backpropagation](https://simg.open-open.com/show/24d3061d67f2a0724544c85f24cead2a.png)
注意了!這里就是 前向微分 和 反向微分 誕生的地方! 這兩個算法是通過因式分解來高效計算導數的。通過在每個幾點上反向合并路徑而非顯式地對所有的路徑求和來大幅提升計算的速度。實際上,兩個算法對每條邊的訪問都只有 一次 !
前向微分從圖的輸入開始,一步一步到達終點。在每個節點處,對輸入的路徑進行求和。每個這樣的路徑都表示輸入影響該節點的一個部分。通過將這些影響加起來,我們就得到了輸入影響該節點的全部,也就是關于輸入的導數。
![[譯] 計算圖上的微積分:Backpropagation](https://simg.open-open.com/show/795f92786c3a09120acf37b3dfe5f117.png)
盡管你可能沒有從圖的結構來考慮這個問題,前向微分其實是在學習了微積分后我們的自然的思維方式。相對的,反向微分是從圖的輸出開始,反向一步一步抵達最開始輸入處。在每個節點處,會合了所有源于該節點的路徑。
![[譯] 計算圖上的微積分:Backpropagation](https://simg.open-open.com/show/978ef4bea6a9a3b0477c43ebe24a23e5.png)
前向微分 跟蹤了輸入如何改變每個節點的情況。反向微分 則跟蹤了每個節點如何影響輸出的情況。也就是說,前向微分應用操作d/dX到每個節點,而反向微分應用操作dZ/d到每個節點。
這其實可以看做是動態規劃( dynamic programming )
計算上的勝利
現在,你可能想知道為何人人都關心 反向微分 了。因為它本身看起來像是用一種奇怪的方式和前向微分做了同樣的事情。這里有什么優點?讓我們重新看看剛開始的例子:
![[譯] 計算圖上的微積分:Backpropagation](https://simg.open-open.com/show/b6e6377551cd4b9c532db7da2fa46a13.jpg)
我們可以從b往上使用前向微分。這樣獲得了每個節點關于b的導數。
![[譯] 計算圖上的微積分:Backpropagation](https://simg.open-open.com/show/268b0adefa29d6cee18b8d702404754d.jpg)
我們已經計算得到了de/db,輸出關于一個輸入b的導數。
如果我們從e往下計算反向微分呢?這會得到e關于每個節點的導數:
![[譯] 計算圖上的微積分:Backpropagation](https://simg.open-open.com/show/d65bc4575703acfffedb14e8b77b7d9c.jpg)
我們說到反向微分給出了e關于每個節點的導數,這里的確是每·一·個節點。我們得到了de/da和de/db,e關于輸入a和b的導數。前向微分給了我們輸出關于某一個輸入的導數,而反向微分則給出了所有的導數。
這幅圖中,僅僅是兩個因子在影響,但是你想象一個擁有百萬個輸入和一個輸出的函數。前向微分需要百萬次遍歷計算圖才能得到最終的導數,而反向微分僅僅需要一次就能得到所有的導數!百萬級的速度提升多么美妙!
訓練神經網絡時,我們將衡量神經網絡表現的代價函數看做是那些決定網絡行為的參數的函數。我們希望計算出代價函數關于所有參數的偏導數,從而進行梯度下降( gradient descent )。現在,常常會遇到百萬甚至千萬級的參數的神經網絡。所以,反向微分,也就是 BP,在神經網絡中發揮了關鍵作用!
(有人要問,有使用前向微分更加合理的場景么?當然!因為反向微分得到一個輸出關于所有輸入的導數,前向微分得到了所有輸出關于一個輸出的導數。如果遇到了一個有多個輸出的函數,前向微分肯定更加快速)
這難道不是 Trivial 的嘛!?
剛剛理解 BP 本質時,我的反應是:“Oh,這不就是鏈式法則么!?為什么人們花了這么久才能夠發現!?” 我也并不是唯一有這種反應的。如果你問問“是不是還有更巧妙的計算前饋神經網絡的導數的方法?”,這個答案并不是很難。
但是我覺得,發明 BP 要比其本身看起來更加困難。你看,在BP被發明的那段時間里,人們并不非常關注前饋神經網絡。并且使用導數來訓練網絡并不是很明顯。在人們發現可以快速計算導數時,這種方法才會進入人們的視野。這里存在著循環依賴的關系。
更糟糕的是,在日常思維中很容易忽略這種循環依賴關系。使用導數來訓練神經網絡?肯定你會困在局部最優解中。更明顯的是,計算這些導數的代價非常大。僅僅因為我們知道這個觀點可行,我們并沒有立即開始研究那些不可能的原因究竟是什么。
這也許就是事后諸葛亮的好處。一旦你已經構建出問題本身,最困難的工作便搞定了。
結論
計算導數遠比你想象的要簡單。這就是這篇文章告訴你的主要觀點。實際上,這些方法是反直覺地簡單,我們人類還是會傻傻地重新發現。在深度學習中,計算導數是相當重要的一件事,同樣在其他領域中也是非常有用的知識。只不過還沒成為一種眾人皆知的事物。
還有其他可以學到的東西么?肯定有。
BP 也是一種理解導數在模型中如何流動的工具。在推斷為何某些模型優化非常困難的過程中,BP 也是特別重要的。典型的例子就是在 Recurrent Neural Network 中理解 vanishing gradient 的原因。
最后,我還要補充的是,這些技術中還有很多算法上的經驗可以借鑒。BP 和 前向微分使用了一對技巧(線性化和動態規劃)來更有效地計算導數。如果你真正理解了這些技術,你就可以有效地計算其他有趣包含導數的表達式。后面的博客也會繼續做介紹。
本文給出了關于 BP 的相對抽象的描述。強烈建議大家閱讀 Michael Nielsen 關于 BP 的講述( chapter 2 ),更加貼合神經網絡本身。
致謝
Thank you to Greg Corrado , Jon Shlens , Samy Bengio and Anelia Angelova for taking the time to proofread this post.
Thanks also to Dario Amodei , Michael Nielsen and Yoshua Bengio for discussion of approaches to explaining backpropagation. Also thanks to all those who tolerated me practicing explaining backpropagation in talks and seminar series!