機械学習 PR

【Hummingbird, ONNX Runtime】sklearnで学習した機械学習モデルの推論処理を高速化する

HummingbirdとONNX Runtime
記事内に商品プロモーションを含む場合があります

こんにちは。

現役エンジニアの”はやぶさ”@Cpp_Learningです。仕事でもプライベートでも機械学習で色々やってます。

今回は機械学習モデルの推論処理を高速化する方法について勉強したので、備忘録も兼ねて本記事を書きます。

CPUマシンで検証しました(GPUによる高速化はしません)

実践する内容と対象読者

最初に本記事で実践する内容を紹介します。

本記事のストーリー
  1. sklearnで機械学習モデルを学習(学習済みモデル生成)
  2. 学習済みモデルを高速処理に適したフォーマットに変換
  3. 変換後のモデルを使って推論処理を高速化

機械学習モデルの高速化については、実に様々なテクニックが存在し、高度な専門スキルを保有していないと実践できないケースもあります。

本記事の内容は sklearnで学習・推論を実践した人を対象としますが、それ以上のスキルは求めません。

sklearnの機械学習モデルを手軽に高速化する方法について紹介しますので、興味のある人は続きをどうぞ。

sklearnによる機械学習モデルの学習

今回は機械学習チュートリアルでお馴染みのirisデータセットを使い、ランダムフォレストによる分類モデルを生成します。

Import

まずはimportから

データセット

irisデータセットを取得し、訓練用と評価用に分割します。

ランダムフォレストの分類モデルを学習

今回はランダムフォレスト(RandomForestClassifier)を採用して学習します。

※ハイパーパラメータのチューニングはせず、デフォルト設定を採用

sklearn Ver0.22 以降 n_estimators(木の数)のデフォルト値が10から100に変更したそうです。

推論

学習済みモデルを使い、テストデータのラベルなどを推論します。

100 loops, best of 3: 14 ms per loop

推論にかかった時間は約14msでした。この処理時間をどこまで短くできるかが、本記事の本題になります。

Jupyter Notebook(Google Colabも可)のマジックコマンド %%timeit で処理時間を計測しました

推論結果

推論結果は以下のコードで確認できます。

出力表示は割愛するので、実際に手を動かして確認してみてください。

(学習フェーズ 完)

スポンサーリンク

機械学習モデルの推論処理を高速化する2つの方法

以降からは推論処理の高速化にトライします。今回は主に以下の2つを活用していきます。

ONNXONNX Runtime については以下の記事で紹介済みなので割愛します。

ONNX RuntimeとYoloV3でリアルタイム物体検出
ONNX RuntimeとYoloV3でリアルタイム物体検出Microsoft社製OSS”ONNX Runtime”の入門から実践まで学べる記事です。ONNXおよびONNX Runtimeの概要から、YoloV3モデルによる物体検出(ソースコード付)まで説明します。深層学習や画像処理に興味のある人にオススメの内容です。...

Humingbird については、公式が以下のように説明しています。

Hummingbird is a library for compiling trained traditional ML models into tensor computations. Hummingbird allows users to seamlessly leverage neural network frameworks (such as PyTorch) to accelerate traditional ML models.

(中略)

Currently, you can use Hummingbird to convert your trained traditional ML models into PyTorchTorchScriptONNX, and TVM). Hummingbird supports a variety of ML models and featurizers. These models include scikit-learn Decision Trees and Random Forest, and also LightGBM and XGBoost Classifiers/Regressors.

引用元:Hummingbird|GitHub

上記の説明がよく分からない人でも、以降で紹介するソースコードを読みながら実践をすれば理解できると思います。

頭で理解できなくても、手を動かすことで理解できるケースがあります。

ONNX Runtimeによる高速化

最初に ONNX Runtime による高速化を実践します。まずは以下のコマンドで各種のインストールをします。

pip install skl2onnx
pip install onnxruntime

以降からコード書いていきます。

Import

まずはimportから

sklearnのモデルをONNXフォーマットに変換

sklearnのモデルをONNXフォーマットに変換するには、sklearn-onnx を使います。

わずか数行で変換から”rf_iris.onnx”の保存まで完了です。

