機械学習

【深層距離学習】Center Lossを徹底解説 -Pytorchによる実践あり-

CenterLossのアルゴリズム

こんにちは。

現役エンジニアの”はやぶさ”@Cpp_Learningです。最近、距離学習を楽しく勉強しています。

今回は、損失関数のCenter Lossについて勉強したので、備忘録も兼ねて本記事を書きます。

深層距離学習(Deep Metric Learning)とは

深層距離学習の概要については、以下の記事で説明済みなので割愛します。

【深層距離学習】Siamese NetworkとContrastive Lossを徹底解説
Siamese
【深層距離学習】Siamese NetworkとContrastive Lossを徹底解説こんにちは。 現役エンジニアの”はやぶさ”@Cpp_Learningです。前回、距離学習の記事を書きました。 htt...

また、以降の説明でSiamese Network(以下 Siamese)との比較もするので、先にSiamese記事を読むと、本記事の内容を効率よく理解できます。

分類問題の課題

クラスタリング・異常検出・画像分類などを実践すると簡単なサンプル難しいサンプルがあることに気づくと思います。

感覚的にも、人と犬の分類は簡単フクロウと猫の分類は難しいと思いませんか?

ONNX RuntimeとSSDでリアルタイム物体検出出典:TensorFlowの学習済みモデルを変換してONNXRuntimeで物体検出|はやぶさの技術ノート

くるる
くるる
ヒゲや猫耳そっくりの羽角(うかく)が猫の特徴と似てるかな?
いやいや!さすがにフクロウと猫は見間違えないよ

という声が聞こえてきそうですが、人の顔を分類(見分ける)って難しいと思いませんか?

くるる
くるる
双子や六つ子を見分けるのは特に難しい…

簡単なサンプルと難しいサンプル

ここまでの説明で感覚的に簡単なサンプル難しいサンプルがあり、特に顔認識(Face Recognition)は難しいことを理解して頂けたら嬉しいです。

ここからは、なぜ顔認識が難しいのかを説明します。

距離学習(Metric Learning)入門から実践まで
距離学習(Metric Learning)
距離学習(Metric Learning)入門から実践までこんにちは。 現役エンジニアの”はやぶさ”@Cpp_Learningです。距離学習 (metric learning)について勉強...

上記の距離学習(Metric Learning)記事で、特徴量を空間に埋め込めると説明しました。埋め込み手法(次元削減)については後で説明するとして、画像も空間に埋め込むことができます。

ここでいう画像をAさん・Bさん・Cさんの顔写真と考え、埋め込み空間(Embedding Space)を生成したものが下図だとします。

Embedding Space

同じAさんでも髪型や化粧で雰囲気(特徴)が変わったり、横顔がCさんに似ているなどのケースが考えられます。

つまり、Aさんの顔写真を複数準備して空間に埋め込むと、全てが同じ座標というわけではなく、上図の青色プロットのような分布ができます。

同様にBさん・Cさんも分布ができ、AさんとCさんの横顔を埋め込んだ座標は距離が近くなります。「距離が近い=特徴が似ている(同じ)」という意味なので、分類困難です。

改めて図を見てみると、Aさん・Cさんの分布は距離が近いため分類が困難といえます。一方、AさんとBさんの分布は距離が離れているため、分類が簡単といえます(サンプル数を増やすとAさんとBさんの分類も困難になる可能性があります)

  • 分類が難しいサンプル=距離が近い
  • 分類が簡単なサンプル=距離が遠い

距離学習の課題

距離学習を活用すれば、最適な距離(あるいは最適な埋め込み空間)を学習できます。

下図が最適な距離(最適な埋め込み空間)のイメージです。

距離学習(Metric Learning)

上図のように綺麗に分離できれば、精度の良い分類ができます。ただし、以下のようなサンプルではどうでしょう?

Softmax loss

