以前、Chainerによる深層学習入門チュートリアルの記事を書いています↓
【深層学習入門】超実践!Chainerと深層学習でシステム解析する方法|はやぶさの技術ノート
これと同じことをSony製の深層学習フレームワーク "Neural Network Libraries(NNabla)" で実践してみます。
nnablaの各モジュールをimportします。
import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solvers as S
その他、今回使用する(定番)モジュールもimportします。
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 実験データ用の配列
x = []
y = []
get_values = 0
for i in range(10):
get_values = random.random()
x.append([i])
y.append([get_values])
# データフレーム生成
df = pd.DataFrame({'X': x,
'Y': y})
# グラフ出力
plt.plot(x, y)
plt.title("Training Data")
plt.xlabel("input_x")
plt.ylabel("output_y")
plt.grid(True)
df
batch_size = 1
xt = nn.Variable((batch_size, 1))
yt = nn.Variable((batch_size, 1))
x = np.array(x)
y = np.array(y)
def MyChain(x):
h1 = F.relu(PF.affine(x, 50, name = "l1"))
h2 = F.relu(PF.affine(h1, 100, name = "l2"))
y = PF.affine(h2, 1, name = "l3")
return y
t = MyChain(xt)
loss = F.mean(F.squared_error(t, yt))
solver = S.Adam()
solver.set_parameters(nn.get_parameters())
loss_list = []
step = []
for i in range(0, 80000):
n = random.randrange(10)
xt.d = x[n]
yt.d = y[n]
loss.forward()
solver.zero_grad()
loss.backward()
solver.update()
loss_list.append(loss.d.copy())
step.append(i)
if i % 1000 == 0: # Print for each 1000 iterations
print(i, loss.d)
# 学習過程のグラフ
plt.plot(step, loss_list)
plt.title("Training Data")
plt.xlabel("step")
plt.ylabel("loss")
plt.grid(True)
plt.show()
'''
# 学習過程グラフを一部拡大
plt.plot(step, loss_list)
plt.xlim([50000,70000])
plt.ylim([0,0.002])
plt.title("Training Data")
plt.xlabel("step")
plt.ylabel("loss")
plt.grid(True)
plt.show()
'''
学習済みモデルを使って推論を行います。基本的には学習のとき同様の手順です。
xm = nn.Variable((batch_size, 1))
ym = MyChain(xm)
本来はセンサ値を使うが、今回は適当な学習で使用していない未知のデータを生成した。
xe = np.array([[0.5], [1.8], [2.3], [3.3], [4.5], [5.4], [6.3], [6.7], [7.4], [8.2]])
ym_list = []
for i in xe:
xm.d = i
ym.forward()
ym_list.append(ym.d.copy())
推論結果を確認します。
print(xe)
ym_list
グラフ化したいのでxmに合わせてymを形状を変換します。
ym_test = np.reshape(ym_list, [10, 1])
ym_test
plt.plot(x, y)
plt.plot(xe, ym_test, "ro")
plt.title("comparison")
plt.xlabel("input")
plt.ylabel("output")
plt.grid(True)
plt.show()
青グラフが解析対象・赤プロットがx(少数)をMyChainに入力したときの推論結果
うん!イイ感じですね!!
良いモデルを生成できたので、このモデルを保存しておきます。
そうすれば、次回からは学習せずに、このモデル使って推論できます。
nn.save_parameters("MyChain.h5")