機械学習 PR

角度を用いた深層距離学習(deep metric learning)を徹底解説 -PytorchによるAdaCos実践あり-

角度と距離
記事内に商品プロモーションを含む場合があります

こんにちは。

現役エンジニアの”はやぶさ”@Cpp_Learningです。最近、距離学習を楽しく勉強しています。

今回は、角度を用いた深層距離学習のSphereFace・CosFace・ArcFace・AdaCosについて勉強したので、備忘録も兼ねて本記事を書きます。

深層距離学習とは

深層距離学習については、以下の二つの記事で説明済みなので、本記事では簡単なイメージのみを説明します。

【深層距離学習】Siamese NetworkとContrastive Lossを徹底解説
Siamese
【深層距離学習】Siamese NetworkとContrastive Lossを徹底解説こんにちは。 現役エンジニアの”はやぶさ”@Cpp_Learningです。前回、距離学習の記事を書きました。 htt...
【深層距離学習】Center Lossを徹底解説 -Pytorchによる実践あり-
CenterLossのアルゴリズム
【深層距離学習】Center Lossを徹底解説 -Pytorchによる実践あり-こんにちは。 現役エンジニアの”はやぶさ”@Cpp_Learningです。最近、距離学習を楽しく勉強しています。 今回は、損...

距離学習(深層距離学習含む)とは、空間に埋め込んだデータに対し、クラスが同じものは距離が近く・クラスが違うものは距離が遠くなるように学習する手法です。

深層距離学習

最適な距離(最適な埋め込み空間の生成)により、精度の良い分類を実現するのが狙いです。

  • 分類が難しいサンプル=距離が近い
  • 分類が簡単なサンプル=距離が遠い
  • 距離を最適化することで分類が簡単になる(分類精度が向上)

角度と距離学習

本記事のタイトルにも入っている「角度」が今回のキーワードです。ここでは角度と距離の関係について説明します。

角度を用いない距離学習の概要

最初に角度を用いない距離学習の概要を説明します。下図の点A・点Bの距離を遠ざけたい場合、「距離Dを大きくすれば良い」というのは直観的にも理解しやすいと思います。

ユークリッド距離

もし、Dがユークリッド距離ならば、D1・D2のどちらか、あるいは両方を大きくすれば、Dも大きくなります。

角度を用いない深層距離学習の多くはユークリッド距離関数を採用しており、最適な距離Dを実現するために距離学習用の損失関数(Center Lossなど)を使います。

角度を用いた距離学習の概要

次に角度を用いた距離学習の概要を説明します。下図の点A・点Bの距離を遠ざけたい場合、「角度θを大きくすれば良い」というのが基本的なアイデアです。

角度と距離

空間に埋め込んだデータ同士の角度が小さければ同じクラス、角度が大きければ違うクラスと分類できます。

角度を用いた深層距離学習の多くは角度の大小でデータ間距離を調整し、かつ最適な角度を実現するために距離学習用の損失関数(Arcfaceなど)を使います。

角度とSoftmax Loss

角度を用いた深層距離学習を理解するには、深層学習でお馴染みのSoftmax Loss関数の理解が必要不可欠です。

【Softmax Loss関数】

softmax loss 関数

N:クラス数, W:重みベクトル, b:バイアス, Xi:CNNの出力特徴ベクトル

この式のWとXに注目します。

【内積】

softmax関数の一部

b=0としたとき、この式は内積で表現できます。

内積

以上から、Softmax Lossにキーワードの「角度」が隠れていたことが分かります。この内積を図にしたものが以下です。

内積

重みW はクラス数N と同じ数だけ存在し、なす角θ の大小でデータXi を分類することができます。例えば、二分類(N=2, X1~X6)なら下図のイメージです。

深層距離学習で二値分類

W1との角度が小さいデータをクラスA、Wとの角度が小さいデータをクラスBと分類できます。

  • Softmax Lossに角度(内積)が隠れていた
  • データXi と重みWj とのなす角でクラスを分類
スポンサーリンク

角度を用いた深層距離学習のアルゴリズム概要

ここまでの説明で、埋め込み空間に存在するデータが角度によって分類可能なことが分かると嬉しいです。続いて、角度を用いた深層距離学習のアルゴリズムを説明します。

深層距離学習の課題

まずは下図の埋め込み空間があるとします。

角度による深層距離学習データXはW1,W2とのなす角θが等しい場所に埋め込まれています。これはクラスAとクラスBの境界線上にデータXが存在するということです。

例えるなら、データXは「猫みたいなフクロウ or フクロウみたいな猫」という分類困難なデータということです。

フクロウを猫と間違えて分類

この問題を解決するのに距離学習を使います。

境界線を移動、あるいはデータXと重みW1とのなす角が小さく(W2とのなす角が大きく)なるような最適な角度を学習をすれば解決できそうです(「角度を用いた距離学習の概要」のところでアイデアの説明あり)。

角度による深層距離学習

ここで改めてSoftmax Loss関数と内積の式を確認します。

【Softmax Loss関数】

softmax loss 関数

【内積】

内積

このSoftmax Lossをベースに内積を組み合わせ、かつ||Wj||=1, θを大きくしたい(θ+mなど)と考えながら式変形すると SphereFace, CosFace, ArcFace を作ることができます。