深層距離学習のSiameseは、伸びた”ばね”が縮むように同じクラスの距離を近づけ、縮んだ”ばね”が伸びるように違うクラスの距離を遠くすることで最適な距離を実現します。

しかし、上図のように各サンプルの距離が近い場合、最適な距離で”ばね”が平衡状態にならない可能性があります。

また、Contrastive Loss関数を使用するには、ペア画像のラベル付けが必要です。そのため、分類したいクラスが多い場合、ペアの組み合わせが非常に多くなるという問題があります。

このような問題をCenter Loss関数で解決します。

スポンサーリンク

深層距離学習(Deep Metric Learning)の効果

Center Lossの詳細説明をする前に、深層距離学習(Deep Metric Learning)の効果について説明します。ただし、前知識として深層学習と次元削減について理解していた方が良いので、1つずつ順番に説明していきます。

深層学習と次元削減

多次元データを可視化する場合、PCAt-SNEにより、データを2次元に圧縮することで2D空間に埋め込み可視化することができます。

ニューラルネットワークも入力データの次元削減(データ圧縮)をしています。

例えば、深層学習のCNNを用いた画像分類(データセットMNIST)の場合、28×28サイズの画像をCNNに入力し、10次元ベクトルが出力されます。これは各層で次元削減した結果、最終的には28×28サイズのデータを10次元ベクトルに圧縮したことになります。

もし、ニューラルネットワークのどこかで2次元の特徴ベクトルを出力する層があったとします。その特徴ベクトルを2D空間に埋め込めば、可視化できます。

深層学習と次元削減

上図の例では、CNNから出力された2次元の特徴ベクトルを空間に埋め込み可視化しています。

この2次元特徴ベクトルはFC層(Fully Connected layer)に入力され、推定結果が10次元ベクトルで出力されます。さらに、10次元ベクトルはLoss関数で演算・評価され、ニューラルネットワークの重みを更新(学習)することで、分類精度を向上させていきます。

ニューラルネットワーク構造の工夫で分類精度が向上できる点については一旦置いておいて…

ここで注目したいのは、可視化した2次元特徴ベクトルの距離が最適な空間(綺麗に各クラスが分離され、クラスタリング容易な空間)の方が、精度の良い分類を行える点です。

Softmax Loss と Center Loss

つまり、上図の左よりも右の方が精度の良い分類ができます。

深層距離学習(Deep Metric Learning)の狙い

深層学習により、次元削減が行われ、埋め込み空間を生成できることを説明しました。

深層学習と次元削減

しかし、必ずしも分類が容易な埋め込み空間が生成できるわけではありません。そこで、深層距離学習(Deep Metric Learning)を使い、距離が最適な埋め込み空間(Embedding Space)を生成することで、分類精度を向上させる狙いがあります。

深層距離学習のSiameseでは、ペア画像を入力するネットワーク構造とContrastive Loss関数により、最適な埋め込み空間を生成していました。

CNNとCenterLoss出典:A Discriminative Feature Learning Approach for Deep Face Recognition

今回は、ペア画像用のラベル付けは必要なく、従来のCNNとSoftmax LossにCenter Lossを組み込むだけで、最適な埋め込み空間を生成できます。

Center Lossとは

既に説明済みですが、Center Lossを使うことで、各クラスの距離が最適な埋め込み空間(各クラスの分類が容易な埋め込み空間)を生成できます。

以降からCenter Lossが「どうやって距離が最適な埋め込み空間を実現しているか?」のアルゴリズム詳細を説明します。

従来のSoftmax Loss関数を使った分類

深層学習による画像分類では、Softmax Loss関数が使われます。

【Softmax Loss関数】

Softmax Loss関数

m:ミニバッチサイズ, W:重み, b:バイアス, Xi:CNNの出力特徴ベクトル

よく使われるのでSoftmaxの説明は不要だと思いますが、分類誤差を算出しています。

Softmax Loss と Center Loss