ONNX(拡張子が .onnx, .pb, .pbtxt)のモデルは Netron で可視化できます(下図参照)。

Netronで機械学習モデルを可視化

ONNX Runtimeによる高速な推論

ONNX Runtime を使うことで、高速な推論を実現します。

最初に先ほど保存した”rf_iris.onnx”をロードして、セッションを作成します。

次に各種の名前を取得後、ONNX Runtimeによる推論を実行します。

10000 loops, best of 3: 131 µs per loop

推論にかかった時間は約131μsでした。桁違いの速さですね

Humingbirdによる高速化

続いて Hummingbird による高速化を実践します。

Import

まずはimportから

sklearnモデルをPyTorchモデルに変換

以下のコードでsklearnモデルをPyTorchモデルに変換します。

※バックエンドがPyTorchになる

PyTorchモデル(Humingbird製)で推論

以下のようにsklearnと同じ使い心地で簡単に推論できます。

1000 loops, best of 3: 1.83 ms per loop

推論にかかった時間は約1.83msでした。ONNX Runtimeほどではありませんが、高速化できました。

ONNXモデルをPyTorchモデルに変換

Hummingbird を使えば、ONNXモデルをPyTorchモデルに変換することもできます。

推論は以下の通りです。

1000 loops, best of 3: 1.7 ms per loop

推論にかかった時間は約1.7msでした。変換元のモデルは違えど、変換後はPyTorchモデルなので、先程とほとんど同じ処理時間ですね。

ONNXモデルをONNXモデル(Hummingbird製)に変換

Hummingbirdで以下のような変換もできます。

推論は以下の通りです。

100 loops, best of 3: 5.25 ms per loop

推論にかかった時間は約5.25msでした。ちょっと物足りない気もしますが、ONNX Runtimeは使いたくないけど、ONNXモデルを使いたいケースには良いかと。

くるる
くるる
そんなケースあるのー?

うーん。分からん(笑)

はやぶさ
はやぶさ
ユースケースに心当たりのある人は教えてください

まとめ -Hummingbird, ONNX Runtimeによる推論高速化-

ONNX Runtime, sklearn-onnx および Hummingbird を活用し、sklearnモデルの変換推論処理の高速化を実践しました。

結果を以下の表にまとめます。

モデル 推論時間
sklearn(オリジナル) 14ms
sklearn to ONNX(sklearn-onnx製) 131μs(ONNX Runtime使用)
sklearn to PyTorch(Humingbird製) 1.83ms
ONNX to PyTorch(Humingbird製) 1.7ms
ONNX to ONNX(Humingbird製) 5.25ms

推論時間は検証するマシンのスペック次第なので参考程度ですが、どんなマシンでも高速化はできると思います。

ケースバイケースですが、以下のような使い分けができます。

  • 高速化ファーストなら ONNX Runtime
  • sklearnの使い心地で高速化したいなら Humingbird

本記事では説明しませんが、もっと高速化したい場合はGPUマシンを使うも良いと思います。

本記事が参考になれば嬉しいです。

はやぶさ
はやぶさ
機械学習の推論処理を高速化したい人は本記事も含め色々とトライしてみて下さいな

以下 ONNX関連の記事

ONNX RuntimeとYoloV3でリアルタイム物体検出
ONNX RuntimeとYoloV3でリアルタイム物体検出Microsoft社製OSS”ONNX Runtime”の入門から実践まで学べる記事です。ONNXおよびONNX Runtimeの概要から、YoloV3モデルによる物体検出(ソースコード付)まで説明します。深層学習や画像処理に興味のある人にオススメの内容です。...
ONNX RuntimeとSSDでリアルタイム物体検出
TensorFlowの学習済みモデルを変換してONNXRuntimeで物体検出こんにちは。 コンピュータビジョン(『ロボットの眼』開発)が専門の”はやぶさ”@Cpp_Learningです。 前回『ONN...

PICK UP BOOKS

  • 数理モデル入門
    数理モデル
  • Jetoson Nano 超入門
    Jetoson Nano
  • 図解速習DEEP LEARNING
    DEEP LEARNING
  • Pythonによる因果分析
    Python