機械学習

Pyroと変分推論でベイズ回帰モデルをつくる

Pyroと変分推論でベイズ回帰モデルをつくる

こんにちは。

現役エンジニアの”はやぶさ”@Cpp_Learningです。最近は統計モデリングを勉強しています。

今回は確率的プログラミング言語Pyroによるベイズ統計モデリングに挑戦したので、備忘録も兼ねて本記事を書きます。

まえおき

確率的プログラミング言語については、以下の記事で説明済みなので、本記事では割愛します。

NumPyroでベイズ回帰モデルをつくる
【ベイズ統計モデリング入門】NumPyroで回帰モデルをつくるこんにちは。 現役エンジニアの”はやぶさ”@Cpp_Learningです。最近は統計モデリングを勉強しています。 今回は確率...

また上の記事では事後分布の算出にMCMCを使いましたが、本記事では変分推論(variational inference)を使います。

変分推論ではなく、変分近似(variational approximation)変分ベイズ(variational Bayes)と呼ぶこともあります(以下 参考記事)。

変分近似(Variational Approximation)の基本(1)

変分推論とは

機械学習とは『モデルの出力結果真値との距離(loss)を最小化するアルゴリズムです』と説明するなら…

変分推論とは『モデルの確率分布データから得られた真の確率分布間距離(KLダイバージェンス)を最小化するアルゴリズムです』という説明ができます。

また機械学習目的関数(MSE:平均二乗誤差 など)からの出力結果(loss)を評価指標に使うのに対し、変分推論ではELBOなどを評価指標に使います。

もっと詳しく知りたい人のために、以下の本をオススメしておきます。

スポンサーリンク

実践!Pyroでベイズ回帰モデルをつくる -変分推論-

MCMCと比較できるように、以下の記事と同じ問題設定会計総額からチップ額を予測ベイズ回帰モデル + 変分推論を活用します。

NumPyroでベイズ回帰モデルをつくる
【ベイズ統計モデリング入門】NumPyroで回帰モデルをつくるこんにちは。 現役エンジニアの”はやぶさ”@Cpp_Learningです。最近は統計モデリングを勉強しています。 今回は確率...

インストール

公式サイトに従い、先にPyTorchをインストールしてください。

PyTorchがインストール済みなら、以下のコマンドで Pyro をインストールするだけでOKです。

pip install pyro-ppl

以降からコード書いていきます。

Import

まずはimportから

データ可視化

使用するデータを可視化します。

可視化

説明変数と目的変数 -torch.tensorに変換

説明変数:X と 目的変数:Y を定義します。

※PyroのバックエンドがPyTorchなので、torch.tensorを扱います

ベイズ統計モデリング -線形回帰モデル-

この記事と同じモデルを設計します。

近似事後分布

データから得られる複雑な確率分布(以降 真の事後分布 と呼ぶ)が、簡単な確率分布(以降 近似事後分布 と呼ぶ)でも近似できると考え、以下のように近似事後分布を設計します。

以下のように既に準備してあるものを使うこともできます。

※今回は AutoDiagonalNormal を使います

各種設定

変分推論のインターフェースを作ります。

※Adamは深層学習でもお馴染みですね

変分推論

以下のコードで指定した回数だけ推論(機械学習でいう学習)を実行します。

[iteration 0001] loss: 4904.6740
[iteration 0101] loss: 3820.1017
[iteration 0201] loss: 4124.2043

[iteration 19701] loss: 1.5044
[iteration 19801] loss: 1.4919
[iteration 19901] loss: 1.4918

途中経過を割愛してますが、上記のようにlossが下がっている様子を確認できます。

これは近似事後分布真の事後分布との距離が近くなる変分パラメータを推論できている、あるいは 真の事後分布簡単な確率分布近似できる変分パラメータを推論できている…といえます。

変分パラメータ取得

変分パラメータは pyro.get_param_store() に辞書型で格納されます。なので以下のコードでkeyとvalueを確認しておきます。

今回は以下のコードで変分パラメータを取得できます。

変分パラメータの格納形式について、以下の記事が参考になりました。

事後分布の可視化

変分パラメータを使って事後分布を可視化します。

変分近似

推論

以下のコードで X=0, 1, …, 49, 50 のときの Y を予測します。

以下のコードで予測結果を可視化します。

変分ベイズ

今回はMCMCとほとんど同じ結果が得られました。

【おまけ】ベイズ回帰やるなら、どの手法を使う?

本記事を書く前に以下のアンケートを行いました。

回答してくれた人ありがとうございました。MAP推定を使っている人や…

くるる
くるる
ベイズ深層学習は手法なのか?

と考えて、回答に悩んだ人がいたら、ごめんなさい。

変分推論 ユーザーは少ないようですが、本記事が参考になれば嬉しいです。

まとめ -変分ベイズ入門-

ベイズ統計モデリングについて勉強し、Pyroで実践(変分推論)した内容をまとめました。

本記事の内容について間違っている箇所があれば、教えて頂けると嬉しいです。

はやぶさ
はやぶさ
統計モデリングに関する情報共有をしながら、一緒に成長できると嬉しいです。よろしくお願いします。

以下 本記事の前半で紹介した書籍を改めてオススメしておきます。

 

Amazonギフト券チャージで最大2.5%ポイント還元
Amazonはチャージがお得

Amazonプライム会員なら、Amazonギフト券を現金でチャージ(コンビニ・銀行払い)すると最大2.5%ポイント還元!

クレジットカード払いでもキャンペーンエントリー0.5%ポイント還元中です。

Amazonでお得に買い物をするならまずはチャージから。

\チャージがお得/

詳細をチェックする

Amazonプライム無料体験中でもOK!

PICK UP BOOKS

  • 数理モデル入門
    数理モデル
  • Jetoson Nano 超入門
    Jetoson Nano
  • 図解速習DEEP LEARNING
    DEEP LEARNING
  • Pythonによる因果分析
    Python