Python PR

Pyroでガウス混合モデル(GMM)をつくる【ベイズ統計モデリング入門】

pyroでGMM
記事内に商品プロモーションを含む場合があります

こんにちは。

現役エンジニアの”はやぶさ”@Cpp_Learningです。最近は統計モデリングを勉強しています。

今回はガウス混合モデル(GMM:Gaussian Mixture Model)について勉強したので、備忘録も兼ねて本記事を書きます。

GMMとは

あるデータが複数のガウス分布から発生していると考えたとき、GMMにより各ガウス分布の平均・分散を予測することができます。

また各データが、どのガウス分布から発生したかも予測できるため、クラスタリングにも活用できます。

pyroでGMM

例えば上図は、クラスタ数が3のときのGMMによる推論のイメージです。

混合モデル(GMM含む)について、もっと勉強したい人のために、本記事の最後でオススメの本を紹介します

実践!Pyroでガウス混合モデル(GMM)をつくる

sklearnを活用することで、簡単にGMMによる推論を実現できますが、今回はPyroでゼロからガウス混合モデルをつくってみます。

Import

まずはimportから

サンプルデータ作成

GMMの説明とは逆の手続きで、サンプルデータを作成します(下図参照)。

ガウス混合

つまり3つのガウス分布からサンプルを生成し、そのサンプルを結合したデータを作成します。

Pyroを活用すれば、簡単に任意の確率分布からサンプルを生成することができます。今回の場合、以下のコードを使います。

ガウス混合

ベイズ統計モデリング

今回は以下のモデルを設計します。

モデルの概要は以下の通りです。

GMM
  • K個の2次元ガウス分布から平均値(locs)を生成(※1)
  • 共分散行列(scales)は単位行列で決め打ち(※2)
  • ディリクレ分布とカテゴリ分布の組み合せで、データ割り当て(クラスタ z = 0 or 1 or 2)
  •  各データが2次元ガウス分布(μ=locs[z], σ=scales)に従って発生(クラスタごとに μ と σ が存在する)

(※1) 2次元ガウス分布なので、円の中心=平均値=[x, y]
(※2) 分散を固定しているが、何らかの確率分布から発生させても良い

今回はシンプルに2次元ガウス分布の平均値(円の中心)のみを推論し、分散(円の大きさ)については、単位行列(単位円)で決め打ちします。

MCMC実行

以下のコードでMCMCによる推論を実行します。

推論結果の可視化

以下のコードでMCMCで算出された各パラメータを取得できます。

以下のコードでクラスタごとの平均値を抽出し、さらに平均化したものを各2次元ガウス分布の平均値(円の中心)とします。

以下のコードで推論結果を可視化します。

GMMによる平均値の推論

MCMC実行で num_samples=100 と設定したので、上図(左)の赤プロット(平均値)は100×3個ありますが、平均化することで上図(右)のように1点に絞り込みました。

平均値の真値と推論結果を比較したものが以下です。

真値:平均値(x, y) 推論結果:平均値(x, y)
Sample1 (0, 5) (0.0230, 4.9318)
Sample2 (5, 0) (4.8584, 0.1352)
Sample3 (8, 8) (7.8621, 7.9924)

綺麗に予測できました。

スポンサーリンク

まとめ

確率的プログラミング言語のPyroを活用し、ガウス混合モデル(GMM)を自作しました。

sklearnなどのライブラリを使えば、簡単にGMMを活用できますが『ゼロからモデリング』することで理解が深まったと感じています。

本記事の内容について間違っている箇所やアドバイスなどあれば、教えて頂けると嬉しいです。

はやぶさ
はやぶさ
統計モデリングに関する情報共有をしながら、一緒に成長できると嬉しいです。よろしくお願いします。

以下 オススメの本と関連記事の紹介

PICK UP BOOKS

  • 数理モデル入門
    数理モデル
  • Jetoson Nano 超入門
    Jetoson Nano
  • 図解速習DEEP LEARNING
    DEEP LEARNING
  • Pythonによる因果分析
    Python