機械学習

PyTorch Lightning入門から実践まで -自前データセットで学習し画像分類モデルを生成-

PyTorch Lightning

こんにちは。

現役エンジニアの”はやぶさ”@Cpp_Learningです。色んな深層学習フレームワークを触ってきましたが…

Chainer開発停止のニュースを聞いて以来、PyTorchをメインで使おうと考えています。

今回は、スッキリした学習コードを書けるpytorch-lightningについて勉強したので、備忘録および情報共有のために本記事を書きます。

PyTorch Lightningとは

PyTorch Lightningの概要は下図の通りです。

pytorch-lightning

引用元:pytorch-lightning|GitHub

PyTorch Lightningを活用すれば、ピュアPyTorchで書いていた学習用ループ処理などの大部分(上図の青:Trainer)を分離・自動化できるため、ユーザーは研究などで注力したい部分(上図の赤:Lightning Module)のコード作成に専念できます。

PyTorch Lightningのライバルたち

PyTorch Lightningの仲間(ライバル?)に catalystfastaiigniteがあり、いずれもEcosystem | PyTorch で紹介されています。

選択肢が多いので、どれを使うべきか悩む…

という人は、以下の資料が参考になると思います。

今回、Pytorch Lightningを採用した理由は…

はやぶさ
はやぶさ
私が新しいもの好き、かつ使い易そうだと思ったからです
スポンサーリンク

PyTorch Lightningを使うモチベーション

PyTorchは深層学習用のフレームワークなので、ある程度は同じ形式でコードを書けます。しかし、自由度が高いため、学習用のループ処理などがユニークになりがちです。

1人で使う書き捨てコードなら問題ありませんが、チームでコードを共有したり、過去に自作したコードを使いまわす場合は、フレームワークなどの力を借りてコードを形式化する方が効率的です。

(夜のテンションで可読性の悪いコードを書いてしまうことありますよね?)

PyTorch Lightningで研究に注力…みたいなカッコイイこと(公式もそう説明してる)を書きましたが…

個人的にはスッキリした学習コードを書きたいというのが本音です。

PyTorch Lightningの活用により、可読性の良い学習コードを書いて、チームで共有したり、過去に書いたコードを使い回したいと考えています。

公式チュートリアルでPyTorch Lightning入門

PyTorch Lightning公式ドキュメント公式リポジトリのチュートリアルコードを見ると、以下の2Stepで学習できることが分かります。

  1. LightningModuleを継承したクラスを作成
  2. trainer.fit(model)で学習

Google Colaboratoryで直ぐに動かせるデモも用意してあるので、スッキリとしたコードで学習できることに感動してください。

Google Colab上でpytorch-lightningをインストールした時にRESTART RUTIMEという表示がでる場合は、それをクリックしないとデモが動かない可能性があります(2019/12/22時点)

実践!PyTorch Lightningと自前データセットで学習 -画像分類編-

上記の公式チュートリアルを動かしたら、次は自前データで実践してみましょう!今回は、フクロウの種類を分類するモデルの生成にPyTorch Lightningを活用してみます。

  • 以降で紹介するソースコードは、Google Colabで動作確認済み
  • 2019/12/22時点の最新バージョンをインストール

! pip install pytorch-lightning==’0.5.3.2′

課題とデータ収集

アナホリフクロウアフリカオオコノハズクを分類するモデルを生成します。

アナホリフクロウ
アフリカオオコノハズク

本サイトのマスコットキャラクター”くるる”@kururu_owl もアフリカオオコノハズク 縮めてアフコノです。くるる可愛いフクロウ画像を集めたら、以下の場所に保存します。

くるる
くるる
ahukonoディレクトリに保存した画像(*.jpg)は”くるる”多めなので、今から作るモデルは”くるる”分類器といっても過言ではない!
はやぶさ
はやぶさ
くるるちゃん”そういうこと”言っちゃダメ―!笑

データ収集が完了したら、学習用ソースコードを作成します。

import

まずはimportから

GPU/CPU設定

PyTorchでは、GPU/CPUどちらを使用するかの設定が必要です。以下のコードでGPUが使用可能な環境ならGPUを使用し、そうでない場合はCPUを使用するようにします。

transforms/Dataset/DataLoaderについて

PyTorchで自前データを学習する場合、transforms/Dataset/DataLoaderモジュールをカスタムして使うことが多いです。各モジュールの役割は以下の通り。

  • transforms
    • データの前処理を担当するモジュール
  • Dataset
    • データとそれに対応するラベルを1組返すモジュール
    • データを返すときにtransformsを使って前処理したものを返す。
  • DataLoader
    • データセットからデータをバッチサイズに固めて返すモジュール