MNISTの分類にSoftmax Loss関数を使い、埋め込み空間を可視化したものが上図(左)です。綺麗に分離できているように見えますが、Center Loss関数を使った方がより良い埋め込み空間を生成できています。

Center Loss関数による最適な埋め込み空間の生成

以下の式がCenter Loss関数です。

【Center Loss関数】

CenterLoss

m:ミニバッチサイズ, Cyi:各クラスyの中心, Xi:CNN出力の特徴ベクトル

MSE(平均二乗誤差)に似ていますが、1/mではなく1/2を使います(条件分岐のないHuber Loss関数という説明の方が適切かな?)

1/2を使う理由は微分すると綺麗に消えて都合が良いからでしょう(下記参照)

【Center Loss関数の微分】

CenterLossの微分

Cyi:各クラスyの中心, Xi:CNN出力の特徴ベクトル

深層学習では、Lossを最小化(最適化)するために学習を行います。つまり、各クラスの中心Cyiと各特徴ベクトルXiの距離を最小化(最適化)するために学習を行います。

CenterLoss

出典:A Discriminative Feature Learning Approach for Deep Face Recognition

くるる
くるる
MSEやHuber Lossに似ているCenter Lossに新規性はないのでは?

と思ったかもしれません。Center Lossの面白いのところは、下記の式で中心Cjを更新する点です。

【Updata Center関数】

center更新

δ(yi=j) = 1 または δ(yi=j) = 0

center更新

j:クラスid, α:ハイパーパラメータ, Cj:ミニバッチ毎の各クラスの中心

MNISTの場合、クラスidが0~9まであります。つまり、j=0~9まであります。もし、j=2の中心移動量ΔC2を算出するなら、Yi=j=2の条件を満たすときのみ演算します。

例えば、ミニバッチの入力データの内、4つのデータがYi=j=2の条件を満たす場合、空間に埋め込んだ4つの特徴ベクトルX1~X4と中心C2のみから移動量ΔC2を算出します。

CenterLossのアルゴリズム

上図の中心Cですが、少し左下に移動させた方が、より最適だと思いませんか?この移動を実現するアルゴリズムについて説明します。

中心更新アルゴリズム

まず最初に、各特徴ベクトルと中心までの距離ベクトルD1~D4を算出します。

CenterLossの中心点更新アルゴリズム

上図を見てみると距離ベクトルD3が比較的大きいことが分かります。つまり、中心Cを左下に移動(X3に近づく方向に移動)させた方が、良いと考えられます。

CenterLossの中心点更新アルゴリズム

各距離ベクトルD1~D4を合計し、5(=1+距離ベクトル数)で割ったものが移動量ΔCです。中心更新の関数を噛砕いて書き直したものが以下です。

【Updata Center関数】

center更新

移動後の中心C = 移動前の中心C – α*ΔC

ハイパーパラメータαで移動量を調整できます。例えばα=1としたとき、各軸で-0.6移動(X3に近づく方向に移動)できることが分かります。

以上のアルゴリズムでミニバッチ毎に各クラスの中心を更新していきます。

Softmax LossとCenter Loss関数を使った分類

改めて、ネットワーク構造と損失関数(Loss関数)を確認してみます。

CNNとCenterLoss出典:A Discriminative Feature Learning Approach for Deep Face Recognition

SoftmaxLossとCenterLoss

λ:ハイパーパラメータ

深層学習により、Softmax Lossで分離誤差を最小化(最適化)しつつ、Center Lossで最適な埋め込み空間を生成しています。

また、先ほど説明した通り、中心はミニバッチ毎に更新しています。

最後にハイパーパラメータのλについてですが、Center Lossの影響を調整するのに使います(下図参照)

CenterLossの影響調整出典:A Discriminative Feature Learning Approach for Deep Face Recognition

実践!深層距離学習 -Center Loss編-

理論の説明はここまでにして、次は実践しましょう!”ググると”以下のソースコードを見つけることができました。

