ディープラーニング

Batch Normalization を理解する

Batch Normalizationとは?

Batch Normalizationは、各バッチのデータを使い、ノード毎に各次元を正規化することで、学習を効率的にする手法となります。大きな利点としては、学習率を高く設定しても大丈夫(パラメータのスケールに依存しなくなる)、正則化の効果があるということで、幅広く使用されています。

まずは、何故ノード毎の正規化が必要なのかということに関してです。

こちらの論文の説明から、例として、何かしらの関数F1, F2から計算できる損失関数を考えます。この時の各関数のパラメータθ1とθ2を損失関数が小さくなるように最適化するのが今やりたいことだとします。

F2の関数はF1の出力を入力としているので、F1=xとして考えると、F2のパラメータθ2は勾配降下法から、下記のようにパラメータが更新されます。

                                         (α:学習率 m:バッチサイズ)

更新幅は、F1の出力(=x)を入力とした時のF2の誤差の微分から計算されるため、当然F1の分布がF2の計算に影響を及ぼします。この時、F1(=x)の各次元のスケールが合っていないと勾配降下法を使う場合、中々収束しない(=学習が非効率になる)問題が起きる可能性があります。

同じようにニューラルネットワークでも、入力の各次元のスケールが合っていないと学習が非効率になるため、入力する前に何かしらの正規化をするのが一般的です。

しかし、最初の入力に対しては問題なくても、層が多いニューラルネットワークになると各層の出力の次元のスケールが合っている保証はなく、深ければ深いほど後の層の学習が非効率になることが考えられます。それは論文で内部共変量シフトと呼ばれています。

この各層の出力の分布が変わる内部共変量シフトを減らす取り組みがBatch Normalizationとなります。

どのように正規化するのか?

元々、Batch Normalizationが使われるまでの研究では、全ての入力次元の値を使用して正規化を行う手法が考えられていたようですが、計算コストが高い(次元間の共分散行列を計算するのが手間らしい)ためBatch Normalizationでは、各入力の次元毎に平均を0、分散を1にする正規化を提案しています。

計算は単純ですが、各ミニバッチの中で平均と分散が計算されます。また、正規化した後に平均と分散を調整する役割を持つパラメータγとβが使用されているのがポイントかと思われます。(γとβも学習パラメータとなる)

γとβの気持ちとしては、必ずしも平均0分散1の入力分布が良いわけではなく、これらのパラメータで平均と分散を調整できること、γ=分散、β=平均とすれば変数によっては正規化無しで学習できるので表現力向上が期待できるということだと思われます。

また、BN(Batch Normalization)は活性化関数を通す前に適用することが推奨されており、パラメータβがあるのでバイアス項bは不要となります。

(W:パラメータ, u:入力, b:バイアス)

推論時の計算方法について

推論時は学習時とは違い、バッチ毎の平均ではなく、学習データのバッチ毎の平均と分散の期待値を使用して、各次元の学習済みパラメータγ、βを用いて変換を行っています。

なるほど…と思いましたが、この方法で行う場合、各ミニバッチの平均がバラバラだと(各ミニバッチに入るサンプルが偏っていしまっている場合)、直感的にはうまくいかなさそうにも思えます。

得られる効用

1つ目が学習率を大きく設定しても、安定した学習ができるということが言われています。

一般的に、ディープラーニングでは学習率が大きすぎると勾配爆発or消失が起きる可能性があります。しかし、Batch Normalizationでは、正規化されるためパラメータの大きさ(定数倍a)によって順伝搬の出力が変わりません。

さらに、逆伝搬の計算を考えると下記のようにパラメータのスケールが大きいほど、更新幅が小さくなるので安定する方向になるということでした。

2つ目に正則化の効果があるということです。こちらについては、理論的な説明というよりは、実験からの見解のようですが、各ミニバッチ毎に正規化を行うので、同じレコードでもミニバッチが違うと結果が変わることが、過学習を抑える効果に繋がっているという見解でした。

実験結果

論文では、3層のシンプルなMLPでMNISTを用いて10クラス分類をする実験を行っています。

単純にBNの有り無しでの比較から、BNを適用している場合は、最初から収束性が良いことが(a)の図から確認できます。

(BN適用有無の結果の比較-論文より引用)

また、BNの狙いである各層における入力が正規化されて安定した学習ができていることを示すために、3層目の入力の分位数(15, 50, 85)が学習を進めていく中でどのように変化しているかを示したのが(b), (c)の図になります。

(b)から、BN無しだと学習初期は入力が安定しておらず、学習が非効率になっているとわかる一方で、BNを適用した場合は安定しており、狙い通りになっていることがわかります。

その他、速度・精度の面からBNの効果を加速させる方法としては、学習率を上げること、Dropoutは抜いても過学習せずスピードアップすること、ミニバッチ内の学習データが同質のものにならないようにすることなどが挙げられていました。

特に、最後のミニバッチ内での学習データが同じようなものだと、過学習してしまうのは直感的にもイメージできますね。個人的には、パラメータのγとβも同質のデータだと、各出力によって更新させる方向が大きく変わってしまう可能性があり、学習が非効率になるような気がします。

所感

何気なく使用しているBatch Normalizationですが、奥が深いですね…。特に、ミニバッチ内に同質のサンプルを与えない方が良いという理由がちょっと理解できた気がします。