kaggle

深層距離学習(Deep metric learning)を理解する

Deep metric learningとは?

多クラス分類を予測するニューラルネットワークを学習した後、各入力データに対する中間表現を比較することで、入力データ同士の類似度(ユークリッド距離やcosine類似度)を評価する手法で、日本語では深層距離学習と呼ばれています。

応用例として顔認識などがあり、人をラベルとしたデータからモデルを作成して、各レコードの最終層の前の層の出力を中間表現として比較することができます(下図のfeature descriptor)。

L2-constrained Softmax Loss for Discriminative Face Verification

しかし、単純にモデルを作成しても、中間表現がクラスにより異なるように学習しているわけではないので、改善の余地があります。上記の論文では、feature descriptorのL2ノルムが固定長(学習パラメータまたは任意の値)になるように正規化して学習することで精度を向上させています。

正規化することで何故よくなるのかというと、普通に学習を行う場合easy sampleに対して重みが最適化されてしまい、hard sampleの精度が極端に悪くなることが実験的にわかっている(easy sampleのfeature descriptorのL2ノルムが大きくなるように学習が進み、hard sampleは無視してしまう)ようで、L2ノルムに制約を与えることでバランスよく学習を行いたい気持ちのようです。

L2-constrained Softmax Loss for Discriminative Face Verification

深層距離学習に関しては、kaggleのShopee – Price Match Guaranteeというコンペで使用されていました。

テーブルデータに対する深層距離学習の適用は、(私の知ってる範囲では)あまり見られませんが、katsuさんのmetric learningの記事が参考になりました(ココで初めて距離学習の存在を知りました)

距離学習の応用としては、予測精度を上げたい対象に類似した学習データに重みを付けしたり、testに近い対象をvalidationとして学習を行うなど色々と考えられます。

ArcFaceを用いた深層距離学習

最初に冒頭で書いたsoftmax lossを使用した学習では、2クラス分類のタスクを考えた時、各クラス確率と境界条件は下記のようになります。

これに対して、より境界付近の対象をクラス中心に近づけるように学習を行うのが
ArcFace(Additive Angular Margin Loss for Deep Face Recognition)を使用した学習で、kaggleのコンペではこちらが使われています(他にもCosFaceやSphereFaceなどがありますが、精度的にはArcFaceが強いようです)。

ArcFaceでは、入力xとそれに対する重みwを両方正規化してバイアス項を無視することで普通の全結合層であるx・wの積をcos類似度と考えます。この時、正解クラスに対しては、角度θにマージンmを加えることで、これまで境界付近にあった対象にもロスが計算されるため、クラス中心に集まるように学習が進みます。

ArcFace: Additive Angular Margin Loss for Deep Face Recognition

上記の図は、各手法別の境界を点線でマージンをグレーで表していますが、ArcFaceではクラス1に対する境界条件(クラス1の確率=クラス2の確率)は下記のように書けます

θは入力と各クラスに対応する重みの角度となっており、等号が成り立つためにはマージン分クラス1に対応する重みと入力の角度を小さくする必要があり、上図の一番右の図のようにクラス1と判定される領域がθ1の軸方向にm小さくなるため、クラス内の分散が小さくなり、クラス間の乖離が大きくなるように学習が進みます。

論文では、logitをスケールさせるsが掛けられていますが、これは、マージンが大きいとcos(θ+m)の値が小さくなり、softmax関数を通した時の値も小さくなることを防ぐために使われているようです。

実際に使ってみる

今回はNumeraiのデータを使って、単純なL2ノルムで正規化した場合とArcfaceを使用した場合で比較を行います。arcfaceの実装はこちらにあります。

Numeraiのデータは、学習データが50万件程、特徴量が310個あり、時系列データで月毎にラベルがついています(era1~212)。今回は、eraをターゲットとして分類を行い中間表現を得ます。

評価方法ですが、評価を行うテストデータのeraの一部が分類モデルの学習データに含まれてしまうと、学習時に評価対象のテストeraを離すように学習してしまう(leakageとなるので)、学習とテストでeraを分ける必要があるかと思います。

今回は、era120までのデータを学習データとして、8:2に分割してモデルを作り、その後学習に使用していないera120~212のデータから中間表現を得て乖離具合を可視化してみます。

