機械学習

【NNabla】実践!Neural Network Librariesで学習から推論まで

Neural Network LibrariesによるDeep Learningチュートリアル

こんにちは。

ディープラーニングお兄さんの”はやぶさ”@Cpp_Learningです。

前回、Sony製の深層学習フレームワーク “Neural Network Libraries(NNabla)” の環境構築について書きました↓

Neural Network Libraries
【NNabla】Windows(WSL)にNeural Network Librariesをインストールするこんにちは。 ディープラーニングお兄さんの”はやぶさ”@Cpp_Learningです。 深層学習フレームワークのChaine...

今回は、NNablaのPython APIを使って学習から推論までを実践していきます!

深層学習で解決する課題

本ブログ”はやぶさの技術ノート”では、機械学習のチュートリアル記事を無料公開しています。

公開している記事の中に「深層学習入門のチュートリアル記事」もあります↓

深層学習でシステム解析
【深層学習入門】超実践!Chainerと深層学習でシステム解析する方法ディープラーニング入門の記事を書きました。難易度は入門レベルですが『深層学習フレームワークChainerを使ってシステム解析する』という実践的な内容に仕上げました。制御・解析・分析などの課題解決に深層学習を使いたい人や、深層学習をビジネスや研究で使い人にオススメの記事です。...

この記事の一部を引用したもの↓

本記事では深層学習フレームワーク”Chainer”を使って、非線形回帰モデルの生成を行いますが、重要なのは、どのフレームワークを使うかではなく、『深層学習でシステム解析が行える』という考え方です。

はい。この記事では”Chainer”を使って課題解決のソースコードを作成しました。

諸事情あって(?)社内で使用するフレームワークを選択できない人がいるかもしれませんが…

あまり何を使うかに拘らず、色々なフレームワークを使えると良い事があるかも(*・ω・)ノ♪

というわけで、本記事では”NNabla”を使って課題解決のソースコードを作成します。

本記事では『深層学習でシステム解析する方法』についての説明は割愛し、NNablaの説明に注力します。なので、本記事を読む前に、上記した”深層学習入門の記事”を読んでおくことをオススメします!

Neural Network Librariesの特徴

“Neural Network Libraries(NNabla)” の大部分(コア)はC++で実装してあり、Python APIとC++ APIを使うことで、直観的にニューラルネットワークを実装することができます。

Neural Network LibrariesとC++

つまり、Python・C++の両言語で学習と推論ができます。

はやぶさ
はやぶさ
組込み機器で深層学習(推論)するならC++が良さそう!

と考えており、”Neural Network Libraries”はとても魅力的だと感じています!

スポンサーリンク

【実践】NNablaと深層学習 -学習フェーズ-

C++ APIが気になりますが、今回はPython APIを使って学習と推論を行います。

import

nnablaの各モジュールをimportします。

その他、今回使用する数値演算モジュールなどもimportします。

ダミーの実験データ生成

訓練データ…つまり、ダミーの実験データを生成します。

NNabla用の訓練データ

乱数で実験データは生成しましたが、これまた解析が難しそうなデータですね…

”あなた”はxとyの相関”y=f(x)”の数式モデル”f(x)”を算出できますか?

非線形で非常に難しいですよね…?

なので、深層学習を使って数式モデル”f(x)”を算出してもらいます(*・ω・)ノ♪

list ⇒ numpy変換

訓練データをnumpy配列に変換します。

NNabla用の変数定義

NNablaで学習および推論を行うために専用の変数を定義します。

ニューラルネットワーク設計

『1入力1出力で中間層のノード数が50 ⇒ 100』のニューラルネットワークを設計します。また、活性化関数(Activation Functions)には“relu”を採用します。

たった5行で独自ニューラルネットワーク”MyChain”を設計できました。

NNablaによるニューラルネットワークの定義には、いくつか書き方があります。好みもありますが、↑の関数を使うコードがシンプルで扱いやすいと思います。

ニューラルネットワーク(NN)モデルの宣言

”MyChain(x)”を呼び出します。

引数に『NNabla用の変数”xt”を使います。

【ニューラルネットワーク”MyChain”の特徴】

  • 入力値はNNabla用の変数”xt”
  • 戻り値”t”はニューラルネットワーク”MyChain”の推論結果

損失関数と最適化アルゴリズムを定義(+パラメータ登録)

学習で使う”損失関数(Loss Functions)””最適化アルゴリズム(Solvers)”を定義します。

※”t”が推論値・”yt”が実験データ(教師データ)

  • 損失関数に”MSE”を採用
  • 最適化アルゴリズムに”Adam”を採用

