機械学習

LightGBMでFocal lossを実装してcustom objective functionを理解する

Focal Lossとは?

facebookが開発した損失関数で、分類問題を解く時にイージーサンプル(予測が容易なサンプル)のロスを小さくすることで、ハードサンプルを集中的に学習させることができる損失関数になります。

分類で一般的に使用されるcross entropy lossと非常に似ている損失曲線をしています。

(参照:Focal Loss for Dense Object Detection)

図でFLがFocal Lossを表しており、γはこちらで与えるパラメータになります。γの値が大きいほど、分類が良くできている対象(容易な対象)に対してロスが0に近いことがわかります。

例えば、学習データに大量のイージーサンプルが含まれている場合、Cross entropy lossではイージーサンプルに対してもロスが発生している(図の青線)ので、それらに対してモデルのパラメータが更新されてしまい偏ったモデルができてしまうことが考えられます。

これに対して、容易なサンプルをダウンサンプリング(除外する)必要がありますが、Focal lossでは、自然にイージーサンプルを除外できることが利点になるかと思います。

yaaku

ノイズが含まれるデータに対しては逆に精度が落ちそう…

LightGBMのCustom objective function

勾配ブースティング系のライブラリはデフォルトで様々な損失関数や評価指標が用意されており、タスクに応じて簡単に切り替えることができます。

また、今回紹介したFocal lossのような、公式ではサポートされていない損失関数や、何かしらのオリジナルの損失関数がある場合でも、損失関数の1,2階微分を計算する関数を与えるだけで最適化してくれます。(中で目的関数をテイラー展開により二次の項までの近似しているため)

最初に標準実装されているBinary Cross EntropyをCustom objective で定義してみます。

データは昔参加したSantanderのコンペのデータを使用します。

kaggleのSantanderで銀メダルを取れたので解法と反省点をまとめておく3~4月の期間kaggleでSantanderというコンペに参加し、運よく銀メダルを受賞できました。 公開されている解法と自身の取...

普通に二値分類を行う場合

まずは、普通に二値分類を行います。

こちらをベースラインとして、損失関数を変更していきます。

損失関数とメトリックを自作する場合

上記と同じ結果を期待して、loglossの1階微分(grad)と2階微分(hessian)を計算する関数とメトリックをfobj, fevalに与えます。余談ですが、勾配の計算はこちらのサイトが便利です。

yaaku

最初に出した結果と違う…

勾配ブースティング系のアルゴリズムは前の木の予測値から、誤差を計算して、誤差にフィッティングするように学習が進みますが、学習の最初の木を構築するための初期値の設定が必要になります。

二値分類の場合、標準のライブラリーでは学習データの正例・負例の対数オッズ比で初期値が設定されているので、学習データを作成する際にそのように初期値を与える必要があります。

最初に計算した値と一致しました。

初期値は与え方が悪いと精度はあまり変わらないですが、収束が遅くなるのは確実なので、複雑なカスタムロスを使用して異常に学習が遅い場合は、初期値の与え方を変えてみると良いかもしれません。

Focal Lossによる学習

続いて、Focal lossを損失関数として学習をさせてみます。Focal Lossは下記の式で表されます。

(参照:Focal Loss for Dense Object Detection)

αはCross entropyでも使用される、正例・負例のロスの重みパラメータとなります。(論文ではαを付与させた方が精度が良かったようです)

上記の式をloglossの形に変換すると下記のように書けます。

これを目的関数として、1階微分と2階微分を計算します。微分の計算はscipyのderivativeメソッドを使用します。

微分の値の違いで結果が微妙に変わってしまっていますが、γ=0としてαを使用しなければ、普通のloglossを損失関数とした時と同じとなります。

これをベースラインとしてγを0(Focal loss無し)からγ=0.5まで値を変えた時のロスの変化を確認しました。(今回はαは使用していません)

γが0.1時にベースラインより精度が良くなっており、微妙ですがFocal Lossにより精度が向上していることがわかります。

まとめ

LightGBMでFocal lossを使用してみました。特に初期値を与えないと収束が遅くなることは知らなかったですし、微分を計算するscipyのメソッドも勉強になりました。

今回のケースでは、劇的な精度向上はしませんでしたが、例えばどちらかのクラスのデータが冗長すぎる場合などに適用してみる価値がありそうです。

最後になりましたが、下記ブログで勉強させていただきました。感謝です。m(__)m