機械学習 PR

【NNabla】Neural Network Librariesと学習済みモデルによる推論 -HDF5ファイル編-

NNablaと学習済みモデルで推論
記事内に商品プロモーションを含む場合があります

こんにちは。

ディープラーニングお兄さんの”はやぶさ”@Cpp_Learningです。

前回、Sony製の深層学習フレームワーク “Neural Network Libraries(NNabla)” のPython APIを使って学習から推論までを実践するチュートリアル記事を書きました↓

Neural Network LibrariesによるDeep Learningチュートリアル
【NNabla】実践!Neural Network Librariesで学習から推論まで直観的にニューラルネットワークの実装ができるソニー製の深層学習フレームワーク”NNL”によるDeep Learning(ディープラーニング)入門チュートリアルを書きました。学習から推論まで行うPython APIを使ったソースコードも公開しています。勉強にお役立て下さい。...

↑記事の終盤で学習済みモデル”MyChain.h5”を保存しました。

今回は、この学習済みモデル”MyChain.h5”を使って推論を実践していきます。

学習済みモデルとは

深層学習の一連の流れを簡単に説明すると大体こんな感じです↓

【深層学習フロー】

  1. データ収集
  2. ニューラルネットワークの構造(アーキテクチャ)設計
  3. 学習によりニューラルネットワークの重み(パラメータ)を調整
  4. 学習済みのニューラルネットワークを使って推論

言葉の定義は書籍などにより異なりますが、”学習済みニューラルネットワーク”のことを”学習済みモデル”と呼ぶことが多いです。

  • 深層学習では、”重み”の自動調整のことを”学習”と呼ぶ
  • “学習済みニューラルネットワーク”=”学習済みモデル”と考えて良い

ニューラルネットワークの保存形式 -構造と重み-

深層学習フレームワークで実装したニューラルネットワークは以下の形式で保存します。

ファイル 内容
HDF5ファイル ニューラルネットワークの構造(アーキテクチャ)
JSON / YAMLファイル ニューラルネットワークの重み(パラメータ)
HDF5ファイル ニューラルネットワークの構造と重み

ファイルについては使用するフレームワークに依存するため、あくまで一例ですが…

ここでは言いたいのは、ニューラルネットワークが構造/重み/構造と重みの3パターンで保存されるという点です。

チームメンバーが以下の言葉をどういう意図で使っているのかに注意しましょう

  • ニューラルネットワーク
  • モデル
  • 学習済みニューラルネットワーク
  • 学習済みモデル

Neural Network Libraries(NNabla)と学習済みモデル

“Neural Network Libraries”の場合、ニューラルネットワークは以下の形式で保存します。

ファイル 内容
net.nntxt ニューラルネットワークの構造(アーキテクチャ)
parameters.h5 ニューラルネットワークの重み(パラメータ)
net_param.nnp ニューラルネットワークの構造と重み

学習により自動調整されるのは”重み”なので、「学習済みモデル」または「学習済みニューラルネットワーク」といえば「学習済みの”重み”」または「構造と学習済みの”重み”」を保存したものと考えて問題ありません。

Neural Network LibrariesによるDeep Learningチュートリアル
【NNabla】実践!Neural Network Librariesで学習から推論まで直観的にニューラルネットワークの実装ができるソニー製の深層学習フレームワーク”NNL”によるDeep Learning(ディープラーニング)入門チュートリアルを書きました。学習から推論まで行うPython APIを使ったソースコードも公開しています。勉強にお役立て下さい。...

↑の記事で保存した学習済みモデルは”MyChain.h5”でした。

つまり、”重み”の情報のみが保存してあり、その重みを格納するニューラルネットワークの構造については「.nntx」を使うか、再度ソースコード上で定義する必要があります。

何を言っているのかよく分からない…

という人も、ソースコード見ればスッキリ理解できるかもしれないので、以降からはソースコードを交えて説明します。

【実践】学習済みモデル(HDF5ファイル)による推論

NNableのPython APIと学習済みモデル(HDF5ファイル)を使って推論を行います。

import

nnablaの各モジュールをimportします。推論では最適化アルゴリズム(Solvers)を使いません。