まずは、L2-constrained softmax lossを用いた場合です。ネットワークは下記のように実装しています。

論文からL2ノルムの長さを決めるパラメータαが16だったので同じように設定しています。推論時はextract前までの出力を中間表現として使用します。その他のパラメータは適当に設定。

学習後、評価データであるera121~211の各レコードに対して中間表現を計算して、rapidsのUMAPを用いて可視化を行った結果が下記です。

紫色→黄色の並びで期間が後になっています。狙いはera毎の特徴を捉えて欲しいので、色が異なる対象がより分離していると良い中間表現と言えます。

金融データでノイズが多いことを考えると思ったより分けられているように見えます。

余談ですが、UMAPを使用する場合は、rapidsが爆速なのでお勧めです。(5秒くらいで可視化できる)

続いて、ArcFaceを用いた場合です。実装は下記のようにしています。

パラメータとしてマージンmとスケールsがあるので、これらを変えたパターンを可視化しています。まずはマージンの違いです。

yaaku

微妙だ…

マージンを上げるとhard sampleに対してlossが上がるため、より重点的に学習が進みますが、分類がほぼ不可な難しい対象も多く含まれると考えられるため、乖離ができない中間表現になっていることがわかります(一番下の図)。

続いて、sを変えた場合です。

こちらもsを上げた場合に乖離が上手くできておらず、マージンと同様に難しい対象が多いのでそれらのロスを上げすぎると学習が進まないようです。これは、sを大きくするとsoftmax関数の分子より、出力値が指数関数的に大きくなります。つまり正解クラスとの出力の差が大きくなり、分類が無理な対象がある場合、それらのロスが下がらないので学習が進まなくなってしまうのではないかと考えています。(正しいのかは保証できません)

テストデータの中間表現

Numeraiをやられている方にしか伝わらないかもしれませんが、上記の評価データはera121~212でtargetが既にわかっているデータとなります。

しかし、予測精度を上げたいのはtarget が分かっていない対象(era212以降)となるので、乖離が良くできていそうなL2-constrained softmax lossを使用して中間表現を計算してみます。

データが重いのでera毎に3割ランダムでサンプリングして図示すると下記のようになります。

特徴量には株価の移動平均なども含まれているので、当然かもしれませんが期間毎に分布が集中していることがわかります。また、ターゲットが既知のデータ(era212以下)の中には、212以降のデータと乖離していない部分(紫色しかない領域)が存在することがわかります。

これらは、例えば上げ相場と下げ相場のデータの違いなど、テストデータを予測する上でノイズになっている可能性も十分考えられるため、学習データから抜けば予測精度が上がることも考えられます。

除外した場合の精度の確認

knnでユークリット距離を求めて評価対象と分布が遠い対象を除外して学習することで、精度が上がるのかを実験します。

実験は同じようにNumeraiのデータを使用して、学習データとユークリット距離が近い対象top1000を求めた後、top1000に評価対象のデータが何件入っていたかをカウントして0個の場合は除外します。

knnの近傍探索は下記のように実装しました。

knnはfacebookが開発したfaissを使用するとGPUを使用して高速に計算できます。こちらからwhlファイルを取得してライブラリーをインストールできます。

除外有無の結果が下記の通りとなります。

評価対象を2つの期間(横軸 validation_1(予測が容易な期間) 縦軸 validation_2(予測が困難な期間))で分けていて(学習には使用していないデータ)、downsample有無(有:1 無:0)で色を分けています。seedによりバラツキの影響があるので、それぞれ10seedずつ結果を比較しています。

validation 1 に関してはダウンサンプリングを行った方が精度が良い傾向がありそうで、validation 2に関しては逆の結果となっています。悩ましいところですが、テストで結果を確認しても良さそうに思えます。(down sample の影響よりseedの影響のデカさに驚いている)

まとめ

当たり前ですがkaggleで使われているモダンな手法が必ずしも効くわけではなく、今回のようなhard sampleが多く含まれるデータに対しては、L2-constrainedの方がよく分離できるケースもあることがわかりました(検証不足の可能性もありますが)

また、knnによる近傍探索で評価対象と遠いデータを除外することで予測精度向上に寄与できそうな結果も確認できました。他にも重みを付けて学習するなど応用が考えられるので、また実験していきたいと思います( ゚Д゚)