学習

前回、色々と試してみて”学習回数:8万回”で上手くいくことが分かっていたので、今回も8万回学習させます。

あと学習回数:1,000回ごとにlossのログをとります。

【学習回数とloss】

0 0.00045301687
1000 3.0827425e-06
2000 1.8982263e-08
3000 4.5757124e-06
4000 2.4628455e-06
5000 0.0015066643

75000 6.3268146e-10
76000 3.8131259e-06
77000 9.984251e-07
78000 1.1723699e-07
79000 2.4714078e-07

ターミナルに↑のような出力をします。以下のコードで学習過程を可視化します。

深層学習による学習過程

徐々にlossが減少していき、”0”付近で収束しました。良いモデルを生成できた気がします。

以上で深層学習による学習フェーズ終了です。

【実践】NNablaと深層学習 -推論フェーズ-

次は、学習済みモデルを使って推論を行います。

NNabla用の変数定義とモデルの宣言

学習時と同様に「NNabla用の変数定義」および「モデルの宣言」を行います。

推論したいデータ(テストデータ)用意

適当に学習で使用していない未知のデータ(少数)を用意します。

今回は、最初からnumpy配列にしました。

推論

以下のコードで推論します。

推論結果を”ym_list”に格納するコードがありますが、そのコードを除けば、たった3行で推論できます。

推論結果

“ym_list”の中身をターミナルに出力しても良いのですが、推論値だけで推論精度の良し/悪しを判断するのは難しいので、可視化します。

用意したテストデータ”xe”に合わせて推論値”ym_list”の形状を変換します。あとは、いつものように”matplotlib”でグラフ化します。

NNablaで深層学習

青色の真値(実験データ)グラフ上未知データに対する推論結果の赤プロットが乗っています。(少しずれているプロットもあるけど…)

良い精度で推論できてますね!つまり、良い学習済みモデルが生成できたということです。

学習済みモデル保存

満足のいく学習済みモデルを生成できたら、そのモデルを保存しましょう(*・ω・)ノ♪

たった1行で学習済みモデルの”重み”を保存できました。

これで次回からは学習フェーズを飛ばして学習済みモデル”MyChain.h5”による推論ができます。

はやぶさの技術ノート

本記事を書く前に、Jupyter Notebookで検証を行った『メモ付きソースコード(技術ノート)』があるので、公開します。

はやぶさの技術ノート:DL_System_Analysis

ソースコード見るだけなら”技術ノート”の方が便利だと思います。

スポンサーリンク

まとめ

Sony製の深層学習フレームワーク “Neural Network Libraries(NNabla)” を使って学習から推論までを実践してみました。

同じ課題に対して、別のツール/ライブラリ/言語を使うと理解が深まる場合があります。

以前書いた”深層学習入門の記事”と本記事を読み比べて…

「深層学習の理解が深まった!」・「深層学習についてをもっと勉強したい!」という人が現れたら、ディープラーニングお兄さんはすごく嬉しい!

はやぶさ
はやぶさ
理系応援ブロガー”はやぶさ”@Cpp_Learningは頑張る理系を応援します!

おまけ -勉強方法とほしいものリスト-

時間に余裕のある人は読んでみてね↓

勉強方法 -教材紹介-

「もっと勉強したい!」というモチベーションの高い人のために以下の記事を紹介!

勉強で得られるもの
機械学習の勉強法!キカガク流 Udemy講座の感想など -エンジニア目線-こんにちは。 現役エンジニアの”はやぶさ”@Cpp_Learningです。 【深層学習チュートリアル】などの技術ブログを書い...

私が今までに使用した教材を紹介しているので、参考になるかも(*・ω・)ノ♪

ほしいものリスト

Neural Network Librariesで学習したモデルは、Sony製のスマートセンシングプロセッサ搭載ボード”Spresense”で実行できるそうです。ほしい…!

Spresenseメインボード

Arduino IDE や専用のSDKを使って簡単にIoTシステムを実現できるようです

Spresense拡張ボード

Arduino UNO 互換のピンソケット

Spresenseカメラボード

Sony製CMOSイメージセンサー

SONY SPRESENSE カメラモジュール CXD5602PWBCAM1
LINEスタンプ配信中!

フクロウのLINEスタンプ

当サイトのマスコットキャラクター

「フクロウのくるる」が

LINEスタンプになりました!

勉強で疲れたあなたに癒しをお届け☆

お迎え待ってます(*・ω・)ノ♪

今すぐお迎えする

40個セットがたったの50コインとお得です