その他、今回使用する数値演算モジュールなどもimportします。

NNabla用の変数定義

NNablaで学習および推論を行うために専用の変数を定義します。

ニューラルネットワーク設計

繰り返しますが、学習済みモデル”MyChain.h5”には”重み”の情報のみ保存してあります。

つまり、ニューラルネットワークの構造(アーキテクチャ)の情報は保存していないので、改めて”重み”を格納するためのニューラルネットワークを定義する必要があります。

前回設計した”独自ニューラルネットワーク”を関数化しておいたので、関数”MyChain(x)”をコピペするだけでOKです。

繰り返し実装するコードは関数化しておくと移植性が向上するのでオススメ!

学習済みモデル(重み)を読込む

学習済みモデル”MyChain.h5”をロードします(”重み”を読込みます)。

たった1行で学習済みモデル(HDF5ファイル)の読込みができました。

推論から推論結果のグラフ化まで

以降からは、前回書いた「NNablaによる学習から推論まで」の記事で説明した推論フェーズのコードをコピペすればOKです。

NNablaと学習済みモデルで推論青線が学習で使用したデータ(x=0~9の整数)の推論結果赤プロットが未知のデータ(x=0~9の少数)に対する推論結果です。

前回「NNablaによる学習から推論まで」の記事で描いた推論結果のグラフと比較してみます。

NNablaで深層学習

赤プロットの推論結果は完全に一致しています。つまり、前回と同じ学習済みモデルで推論できた!ということです。

(こちらの青線は推論結果ではなく実験データなので、僅かに不一致ですね)

スポンサーリンク

学習済みモデル保存

学習済みモデル”MyChain.h5”は”重み”のみを保存していました。

ファイル 内容
net.nntxt ニューラルネットワークの構造(アーキテクチャ)
parameters.h5 ニューラルネットワークの重み(パラメータ)
MyChain.nnp ニューラルネットワークの構造と重み

今回は、ニューラルネットワークの構造と重みを”MyChain.nnp”で保存します。

↑のようにJSONフォーマットで情報を記述して、ニューラルネットワークの重みと構造を”MyChaine.nnp”に保存します。

これで次回からは学習済みモデル”MyChaine.nnp”のみ(ニューラルネットワークの関数MyChain(x)なし)で推論ができます。

  • *.nnpには学習済みモデル(重み+構造)が保存してあります。
  • *.nnpをロードすれば、構造の情報も取得できるため、ソースコード上でニューラルネットワークを定義しないで推論ができます。

はやぶさの技術ノート

本記事を書く前に、Jupyter Notebookで検証を行った『メモ付きソースコード(技術ノート)』があるので、公開します。

はやぶさの技術ノート:DL_Predict_HDF5

ちょっと散らかってけど、ソースコード見るだけなら”技術ノート”の方が便利だと思います。

まとめ

Sony製の深層学習フレームワーク “Neural Network Libraries(NNabla)” と学習済みモデル(HDF5)を使った推論を実践しました。

深層学習の学習フェーズではマシンパワーを必要とする場合がほとんどです。

しかし、組込み機器(エッジディバイス)は基本的に低コスト/低消費電力/小型化が求められるため、マシンパワーが小さいものが多いです。

そのため、組込み機器では”学習済みモデル”を使い推論フェーズのみ実行するのがスマートです。

本記事を読んで「学習済みモデルの使い方が分かった」・「組込み機器で深層学習やってみたい!」という人が現れたら、ディープラーニングお兄さんはすごく嬉しい!

はやぶさ
はやぶさ
理系応援ブロガー”はやぶさ”@Cpp_Learningは頑張る理系を応援します!

おまけ -組み込み深層学習-

Neural Network Librariesで生成した学習済みモデルとSony製のスマートセンシングプロセッサ搭載ボード”Spresense”を使って推論ができるそうです。ほしい…!

SONY SPRESENSE カメラモジュール CXD5602PWBCAM1

PICK UP BOOKS

  • 数理モデル入門
    数理モデル
  • Jetoson Nano 超入門
    Jetoson Nano
  • 図解速習DEEP LEARNING
    DEEP LEARNING
  • Pythonによる因果分析
    Python