【SphereFace】

Sphere Face

 

【CosFace】

CosFace

 

【Arcface】

ArcFace

もっと詳しく知りたい人のために、以下の資料をオススメしておきます。

AdaCosとは

ArcFaceやCosFaceなどには、スケール:s, マージン:mといったハイパーパラメータがあります。

AdaCosでは、それらのハイパーパラメータを自動で設定してくれます。詳細なアルゴリズムを知りたい人は、以下の記事を参考にすると良いです。

実践!深層距離学習で画像分類 -AdaCos編-

理論の説明はここまでにして、次は実践しましょう!”ググると”角度を用いた深層距離学習に関するソースコードを見つけることができました。

Pytorchを使いたかったので、pytorch-adacos をベースにソースコードを改良し、AdaCosを実践しました。以降で解説します。

以降で解説するソースコードはGoogle Colaboratoryで動作確認しました。

Import

最初はimportから

いくつかのコードをコメントアウトしているのは、本記事のコードをGoogle Colaboratoryにコピペすれば簡単に実践できる形に修正したためです。

GPU/CPU設定

GPUが使用可能な環境ならGPUを使用し、そうでない場合はCPUを使用するように設定します。

前処理

以下の処理を行うコードを作成します。

  • データセット(MNIST)のダウンロード
  • transforms(画像の前処理)
  • データセット作成(train_set/val_set)
  • データローダ生成(train_loader/val_loader)

CNN設計

Optunaでパイパーパラメータの自動チューニングしたときに結果を参考に、CNNを設計します。

FC3層の出力ベクトルがAdaCosに入力されます(後述します)。

深層距離学習(AdaCos)クラス

本来は、metrics.pyをimportするだけで、SphereFace・CosFace・ArcFace・AdaCosを利用できますが、Google Colab上で”サクッ”と動かすために、AdaCosクラスを写経します。

平均値算出クラス(ログ用)とaccuracy関数

同様に、utils.pyををimportするだけで、平均値算出クラスやaccuracy関数を利用できるのですが、今回は写経します。

train関数

train関数を作成します。

validate関数

validate関数を作成します。

モデルのインスタンス生成

先ほど作成したNetクラスとAdaCosクラスのインスタンスを生成し、接続します。

下図がイメージです。

AdaCos

epocs/オプティマイザ/スケジューラ/クロスエントロピー定義

epocs/オプティマイザ(今回はMomentum SGD)/スケジューラ(必須ではない)/CrossEntropyLossを設定します。

学習

以下のコードで学習します。

val_lossが最も小さかったモデルを保存します。

私が試したときは、71epocsでベストモデルを生成できました(下記のように表示されます)。

loss 0.0892 – acc1 99.8300 – val_loss 0.1400 – val_acc 98.7700 => saved best model Epoch [71/100]

以上で実践も完了です。

まとめ

角度を用いた深層距離学習の理論から実践まで徹底解説してみました。

理論については、既に良い資料があったのですが、本記事を読んでから各資料や論文を読むと、よりスッキリ理解できると思います。

また、PytorchでAdaCosを実践している記事は少ないと思いますので、本記事が参考になれば嬉しいです。

AdaCos

最後にAdaCosの良かった点をまとめます。

  • CNNの最終層(FC層)を改良するだけで使える
  • ArcFaceと違いハイパーパラメータを自動で設定してくれる

扱い易くて精度も良い、とても優秀な手法だと感じました。

はやぶさ
はやぶさ
本記事で間違っている点や気になる点(ソースコードの改良ポイントなど)があれば、情報共有して頂けると嬉しいです。

よろしくお願いします。

PyTorchの実践的なコードを多数収録
距離学習(Metric Learning)
距離学習(Metric Learning)入門から実践までこんにちは。 現役エンジニアの”はやぶさ”@Cpp_Learningです。距離学習 (metric learning)について勉強...
PyTorch Lightning
PyTorch Lightning入門から実践まで -自前データセットで学習し画像分類モデルを生成-ディープラーニングフレームワークPytorchの軽量ラッパー”pytorch-lightning”の入門から実践までのチュートリアル記事を書きました。自前データセットを学習して画像分類モデルを生成し、そのモデルを使って推論するところまでソースコード付で解説しています。...
optunaでハイパーパラメータ最適化 pytorch-ligthning編
Optunaでハイパーパラメータの自動チューニング -Pytorch Lightning編-ハイパーパラメータ自動最適化フレームワークOptunaについて、入門から実践まで学べる記事を書きました。基本的な使い方からpytorch-lightningへの適用例までソースコード付きで公開しています。ご参考までに。...
PytorchでMobileNet SSDによるリアルタイム物体検出
PyTorchでMobileNet SSDによるリアルタイム物体検出深層学習フレームワークPytorchを使い、ディープラーニングによる物体検出の記事を書きました。物体検出手法にはいくつか種類がありますが、今回はMobileNetベースSSDによる『リアルタイム物体検出』を行いました。...

PICK UP BOOKS

  • 数理モデル入門
    数理モデル
  • Jetoson Nano 超入門
    Jetoson Nano
  • 図解速習DEEP LEARNING
    DEEP LEARNING
  • Pythonによる因果分析
    Python