神經網絡的 BP 算法學習

lrm928 8年前發布 | 37K 次閱讀 神經網絡

首先用圖來表達函數,對于以下的函數:

令:

則相應的圖表示為:

圖-1

上圖中,圓形框是函數表達式,邊表示圓形框的輸入和輸出來源。將圖r節點的輸出記為r, 則r節點的輸入為p節點和q節點的輸出,記為 p,q。

那么如何求z(z節點的輸出)關于圖中任意節點輸出的梯度呢?

圖-2

對于任意的x節點,找到所有與x節點直接關聯的節點p,q,x 只能通過p,q去影響z。那么z和x的關系可以表示如下z=z(p,q),p=p(x),q=q(x),根據復合函數求導法則:

這就是BP算法的核心公式。

在實際使用中,有前向和后向兩個流程,以圖-1為例:

在前向流程中,依次計算:

在后向計算過程中,計算過程如下:

可見當某一節點的后續節點都走完回傳過程后,那么該節點的梯度就計算出來了,然后輪到該節點,直接做回傳就可以了。

在BP過程中,最重要的事情就是維護梯度回傳的先后順序。在TensorFlow中,每個tensor就表示了一個函數節點,tensor之間的運算,例如mat、add這些都會生成新的tensor,這種關系最終都在記在graph中,在梯度求導的時候,根據圖中節點的順序,從后至前計算梯度,只要確保:回傳某節點的梯度時,該節點的后續依賴節點已經回傳完畢就行了。

需要注意的幾個地方:

1)梯度的計算可能引起歧義,如下:

根據上面的求導法則

假設:

p=x*x,

q=p*x,

z=p+q,

顯然z=x*x+(x*x)*x,則?z/?x實際為2*x+3*x*x。

在這個公式中(?q/?x)為q對x_2這個輸入的導數,x_2的值正好為x,即(?q/(?(x_2)))=p=x*x。(?z/?p)是z關于p節點輸出的導數,即(?z/?p=?z/?z*?z/(?(p_1))+?z/?q  ?q/(?(p_2)))=(1*1+1*x)=(1+x),則(?z/?x)=(1+x)*(2*x)+(1)*(x*x)=2*x+3*x*x。所以計算導數的時候一定要分清楚,是要對某節點的輸入求導還是對某節點的輸出求導。

為方便起見,在loss節點后增加一個節點C=loss,如下:

記(?C/?w)表示C關于w節點輸出的導數,這是回傳算法要求的。其他的,例如:(?q/?x)表示q關于x到q的輸入端上的導數,這在建立節點的時候前向的時候就能算出來。

則上圖的求導過程如下:

2)對于某些RNN,unroll之后,可能具有以下的形式,其求導過程如下:

3)上圖中所有的節點是以標量為例的,實際上也可以為向量,特別注意下面兩種向量的求導形式:

(a)矩陣形式

(b)點乘形式


有了上面的基礎后,就可以看看幾種常見的神經網絡的推導

1)普通的神經網絡

2)LSTM的BPTT推導,一個典型的LSTM示意圖如下:

將其中的函數塊都變為節點,一個典型的seq2seq網絡就如下圖所示:

依賴關系如下:

lstm_cell[t-1]->lstm_cell[t]->logits[t]->prob[t]

prob[0],…,prob[T-1]->loss

基本的前向過程如下:

對任意t:

幾個控制門的輸出:

c_hat=W_c x_(t-1)+U_c h_(t-1)+b_c

c=tanh(c_hat)

i_hat=W_i x_(t-1)+U_i h_(t-1)+b_i

i=sigmoid(i_hat)

f_hat=W_f x_(t-1)+U_f h_(t-1)+b_f

f=sigmoid(f_hat)

o_hat=W_o x_(t-1)+U_o h_(t-1)+b_o

o=sigmoid(o_hat)

狀態S的更新:

S[t]=c⊙i+S[t-1]⊙f

輸出H的計算:

H[t]=o⊙tan?h(S[t-1])

未歸一化的概率計算:

logits[t]=W_p H[t]+b_p

歸一化的概率計算:

prob[t]=softmax(logits[t])

損失函數的計算:

loss+=Y_t⊙(-log?(prob[t])

按照之前BP的法則,計算loss關于各個節點輸出的梯度,如下:

其余節點的計算都很簡單,就不列出計算公式了。BP的時候,需要從T步向第1步計算。具體的代碼參見:

http://git.oschina.net/bamuyy/lstm/blob/master/lstm.py?dir=0&filepath=lstm.py&oid=7dfd64de7ca2ab53ae79008ac9bbc054dcba2f5d&sha=5fd268b98baf62fc4ec3ebd3cadcc130f85894e2

 

 

來自:http://mp.weixin.qq.com/s?__biz=MjM5ODIzNDQ3Mw==&mid=2649966184&idx=1&sn=b3347f1a508ccbf799c3bf124d670131&chksm=beca366e89bdbf78fe1bf48b91c6b48b68f988a67fa99dac694748591060c9484543713aedec&scene=0

 

 本文由用戶 lrm928 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
 轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
 本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!