http://www.cnblogs.com/pinard/p/6509630.html
在前面我們講到了DNN,以及DNN的特例CNN的模型和前向反向傳播算法,這些算法都是前向反饋的,模型的輸出和模型本身沒有關聯關系。今天我們就討論另一類輸出和模型間有反饋的神經網絡:循環神經網絡(Recurrent Neural Networks ,以下簡稱RNN),它廣泛的用于自然語言處理中的語音識別,手寫書別以及機器翻譯等領域。
1. RNN概述
在前面講到的DNN和CNN中,訓練樣本的輸入和輸出是比較的確定的。但是有一類問題DNN和CNN不好解決,就是訓練樣本輸入是連續的序列,且序列的長短不一,比如基于時間的序列:一段段連續的語音,一段段連續的手寫文字。這些序列比較長,且長度不一,比較難直接的拆分成一個個獨立的樣本來通過DNN/CNN進行訓練。
而對于這類問題,RNN則比較的擅長。那么RNN是怎么做到的呢?RNN假設我們的樣本是基于序列的。比如是從序列索引1到序列索引ττ的。對于這其中的任意序列索引號tt,它對應的輸入是對應的樣本序列中的x(t)x(t)。而模型在序列索引號tt位置的隱藏狀態h(t)h(t),則由x(t)x(t)和在t?1t?1位置的隱藏狀態h(t?1)h(t?1)共同決定。在任意序列索引號tt,我們也有對應的模型預測輸出o(t)o(t)。通過預測輸出o(t)o(t)和訓練序列真實輸出y(t)y(t),以及損失函數L(t)L(t),我們就可以用DNN類似的方法來訓練模型,接著用來預測測試序列中的一些位置的輸出。
下面我們來看看RNN的模型。
2. RNN模型
RNN模型有比較多的變種,這里介紹最主流的RNN模型結構如下:

上圖中左邊是RNN模型沒有按時間展開的圖,如果按時間序列展開,則是上圖中的右邊部分。我們重點觀察右邊部分的圖。
這幅圖描述了在序列索引號tt附近RNN的模型。其中:
1)x(t)x(t)代表在序列索引號tt時訓練樣本的輸入。同樣的,x(t?1)x(t?1)和x(t+1)x(t+1)代表在序列索引號t?1t?1和t+1t+1時訓練樣本的輸入。
2)h(t)h(t)代表在序列索引號tt時模型的隱藏狀態。h(t)h(t)由x(t)x(t)和h(t?1)h(t?1)共同決定。
3)o(t)o(t)代表在序列索引號tt時模型的輸出。o(t)o(t)只由模型當前的隱藏狀態h(t)h(t)決定。
4)L(t)L(t)代表在序列索引號tt時模型的損失函數。
5)y(t)y(t)代表在序列索引號tt時訓練樣本序列的真實輸出。
6)U,W,VU,W,V這三個矩陣是我們的模型的線性關系參數,它在整個RNN網絡中是共享的,這點和DNN很不相同。 也正因為是共享了,它體現了RNN的模型的“循環反饋”的思想。
3. RNN前向傳播算法
有了上面的模型,RNN的前向傳播算法就很容易得到了。
對于任意一個序列索引號tt,我們隱藏狀態h(t)h(t)由x(t)x(t)和h(t?1)h(t?1)得到:
h(t)=σ(z(t))=σ(Ux(t)+Wh(t?1)+b)h(t)=σ(z(t))=σ(Ux(t)+Wh(t?1)+b) 其中σσ為RNN的激活函數,一般為tanhtanh, bb為線性關系的偏倚。
序列索引號tt時模型的輸出o(t)o(t)的表達式比較簡單:
o(t)=Vh(t)+co(t)=Vh(t)+c 在最終在序列索引號tt時我們的預測輸出為:
y^(t)=σ(o(t))y^(t)=σ(o(t)) 通常由于RNN是識別類的分類模型,所以上面這個激活函數一般是softmax。
通過損失函數L(t)L(t),比如對數似然損失函數,我們可以量化模型在當前位置的損失,即y^(t)y^(t)和y(t)y(t)的差距。
4. RNN反向傳播算法推導
有了RNN前向傳播算法的基礎,就容易推導出RNN反向傳播算法的流程了。RNN反向傳播算法的思路和DNN是一樣的,即通過梯度下降法一輪輪的迭代,得到合適的RNN模型參數U,W,V,b,cU,W,V,b,c。由于我們是基于時間反向傳播,所以RNN的反向傳播有時也叫做BPTT(back-PRopagation through time)。當然這里的BPTT和DNN也有很大的不同點,即這里所有的U,W,V,b,cU,W,V,b,c在序列的各個位置是共享的,反向傳播時我們更新的是相同的參數。
為了簡化描述,這里的損失函數我們為對數損失函數,輸出的激活函數為softmax函數,隱藏層的激活函數為tanh函數。
對于RNN,由于我們在序列的每個位置都有損失函數,因此最終的損失LL為:
L=∑t=1τL(t)L=∑t=1τL(t) 其中V,c,V,c,的梯度計算是比較簡單的:
?L?c=∑t=1τ?L(t)?c=∑t=1τ?L(t)?o(t)?o(t)?c=∑t=1τy^(t)?y(t)?L?c=∑t=1τ?L(t)?c=∑t=1τ?L(t)?o(t)?o(t)?c=∑t=1τy^(t)?y(t)?L?V=∑t=1τ?L(t)?V=∑t=1τ?L(t)?o(t)?o(t)?V=∑t=1τ(y^(t)?y(t))(h(t))T?L?V=∑t=1τ?L(t)?V=∑t=1τ?L(t)?o(t)?o(t)?V=∑t=1τ(y^(t)?y(t))(h(t))T 但是W,U,bW,U,b的梯度計算就比較的復雜了。為啥呢?比如我們看看WW在某一序列位置t的梯度損失如下:
?L(t)?W=?L(t)?o(t)?o(t)?h(t)?h(t)?W=(y^(t)?y(t))VT?h(t)?W?L(t)?W=?L(t)?o(t)?o(t)?h(t)?h(t)?W=(y^(t)?y(t))VT?h(t)?W 前面的兩部分部分偏導數都好計算,難點在?h(t)?W?h(t)?W。看似h(t)=σ(Ux(t)+Wh(t?1)+b)h(t)=σ(Ux(t)+Wh(t?1)+b),這樣?h(t)?W?h(t)?W的結果就是激活函數的導數乘以系數h(t?1)h(t?1)轉置。但是問題是:h(t?1)=σ(Ux(t?1)+Wh(t?2)+b)h(t?1)=σ(Ux(t?1)+Wh(t?2)+b),也就是h(t?1)h(t?1)中也含有WW,我們不能簡單的把h(t?1)h(t?1)當做系數,要計算?h(t)?h(t?1)?h(t)?h(t?1)的依賴關系。
也就是說,在反向傳播時,在在某一序列位置t的梯度損失由當前位置的損失和序列索引位置t+1t+1時的梯度損失兩部分共同決定。對于WW在某一序列位置t的梯度損失需要反向傳播一步步的計算。我們定義序列索引tt位置的隱藏狀態的梯度為:
δ(t)=?L(t)?h(t)δ(t)=?L(t)?h(t) 這樣我們可以像DNN一樣從δ(t+1)δ(t+1)遞推δ(t)δ(t) 。
δ(t)=?L(t)?o(t)?o(t)?h(t)+?L(t)?h(t+1)?h(t+1)?h(t)=VT(y^(t)?y(t))+WTδ(t+1)diag(1?(h(t))2)δ(t)=?L(t)?o(t)?o(t)?h(t)+?L(t)?h(t+1)?h(t+1)?h(t)=VT(y^(t)?y(t))+WTδ(t+1)diag(1?(h(t))2) 有了δ(t+1)δ(t+1),計算W,U,bW,U,b就容易了,這里給出W,U,bW,U,b的梯度計算表達式:
?L?W=∑t=1τ?L(t)?W=∑t=1τ?L(t)?h(t)?h(t)?W=∑t=1τdiag(1?(h(t))2)δ(t)(h(t?1))T?L?W=∑t=1τ?L(t)?W=∑t=1τ?L(t)?h(t)?h(t)?W=∑t=1τdiag(1?(h(t))2)δ(t)(h(t?1))T?L?b=∑t=1τ?L(t)?b=∑t=1τ?L(t)?h(t)?h(t)?b=∑t=1τdiag(1?(h(t))2)δ(t)?L?b=∑t=1τ?L(t)?b=∑t=1τ?L(t)?h(t)?h(t)?b=∑t=1τdiag(1?(h(t))2)δ(t)?L?U=∑t=1τ?L(t)?U=∑t=1τ?L(t)?h(t)?h(t)?U=∑t=1τdiag(1?(h(t))2)δ(t)(x(t))T?L?U=∑t=1τ?L(t)?U=∑t=1τ?L(t)?h(t)?h(t)?U=∑t=1τdiag(1?(h(t))2)δ(t)(x(t))T 除了梯度表達式不同,RNN的反向傳播算法和DNN區別不大,因此這里就不再重復總結了。
5. RNN小結
上面總結了通用的RNN模型和前向反向傳播算法。當然,有些RNN模型會有些不同,自然前向反向傳播的公式會有些不一樣,但是原理基本類似。
RNN雖然理論上可以很漂亮的解決序列數據的訓練,但是它也像DNN一樣有梯度消失時的問題,當序列很長的時候問題尤其嚴重。因此,上面的RNN模型一般不能直接用于應用領域。在語音識別,手寫書別以及機器翻譯等NLP領域實際應用比較廣泛的是基于RNN模型的一個特例LSTM,下一篇我們就來討論LSTM模型。
(歡迎轉載,轉載請注明出處。歡迎溝通交流: pinard.liu@eriCSSon.com)
參考資料:
1) Neural Networks and Deep Learning by By Michael Nielsen
2) Deep Learning, book by Ian Goodfellow, Yoshua Bengio, and Aaron Courville
3) UFLDL Tutorial
4)CS231n Convolutional Neural Networks for Visual Recognition, Stanford
分類: 0082. 深度學習