機械学習

PytorchでRAdamを使う -最適化アルゴリズムまとめ-

pytorch optimizer

こんにちは。

現役エンジニアの”はやぶさ”@Cpp_Learningです。最近は Pytorch を使って深層学習を楽しんでいます。

今回は以下の内容について勉強したので、備忘録も兼ねて本記事を書きます。

本記事に含まれる内容
  • 最適化アルゴリズム(Optimizer)の情報整理
  • Pytorch(Pytorch Lightning)から様々なOptimizerを使う方法
  • Pytorch LightningでRAdamを使う

最適化アルゴリズム(Optimizer)とは

最適化アルゴリズム?なにそれ美味しいの?

という人は、以下の書籍でDeep Learningについて勉強すると良いですよ(*・ω・)ノ♪

以下のブログ記事で勉強するのも良いと思います。

最適化アルゴリズム(Optimizer)選定

どの最適化アルゴリズムを使うべきか分からない

という人は、各手法を比較している以下の記事が参考になると思います。

SGD Momentumは良いぞ~

個人的には、最適化アルゴリズムに悩むときは SGD Momentum 使っておけば良いと考えています(以下の資料がとても勉強になります)

RAdamとは

ただ最近、RAdamが良いぞ~という情報をキャッチしたので気になっています。RAdamについては論文や以下の解説記事を読むと良いです。

Pytorchで様々な最適化アルゴリズム(Optimizer)を使う

Pytorch公式は様々な 最適化アルゴリズム(Optimizer)をサポートしていますが、その中にRAdam はありません(2020/03/08時点)

そのため、RAdamを試す場合は自作する必要があります。論文の筆者が公開しているソースコード を使っても良いのですが…

RAdam含め多数の最適化アルゴリズムを実装しているリポジトリを見つけました(下記GitHub参照)。

しかもシンプルで使い易い!ありがとうございます。使わせていただきます。

pytorch-optimizerの使い方

README に 丁寧な説明があるし、examples(サンプルコード)も用意してあるので、pytorchでoptimizerを使ったことのある人なら、直ぐに試せると思います。ただ、個人的に…

はやぶさ
はやぶさ
学習用のループ処理を書きたくない…スッキリとした学習コードを書きたい…

と思ったので、pytorch-lightningを使うことにしました(pytorch-lightningの基本的な使い方は、以下の記事にまとめておきました)。

PyTorch Lightning入門から実践まで
PyTorch Lightning
PyTorch Lightning入門から実践まで -自前データセットで学習し画像分類モデルを生成-ディープラーニングフレームワークPytorchの軽量ラッパー”pytorch-lightning”の入門から実践までのチュートリアル記事を書きました。自前データセットを学習して画像分類モデルを生成し、そのモデルを使って推論するところまでソースコード付で解説しています。...
スポンサーリンク

実践!Pytorch LightningでRAdamを使う

Pytorch LightningからRAdamを使ってみます。

Optimizer処理については ピュアPytorchPytorch Lightning どちらも同じコードを使えます(なので、Pytorch Lightnigを使わない人も参考になるかと)

インストール

使用するライブラリ一覧は以下の通りです。

Requirements
  • torch 1.4.0
  • torchvision 0.5.0
  • torch-optimizer 0.0.1a9
  • pytorch-lightning 0.7.1

※Google Colaboratoryで動作確認しました(2020/03/08)

Google Colaboratoryなら、以下のコマンドでインストールするだけで環境構築完了です。

pip install torch_optimizer pytorch-lightning

以降からソースコードを書いていきます。

Import

最初はimportから

CNN設計

Pytorchの公式チュートリアル を参考にCNNを設計します。

Config(パラメータ設定)

最近、Data Classesの存在に気づいたので、config(パラメータ設定)に使ってみます。

※本記事では未使用のパラメータも定義しています

Data Classesの使い方については、以下の記事が参考になります。

System設計

PyTorch Lightningでは、以下のような学習システムを設計します。

ポイントは以下の通りです。

  • パラメータを定義した conf があること以外は、Pytorch Lightningの基本的なコードと同じ
  • Optimizerのコードは configure_optimizers() に書きます(ピュアPyTorchと同じコードを書けばOKです)
  • あとからSGD MomentumやAdamと比較するため、コードを書いておく(30~31行目参照)

学習

以下のステップで学習します。

  1. Config()のインスタンスを生成
  2. ハイパーパラメータチューニング
  3. CoolSystemのインスタンスを生成
  4. Trainerを定義
  5. fitで学習

今回は以下のようにしました。

パラメータについては、デフォルト設定を採用(Configクラスに定義したパラメータを変更せずに使用)し、Trainerのepochs(学習回数)を指定して学習します。

以上でPytorch LightningからRAdamを使って学習ができます。

ベストプラクティスを求めて

RAdamだけでなく SGD MomentumやAdamも試し、lossを比較してみるとRAdamが最も良い結果になりました。

しかし、データセットがMNISTなので割といい加減でもlossが下がるし、採用するOptimizerに応じてパラメータチューニングしないと、フェアな比較にはならないですね。

Optuna などを活用し、自動チューニングすればフェアな比較ができそうです(*・ω・)ノ♪

Optunaでハイパーパラメータの自動チューニング -Pytorch Lightning編-
optunaでハイパーパラメータ最適化 pytorch-ligthning編
Optunaでハイパーパラメータの自動チューニング -Pytorch Lightning編-ハイパーパラメータ自動最適化フレームワークOptunaについて、入門から実践まで学べる記事を書きました。基本的な使い方からpytorch-lightningへの適用例までソースコード付きで公開しています。ご参考までに。...

最適化アルゴリズム(Optimizer)の情報を整理し、Pytorch LightningからRAdamを使う方法をソースコード付きで解説しました。

pytorch-optimizerを使えば、簡単にRAdamなどのOptimizerを試せるのが嬉しいですね。

ただし、タスク次第でベストなOptimizerやパラメータが異なると思うので、最適なチューニングが必要ですね。

転移学習するなら、Adamより○○がオススメです
くるる
くるる
OptunaでOptimizer比較したよー

など 情報共有して頂けたら嬉しいです。ブログやGitHubなどでコードを公開している人は、Twitterでこっそり教えて下さい ⇒ アカウント:@Cpp_Learning

はやぶさ
はやぶさ
情報共有しながら、一緒に成長できると嬉しいです。よろしくお願いします。

以下 本の紹介

PyTorchの実践的なコードを多数収録
実践的なデータ分析技術を学べる良書
ディープラーニング参考書のベストセラー本
LINEスタンプ配信中!

フクロウのLINEスタンプ

当サイトのマスコットキャラクター

「フクロウのくるる」が

LINEスタンプになりました!

勉強で疲れたあなたに癒しをお届け☆

お迎え待ってます(*・ω・)ノ♪

今すぐお迎えする

40個セットがたったの50コインとお得です