引用元:PyTorch transforms/Dataset/DataLoaderの基本動作を確認する|Qiita

各モジュールのカスタム方法についてはPyTorch公式チュートリアルや以下の本が参考になります。

PyTorchの実践的なコードを多数収録

参考書やサイトの情報を吸収し、自分が扱いやすいコードを設計できると良いですね(*・ω・)ノ♪

はやぶさ
はやぶさ
以降で私の設計したソースコードを公開するので、情報交換して頂けると嬉しいです!

transforms実装

今回は画像用の前処理クラス MyTransforms を自作します。

trainデータとvalデータ両方の前処理をサポートする設計にしました。

他に入れた方が良い前処理があれば教えて下さい!

Dataset実装

次にDateset生成クラス MyDataset を自作します。

MyDatasetをゼロからつくることも可能ですが、Datasetを継承し、課題に合わせてカスタムするのが良いと思います。

今回は「path=データセットの保存場所」・「key=train/val」をセットすれば、trainデータセット/valデータセットを生成できる設計にしました。

2分類なので16~22行目をべた書きしましたが、ラベルを定義したJSONファイルなどを活用すればもっとスマートに書けたかも。。

DataLoader実装

DataLoaderは用意してあるものを活用するので、自作しません。以下のコード1行書くだけでOKです(既にソースコード冒頭でimport済み)

ここまでは、ピュアPyTorchでも自作するコード です。以降からが PyTorch Lightning特有のコード になります。

システム設計 -PyTorch Lightning編-

PyTorch Lightningでは、ニューラルネットワーク・損失関数・オプティマイザ・train/valデータローダを1つのクラス(LightningModuleを継承したクラス)にまとめて定義します。

今回は、ニューラルネットワーク(CNN)を自作せず、学習済みモデル(Resnet)をファインチューニングします。

また、課題が画像分類なので、損失関数にはcross_entropy・オプティマイザにはSGDを採用しました(定番ですね)。

transforms/Datasetは自作した My○○ を呼ぶだけです。スッキリしたコードですね!

CNN・損失関数を自作する場合は、transforms/Dataset同様に外で定義したクラスや関数を呼ぶ方がスッキリしそうです。

学習

最後に「CoolSystemのインスタンスを生成⇒Trainerを定義⇒fitで学習」します。

Trainer関連の公式ドキュメントを参照し、細かい設定も可能でしたが…今回は公式チュートリアルのコードに書かれた以下のコメントを信じて、デフォルトにしました。

# most basic trainer, uses good defaults (1 gpu)

  • デフォルトのこの部分がイマイチだよ
  • ○○の意図で△△な設定してます

などの情報があれば共有して頂けると嬉しいです(*・ω・)ノ♪

PyTorch Lightningと推論

今回のモチベーションがスッキリした学習コードを書きたいなので、以上の内容で終了しても良かったのですが…アドバイスもらえたら嬉しいので、推論用ソースコードも公開します。

フクロウ分類器の推論結果
フクロウ分類器の推論結果

くるる
くるる
アナホリとアフコノ(くるる)をちゃんと分類できてる!

推論コードの補足説明

PyTorch Lightningで学習すると、checkpoint(重み)が以下の場所に保存されます。

※Trainerがデフォルト設定のとき

この重み(*.ckpt)をロードし、‘state_dict’model.load_state_dict の引数に渡せば、モデル(CNN)に重みがセットされます。

あとは、ピュアPyTorchと同じです。

スポンサーリンク

おわりに

今まで、ごちゃごちゃしたコードを書いてたので…PyTorch Lightningを活用し、スッキリした学習コードを設計してみました。

はやぶさ
はやぶさ
Pytorch Lightningなかなか良いぞ!

と思う一方で、catalystignite も良さそうなので、何をメインで使うか未だに悩んでいます。

もっと、こんな感じで書くとスマートですよ
くるる
くるる
MYBESTコードは、こんな感じだよ

など情報共有して頂けたら嬉しいです。ブログやGitHubなどでコードを公開している人は、Twitterでこっそり教えて下さい ⇒ アカウント:@Cpp_Learning

また、本記事をきかっけに…

深層学習のソースコードを見直したい!

という人が現れたら、とても嬉しいです。

はやぶさ
はやぶさ
情報共有しながら、一緒に成長できると嬉しいです。よろしくお願いします。
PyTorchの実践的なコードを多数収録
LINEスタンプ配信中!

フクロウのLINEスタンプ

当サイトのマスコットキャラクター

「フクロウのくるる」が

LINEスタンプになりました!

勉強で疲れたあなたに癒しをお届け☆

お迎え待ってます(*・ω・)ノ♪

今すぐお迎えする

40個セットがたったの50コインとお得です