標籤:更新 決定 簡化 理解 div 理論 強制 統計量 為什麼
轉自:https://www.cnblogs.com/guoyaohua/p/8724433.html
郭耀華‘s Blog欲窮千裡目,更上一層樓
項目首頁:https://github.com/guoyaohua/
欲窮千裡目,更上一層樓
項目首頁:https://github.com/guoyaohua/
【深度學習】深入理解Batch Normalization批標準化
這幾天面試經常被問到BN層的原理,雖然回答上來了,但還是感覺答得不是很好,今天仔細研究了一下Batch Normalization的原理,以下為參考網上幾篇文章總結得出。
Batch Normalization作為最近一年來DL的重要成果,已經廣泛被證明其有效性和重要性。雖然有些細節處理還解釋不清其理論原因,但是實踐證明好用才是真的好,別忘了DL從Hinton對深層網路做Pre-Train開始就是一個經驗領先於理論分析的偏經驗的一門學問。本文是對論文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》的導讀。
機器學習領域有個很重要的假設:IID獨立同分布假設,就是假設訓練資料和測試資料是滿足相同分布的,這是通過訓練資料獲得的模型能夠在測試集獲得好的效果的一個基本保障。那BatchNorm的作用是什麼呢?BatchNorm就是在深度神經網路訓練過程中使得每一層神經網路的輸入保持相同分布的。
接下來一步一步的理解什麼是BN。
為什麼深度神經網路隨著網路深度加深,訓練起來越困難,收斂越來越慢?這是個在DL領域很接近本質的好問題。很多論文都是解決這個問題的,比如ReLU啟用函數,再比如Residual Network,BN本質上也是解釋並從某個不同的角度來解決這個問題的。
一、“Internal Covariate Shift”問題
從論文名字可以看出,BN是用來解決“Internal Covariate Shift”問題的,那麼首先得理解什麼是“Internal Covariate Shift”?
論文首先說明Mini-Batch SGD相對於One Example SGD的兩個優勢:梯度更新方向更準確;並行計算速度快;(為什麼要說這些?因為BatchNorm是基於Mini-Batch SGD的,所以先誇下Mini-Batch SGD,當然也是大實話);然後吐槽下SGD訓練的缺點:超參數調起來很麻煩。(作者隱含意思是用BN就能解決很多SGD的缺點)
接著引入covariate shift的概念:如果ML系統執行個體集合<X,Y>中的輸入值X的分布老是變,這不符合IID假設,網路模型很難穩定的學規律,這不得引入遷移學習才能搞定嗎,我們的ML系統還得去學習怎麼迎合這種分布變化啊。對於深度學習這種包含很多隱層的網路結構,在訓練過程中,因為各層參數不停在變化,所以每個隱層都會面臨covariate shift的問題,也就是在訓練過程中,隱層的輸入分布老是變來變去,這就是所謂的“Internal Covariate Shift”,Internal指的是深層網路的隱層,是發生在網路內部的事情,而不是covariate shift問題只發生在輸入層。
然後提出了BatchNorm的基本思想:能不能讓每個隱層節點的啟用輸入分布固定下來呢?這樣就避免了“Internal Covariate Shift”問題了。
BN不是憑空拍腦袋拍出來的好點子,它是有啟發來源的:之前的研究表明如果在影像處理中對輸入映像進行白化(Whiten)操作的話——所謂白化,就是對輸入資料分布變換到0均值,單位方差的常態分佈——那麼神經網路會較快收斂,那麼BN作者就開始推論了:映像是深度神經網路的輸入層,做白化能加快收斂,那麼其實對於深度網路來說,其中某個隱層的神經元是下一層的輸入,意思是其實深度神經網路的每一個隱層都是輸入層,不過是相對下一層來說而已,那麼能不能對每個隱層都做白化呢?這就是啟發BN產生的原初想法,而BN也確實就是這麼做的,可以理解為對深層神經網路每個隱層神經元的啟用值做簡化版本的白化操作。
二、BatchNorm的本質思想
BN的基本思想其實相當直觀:因為深層神經網路在做非線性變換前的啟用輸入值(就是那個x=WU+B,U是輸入)隨著網路深度加深或者在訓練過程中,其分布逐漸發生位移或者變動,之所以訓練收斂慢,一般是整體分布逐漸往非線性函數的取值區間的上下限兩端靠近(對於Sigmoid函數來說,意味著啟用輸入值WU+B是大的負值或正值),所以這導致反向傳播時低層神經網路的梯度消失,這是訓練深層神經網路收斂越來越慢的本質原因,而BN就是通過一定的正常化手段,把每層神經網路任意神經元這個輸入值的分布強行拉回到均值為0方差為1的標準常態分佈,其實就是把越來越偏的分布強制拉回比較標準的分布,這樣使得啟用輸入值落在非線性函數對輸入比較敏感的地區,這樣輸入的小變化就會導致損失函數較大的變化,意思是這樣讓梯度變大,避免梯度消失問題產生,而且梯度變大意味著學習收斂速度快,能大大加快訓練速度。
THAT’S IT。其實一句話就是:對於每個隱層神經元,把逐漸向非線性函數映射後向取值區間極限飽和區靠攏的輸入分布強制拉回到均值為0方差為1的比較標準的常態分佈,使得非線性變換函數的輸入值落入對輸入比較敏感的地區,以此避免梯度消失問題。因為梯度一直都能保持比較大的狀態,所以很明顯對神經網路的參數調整效率比較高,就是變動大,就是說向損失函數最優值邁動的步子大,也就是說收斂地快。BN說到底就是這麼個機制,方法很簡單,道理很深刻。
上面說得還是顯得抽象,下面更形象地表達下這種調整到底代表什麼含義。
圖1 幾個常態分佈
假設某個隱層神經元原先的啟用輸入x取值符合常態分佈,常態分佈均值是-2,方差是0.5,對應中最左端的淺藍色曲線,通過BN後轉換為均值為0,方差是1的常態分佈(對應中的深藍色圖形),意味著什麼,意味著輸入x的取值常態分佈整體右移2(均值的變化),圖形曲線更平緩了(方差增大的變化)。這個圖的意思是,BN其實就是把每個隱層神經元的啟用輸入分布從偏離均值為0方差為1的常態分佈通過平移均值壓縮或者擴大麴線尖銳程度,調整為均值為0方差為1的常態分佈。
那麼把啟用輸入x調整到這個常態分佈有什麼用?首先我們看下均值為0,方差為1的標準常態分佈代表什麼含義:
圖2 均值為0方差為1的標準常態分佈圖
這意味著在一個標準差範圍內,也就是說64%的機率x其值落在[-1,1]的範圍內,在兩個標準差範圍內,也就是說95%的機率x其值落在了[-2,2]的範圍內。那麼這又意味著什嗎?我們知道,啟用值x=WU+B,U是真正的輸入,x是某個神經元的啟用值,假設非線性函數是sigmoid,那麼看下sigmoid(x)其圖形:
圖3. Sigmoid(x)
及sigmoid(x)的導數為:G’=f(x)*(1-f(x)),因為f(x)=sigmoid(x)在0到1之間,所以G’在0到0.25之間,其對應的圖如下:
圖4 Sigmoid(x)導數圖
假設沒有經過BN調整前x的原先常態分佈均值是-6,方差是1,那麼意味著95%的值落在了[-8,-4]之間,那麼對應的Sigmoid(x)函數的值明顯接近於0,這是典型的梯度飽和區,在這個地區裡梯度變化很慢,為什麼是梯度飽和區?請看下sigmoid(x)如果取值接近0或者接近於1的時候對應導數函數取值,接近於0,意味著梯度變化很小甚至消失。而假設經過BN後,均值是0,方差是1,那麼意味著95%的x值落在了[-2,2]區間內,很明顯這一段是sigmoid(x)函數接近於線性變換的地區,意味著x的小變化會導致非線性函數值較大的變化,也即是梯度變化較大,對應導數函數圖中明顯大於0的地區,就是梯度非飽和區。
從上面幾個圖應該看出來BN在幹什麼了吧?其實就是把隱層神經元啟用輸入x=WU+B從變化不拘一格的常態分佈通過BN操作拉回到了均值為0,方差為1的常態分佈,即原始常態分佈中心左移或者右移到以0為均值,展開或者縮減形態形成以1為方差的圖形。什麼意思?就是說經過BN後,目前大部分Activation的值落入非線性函數的線性區內,其對應的導數遠離導數飽和區,這樣來加速訓練收斂過程。
但是很明顯,看到這裡,稍微瞭解神經網路的讀者一般會提出一個疑問:如果都通過BN,那麼不就跟把非線性函數替換成線性函數效果相同了?這意味著什嗎?我們知道,如果是多層的線性函數變換其實這個深層是沒有意義的,因為多層線性網路跟一層線性網路是等價的。這意味著網路的表達能力下降了,這也意味著深度的意義就沒有了。所以BN為了保證非線性獲得,對變換後的滿足均值為0方差為1的x又進行了scale加上shift操作(y=scale*x+shift),每個神經元增加了兩個參數scale和shift參數,這兩個參數是通過訓練學習到的,意思是通過scale和shift把這個值從標準常態分佈左移或者右移一點並長胖一點或者變瘦一點,每個執行個體挪動的程度不一樣,這樣等價於非線性函數的值從正中心周圍的線性區往非線性區動了動。核心思想應該是想找到一個線性和非線性較好平衡點,既能享受非線性較強表達能力的好處,又避免太靠非線性區兩頭使得網路收斂速度太慢。當然,這是我的理解,論文作者並未明確這樣說。但是很明顯這裡的scale和shift操作是會有爭議的,因為按照論文作者論文裡寫的理想狀態,就會又通過scale和shift操作把變換後的x調整回未變換的狀態,那不是饒了一圈又繞回去原始的“Internal Covariate Shift”問題裡去了嗎,感覺論文作者並未能夠清楚地解釋scale和shift操作的理論原因。
三、訓練階段如何做BatchNorm
上面是對BN的抽象分析和解釋,具體在Mini-Batch SGD下做BN怎麼做?其實論文裡面這塊寫得很清楚也容易理解。為了保證這篇文章完整性,這裡簡單說明下。
假設對於一個深層神經網路來說,其中兩層結構如下:
圖5 DNN其中兩層
要對每個隱層神經元的啟用值做BN,可以想象成每個隱層又加上了一層BN操作層,它位於X=WU+B啟用值獲得之後,非線性函數變換之前,其圖示如下:
圖6. BN操作
對於Mini-Batch SGD來說,一次訓練過程裡麵包含m個訓練執行個體,其具體BN操作就是對於隱層內每個神經元的啟用值來說,進行如下變換:
要注意,這裡t層某個神經元的x(k)不是指原始輸入,就是說不是t-1層每個神經元的輸出,而是t層這個神經元的線性啟用x=WU+B,這裡的U才是t-1層神經元的輸出。變換的意思是:某個神經元對應的原始的啟用x通過減去mini-Batch內m個執行個體獲得的m個啟用x求得的均值E(x)併除以求得的方差Var(x)來進行轉換。
上文說過經過這個變換後某個神經元的啟用x形成了均值為0,方差為1的常態分佈,目的是把值往後續要進行的非線性變換的線性區拉動,增大導數值,增強反向傳播資訊流動性,加快訓練收斂速度。但是這樣會導致網路表達能力下降,為了防止這一點,每個神經元增加兩個調節參數(scale和shift),這兩個參數是通過訓練來學習到的,用來對變換後的啟用反變換,使得網路表達能力增強,即對變換後的啟用進行如下的scale和shift操作,這其實是變換的反操作:
BN其具體操作流程,如論文中描述的一樣:
過程非常清楚,就是上述公式的流程化描述,這裡不解釋了,直接應該能看懂。
四、BatchNorm的推理(Inference)過程
BN在訓練的時候可以根據Mini-Batch裡的若干訓練執行個體進行啟用數值調整,但是在推理(inference)的過程中,很明顯輸入就只有一個執行個體,看不到Mini-Batch其它執行個體,那麼這時候怎麼對輸入做BN呢?因為很明顯一個執行個體是沒法算執行個體集合求出的均值和方差的。這可如何是好?
既然沒有從Mini-Batch資料裡可以得到的統計量,那就想其它辦法來獲得這個統計量,就是均值和方差。可以用從所有訓練執行個體中獲得的統計量來代替Mini-Batch裡面m個訓練執行個體獲得的均值和方差統計量,因為本來就打算用全域的統計量,只是因為計算量等太大所以才會用Mini-Batch這種簡化方式的,那麼在推理的時候直接用全域統計量即可。
決定了獲得統計量的資料範圍,那麼接下來的問題是如何獲得均值和方差的問題。很簡單,因為每次做Mini-Batch訓練時,都會有那個Mini-Batch裡m個訓練執行個體獲得的均值和方差,現在要全域統計量,只要把每個Mini-Batch的均值和方差統計量記住,然後對這些均值和方差求其對應的數學期望即可得出全域統計量,即:
有了均值和方差,每個隱層神經元也已經有對應訓練好的Scaling參數和Shift參數,就可以在推導的時候對每個神經元的啟用資料計算NB進行變換了,在推理過程中進行BN採取如下方式:
這個公式其實和訓練時
是等價的,通過簡單的合并計算推導就可以得出這個結論。那麼為啥要寫成這個變換形式呢?我猜作者這麼寫的意思是:在實際啟動並執行時候,按照這種變體形式可以減少計算量,為啥呢?因為對於每個隱層節點來說:
都是固定值,這樣這兩個值可以事先算好存起來,在推理的時候直接用就行了,這樣比原始的公式每一步驟都現算少了除法的運算過程,乍一看也沒少多少計算量,但是如果隱層節點個數多的話節省的計算量就比較多了。
五、BatchNorm的好處
BatchNorm為什麼NB呢,關鍵還是效果好。①不僅僅極大提升了訓練速度,收斂過程大大加快;②還能增加分類效果,一種解釋是這是類似於Dropout的一種防止過擬合的正則化表達方式,所以不用Dropout也能達到相當的效果;③另外調參過程也簡單多了,對於初始化要求沒那麼高,而且可以使用大的學習率等。總而言之,經過這麼簡單的變換,帶來的好處多得很,這也是為何現在BN這麼快流行起來的原因。
[轉] 深入理解Batch Normalization批標準化