機械学習

【skorch】Pytorchをscikit-learnのような使い心地にするライブラリ

skorch

こんにちは。

現役エンジニアの”はやぶさ”@Cpp_Learningです。最近は Pytorch を使って深層学習を楽しんでいます。

今回はPyTorchをscikit-learnのような使い心地にするライブラリ skorch を紹介します。

skorchとは

skorchとはPyTorchをラップしてscikit-learnと完全互換させた深層学習ライブラリです。

Ecosystem | PyTorch では以下のように紹介しています。

skorch is a high-level library for PyTorch that provides full scikit-learn compatibility.

skorchの公式GitHub だと以下のように説明しています。

A scikit-learn compatible neural network library that wraps PyTorch.

skorchを使うモチベーション

scikit-learnやKerasユーザーがPytorchを使うとき、以下のような不満をもつ(戸惑う)人がいると思います。

  • 学習コードが冗長的
  • 推論コードも冗長的

Pytorchの魅力の1つは学習用のループ処理などを柔軟に書けることです。しかし、柔軟ゆえにコードがユニークになりがちです。

またPytorchではモデルに対し、学習/推論を切り替える必要があります(下記コード参照)。

なので…

くるる
くるる
Pytorchちょっとメンドクサイのよね~

とフクロウの”くるる”@kururu_owl は感じているようです。

skorchを使えば、Pytorchで設計したニューラルネットワークに対し、scikit-learnのようにmodel.fit()で学習, model.predict()で推論という単純明快なコードになります。

実践!skorch

色々と言葉で説明するより、コード見た方が理解しやすいので、以降からコードを交えてskorchの基本的な使い方を紹介します。

インストール

使用するライブラリ一覧は以下の通りです。

Requirements
  • scikit-learn 0.22.1
  • torch 1.4.0
  • skorch 0.7.0

※Google Colaboratoryで動作確認しました(2020/03/14)

Google Colaboratoryなら、以下のコマンドでインストールすれば環境構築完了です。

pip install skorch

以降からソースコードを書いていきます。

Import

最初はimportから

データセット作成

make_classificationを使って適当な2値分類用のデータセットを作成します。

データセット概要
  • サンプル数:300
  • クラス数:2
  • 特徴量(パラメータ数):5
  • 乱数シード:固定

このデータセットをtrain用/test用で分割(train:test = 7:3)するコードが以下です。

make_classificationの使い方については、以下が参考になります。

t-SNEで次元圧縮して可視化

trainデータを可視化してみます。ただし、特徴量が5次元なので、t-SNEで2次元に圧縮してから可視化します。

t-SNEで次元圧縮して可視化

ニューラルネットワーク設計

ニューラルネットワーク設計では、Pytorchのコードがそのまま使えます。

学習

ハイパーパラメータをセットしたmodelを定義し、model.fit()で学習します。簡単ですね!

skorchで学習※自動で学習時のloss/valの変化が表示されます(上図はmax_epochs=10のとき)

check point や early stop も設定できます。

今回は分類なので NeuralNetClassifier を使いましたが、回帰の場合は NeuralNetRegressor を使います

推論と評価

model.predict()で推論します。簡単ですね!

classification_report

modelの評価には classification_report を使いました(スコアだけなら accuracy_score でもOK)。

skorchを使えば、Pytorchで設計したニューラルネットワークに対し、scikit-learnのように単純明快なコードで学習/推論ができます

モデルの保存と読込み

以下のコードでmodelの保存と読込みができます。

ロードしたモデルで推論

学習にも評価にも未使用なデータに対し、ロードしたモデルによる分類をしてみます。

※2値分類なので、結果は ”0” または ”1” で出力されます。

以上がskorchの基本的な使い方です。

【補足】パイプラインやグリッドサーチもできる

skorchを使えば、scikit-learnのpipeline(パイプライン)やGridSearch(グリッドサーチ)のコードを書くこともできます。

最後にskorchを紹介している日本語の記事を紹介します。

スポンサーリンク

まとめ

skorch の基本的な使い方をソースコード付きで説明しました。

PyTorchで”サッと”ネットワーク設計してscikit-learnライクな学習/推論ができるので、とても楽だと感じました。

くるる
くるる
PyTorchに不満がある人は skorch 使えば解決するかも♪

PyTorchの学習コードをスッキリ書ける pytorch-lightning などもあります。

PyTorch Lightning入門から実践まで
PyTorch Lightning
PyTorch Lightning入門から実践まで -自前データセットで学習し画像分類モデルを生成-ディープラーニングフレームワークPytorchの軽量ラッパー”pytorch-lightning”の入門から実践までのチュートリアル記事を書きました。自前データセットを学習して画像分類モデルを生成し、そのモデルを使って推論するところまでソースコード付で解説しています。...

タスクに応じて、使い分けるのも良いかもしれませんね。

はやぶさ
はやぶさ
便利なものは積極的に活用して効率を上げちゃいましょう
現役エンジニアおすすめの仕事道具
現役エンジニアおすすめの仕事道具
【仕事効率化】現役エンジニアおすすめの仕事道具(キーボード・文房具など)現役エンジニアの”はやぶさ”が仕事効率UPのために使っている仕事道具(キーボード・マウス・文房具など)を紹介します。自分に合う道具を使って気持ちよく仕事をしましょう!...
LINEスタンプ配信中!

フクロウのLINEスタンプ

当サイトのマスコットキャラクター

「フクロウのくるる」が

LINEスタンプになりました!

勉強で疲れたあなたに癒しをお届け☆

お迎え待ってます(*・ω・)ノ♪

今すぐお迎えする

40個セットがたったの50コインとお得です