神經網絡的 BP 算法學習
首先用圖來表達函數,對于以下的函數:
令:
則相應的圖表示為:
圖-1
上圖中,圓形框是函數表達式,邊表示圓形框的輸入和輸出來源。將圖r節點的輸出記為r, 則r節點的輸入為p節點和q節點的輸出,記為 p,q。
那么如何求z(z節點的輸出)關于圖中任意節點輸出的梯度呢?
圖-2
對于任意的x節點,找到所有與x節點直接關聯的節點p,q,x 只能通過p,q去影響z。那么z和x的關系可以表示如下z=z(p,q),p=p(x),q=q(x),根據復合函數求導法則:
這就是BP算法的核心公式。
在實際使用中,有前向和后向兩個流程,以圖-1為例:
在前向流程中,依次計算:
在后向計算過程中,計算過程如下:
可見當某一節點的后續節點都走完回傳過程后,那么該節點的梯度就計算出來了,然后輪到該節點,直接做回傳就可以了。
在BP過程中,最重要的事情就是維護梯度回傳的先后順序。在TensorFlow中,每個tensor就表示了一個函數節點,tensor之間的運算,例如mat、add這些都會生成新的tensor,這種關系最終都在記在graph中,在梯度求導的時候,根據圖中節點的順序,從后至前計算梯度,只要確保:回傳某節點的梯度時,該節點的后續依賴節點已經回傳完畢就行了。
需要注意的幾個地方:
1)梯度的計算可能引起歧義,如下:
根據上面的求導法則
假設:
p=x*x,
q=p*x,
z=p+q,
顯然z=x*x+(x*x)*x,則?z/?x實際為2*x+3*x*x。
在這個公式中(?q/?x)為q對x_2這個輸入的導數,x_2的值正好為x,即(?q/(?(x_2)))=p=x*x。(?z/?p)是z關于p節點輸出的導數,即(?z/?p=?z/?z*?z/(?(p_1))+?z/?q ?q/(?(p_2)))=(1*1+1*x)=(1+x),則(?z/?x)=(1+x)*(2*x)+(1)*(x*x)=2*x+3*x*x。所以計算導數的時候一定要分清楚,是要對某節點的輸入求導還是對某節點的輸出求導。
為方便起見,在loss節點后增加一個節點C=loss,如下:
記(?C/?w)表示C關于w節點輸出的導數,這是回傳算法要求的。其他的,例如:(?q/?x)表示q關于x到q的輸入端上的導數,這在建立節點的時候前向的時候就能算出來。
則上圖的求導過程如下:
2)對于某些RNN,unroll之后,可能具有以下的形式,其求導過程如下:
3)上圖中所有的節點是以標量為例的,實際上也可以為向量,特別注意下面兩種向量的求導形式:
(a)矩陣形式
(b)點乘形式
有了上面的基礎后,就可以看看幾種常見的神經網絡的推導
1)普通的神經網絡
2)LSTM的BPTT推導,一個典型的LSTM示意圖如下:
將其中的函數塊都變為節點,一個典型的seq2seq網絡就如下圖所示:
依賴關系如下:
lstm_cell[t-1]->lstm_cell[t]->logits[t]->prob[t]
prob[0],…,prob[T-1]->loss
基本的前向過程如下:
對任意t:
幾個控制門的輸出:
c_hat=W_c x_(t-1)+U_c h_(t-1)+b_c
c=tanh(c_hat)
i_hat=W_i x_(t-1)+U_i h_(t-1)+b_i
i=sigmoid(i_hat)
f_hat=W_f x_(t-1)+U_f h_(t-1)+b_f
f=sigmoid(f_hat)
o_hat=W_o x_(t-1)+U_o h_(t-1)+b_o
o=sigmoid(o_hat)
狀態S的更新:
S[t]=c⊙i+S[t-1]⊙f
輸出H的計算:
H[t]=o⊙tan?h(S[t-1])
未歸一化的概率計算:
logits[t]=W_p H[t]+b_p
歸一化的概率計算:
prob[t]=softmax(logits[t])
損失函數的計算:
loss+=Y_t⊙(-log?(prob[t])
按照之前BP的法則,計算loss關于各個節點輸出的梯度,如下:
其余節點的計算都很簡單,就不列出計算公式了。BP的時候,需要從T步向第1步計算。具體的代碼參見:
http://git.oschina.net/bamuyy/lstm/blob/master/lstm.py?dir=0&filepath=lstm.py&oid=7dfd64de7ca2ab53ae79008ac9bbc054dcba2f5d&sha=5fd268b98baf62fc4ec3ebd3cadcc130f85894e2
來自:http://mp.weixin.qq.com/s?__biz=MjM5ODIzNDQ3Mw==&mid=2649966184&idx=1&sn=b3347f1a508ccbf799c3bf124d670131&chksm=beca366e89bdbf78fe1bf48b91c6b48b68f988a67fa99dac694748591060c9484543713aedec&scene=0