論文に忠実で綺麗なコードはChainer実装ですが、Chainerは開発停止しましたね…ありがとうChainer(´;ω;`)

論文のCenter Lossを少しアレンジしていますが、今回はMNIST_center_loss_pytorchのコードを解説しながら実践していきます。

以降で説明するソースコードはGoogle Colaboratoryで動作確認しました。

import

最初はimportから

8行目をコメントアウトしているのは、本記事のコードをGoogle Colaboratoryに写経すれば簡単に実践できる形に修正したためです。

ニューラルネットワーク設計

次にCNNを設計します。

ip1が空間に埋め込まれ、かつCenter Lossにより演算される特徴ベクトルです。

Center Loss関数

Center Loss関数を自作します。

kerasも良いですが、Loss関数を自作するならPytorchの方が柔軟だと感じています。

埋め込み空間を可視化

空間に埋め込まれた特徴ベクトルip1を可視化する関数も作成します。

可視化は推論フェーズでは不要になりますが、Center Lossの評価やハイパーパラメータを調整するのに便利なので、可視化をオススメしておきます。

学習用の関数

学習用の関数まで作成したら、前準備完了です。

epoch毎に空間に埋め込んだ特徴ベクトルを可視化するので、train()からvisualize()を呼びます。

GPU/CPU設定

Pytorchでは、GPU/CPUどちらを使用するかの設定が必要です。なので、GPUが使用可能な環境ならGPUを使用し、そうでない場合はCPUを使用するようにします。

MNISTデータセットをダウンロード

MNISTデータセットをダウンロードします。

今回は割愛しますが、Siamese記事と同じように訓練データを可視化するのも良いかと。

学習

モデル/損失関数(Loss関数)/オプティマイザ(今回はSGD)/スケジューラ(必須ではない)を設定して、学習します。

※ハイパーパラメータλをloss_weightと定義しています

深層距離学習

本コード実行すると、CNNから出力された特徴ベクトルip1が可視化され、距離が最適な埋め込み空間を生成する様子を確認できます(改めて、概要図を以下に置いておきます)

CenterLossのアルゴリズム

以上で実践も完了です。

まとめ

Center Lossの理論から実践まで徹底解説してみました。

Center LossはSiameseやTriplet Loss(本サイトで紹介してない)と違い、以下の点がスマートだと感じました。

  • ペア画像のラベル付けが不要
  • 従来のCNNとSoftmax LossにCenter Lossを組み込むだけで使える
  • 中心(Center)の更新も簡単
  • ハイパーパラメータλでCenter Lossの影響を調整可能

アルゴリズムが面白いのに加え、シンプルで扱いやすいのが、とても良いですね。

数式を交えた理論の説明は、難しくなりがちですが、図を多く使った丁寧な説明により、”スッと”理解できることを期待して書き上げました。

まとめるの大変でしたが、本記事を読んでくれた多くの人が…

距離学習おもしろい!
深層距離学習で使われる損失関数に興味をもった!

など、モチベーションが向上するような感想を抱いてくれたら嬉しく思います。

はやぶさ
はやぶさ
はやぶさの技術ノート著者:はやぶさ@Cpp_Learningは頑張っている全ての人を応援します!
本記事でPytorch使ったので、関連書籍を紹介
距離学習(Metric Learning)
距離学習(Metric Learning)入門から実践までこんにちは。 現役エンジニアの”はやぶさ”@Cpp_Learningです。距離学習 (metric learning)について勉強...
Siamese
【深層距離学習】Siamese NetworkとContrastive Lossを徹底解説こんにちは。 現役エンジニアの”はやぶさ”@Cpp_Learningです。前回、距離学習の記事を書きました。 htt...
LINEスタンプ配信中!

フクロウのLINEスタンプ

当サイトのマスコットキャラクター

「フクロウのくるる」が

LINEスタンプになりました!

勉強で疲れたあなたに癒しをお届け☆

お迎え待ってます(*・ω・)ノ♪

今すぐお迎えする

40個セットがたったの50コインとお得です