多変量(多次元)正規分布のKLダイバージェンスの求め方

機械学習界隈では多変量正規分布のKLダイバージェンスの導出は自明らしく、とくに説明もなく「はいこうなりますね〜簡単ですね〜ははは〜」みたいな感じで軽く流されて死にそうになる。
軽く流されると私のように死んでしまう人もいるかもしれないので導出方法をメモしておく。


前準備

KLダイバージェンスは分布Pに対して分布Qがどれだけ近いかを表し、定義は以下のとおり。

KL(P(x) || Q(x))
= ∫P(x) log(P(x) / Q(x)) dx
= ∫P(x) log(P(x)) dx - ∫P(x) log(Q(x)) dx

また多変量正規分布の定義は以下のとおり。

P(x | μ, Σ)
= ((2π)^d * |Σ|)^(-1/2) * exp(-1/2 * (x - μ)T Σ^-1 (x - μ))

μ: 平均(d次元(縦)ベクトル)
Σ: 共分散行列(d次正方行列)
x: データ点(d次元(縦)ベクトル)

(x - μ)T: (x-μ)を転置させたもの(横ベクトル)
(x - μ)T Σ^-1 (x - μ): 二次形式。スカラー

この多変量正規分布の式をKLダイバージェンスの式に突っ込んで計算すればいいだけなのだが、そのまま突っ込むと式がごちゃごちゃして見難くなる。なのでまずは多変量正規分布の式を整理しておく。

P(x | μ, Σ) = Z(Σ) * exp(S(x, μ, Σ))

Z(Σ) = ((2π)^d * |Σ|)^(-1/2)
S(x, μ, Σ) = -1/2 * (x - μ)T Σ^-1 (x - μ)

基本的にはこのP=Z*exp(S)の状態で計算をして必要な時だけ展開する。また対称行列Aの二次形式(xT A x)は(A x xT)のトレース(対角成分の和)に等しいので、以下が成り立つ。

xT A x = Tr(A x xT)より

S(x, μ, Σ)
= -1/2 * (x - μ)T Σ^-1 (x - μ) ・・・(1)
= -1/2 * Tr(Σ^-1 (x - μ)(x - μ)T) ・・・(2)

S(x, μ, Σ)を(2)の形式にするとΣに依存する部分とx,μに依存する部分を切り分けることができるので必要に応じて利用する。


KLダイバージェンスの導出(前編)
多変量正規分布のKLダイバージェンス、つまり以下を計算する。

KL(P(x | μ1, Σ1) || Q(x | μ2, Σ2))

= ∫P(x | μ1, Σ1) * log (P(x | μ1, Σ1) dx
- ∫P(x | μ1, Σ1) * log (Q(x | μ2, Σ2) dx

まずは多変量正規分布P=Z*exp(S)をこの式に代入する。

= ∫P(x | μ1, Σ1) * log (Z(Σ1) * exp(S(x, μ1, Σ1))) dx
- ∫P(x | μ1, Σ1) * log (Z(Σ2) * exp(S(x, μ2, Σ2))) dx

ここでlog(Z * exp(S))をlog(Z) + Sと展開できる。

= ∫P(x | μ1, Σ1) * (log(Z(Σ1)) + S(x, μ1, Σ1)) dx
- ∫P(x | μ1, Σ1) * (log(Z(Σ2)) + S(x, μ2, Σ2)) dx

さらにP(log(Z) + S) = P*logZ + P*Sとなるので

= ∫P(x | μ1, Σ1) * log (Z(Σ1) dx
+ ∫P(x | μ1, Σ1) * S(x, μ1, Σ1) dx
- ∫P(x | μ1, Σ1) * log (Z(Σ2) dx
- ∫P(x | μ1, Σ1) * S(x, μ2, Σ2) dx

∫ P*log(Z) dxについて、log(Z)はxに依存していないので積分の外に出せてlog(Z) ∫P dxとなる。∫P dxはとりうる全xの確率を足しているので=1になる。よって∫ P*log(Z) dx = log(Z)となるから

= log(Z(Σ1)) - log(Z(Σ2)) 
+ ∫P(x | μ1, Σ1) * S(x, μ1, Σ1) dx
- ∫P(x | μ1, Σ1) * S(x, μ2, Σ2) dx

= log(Z(Σ1) / Z(Σ2)) 
+ ∫P(x | μ1, Σ1) * S(x, μ1, Σ1) dx
- ∫P(x | μ1, Σ1) * S(x, μ2, Σ2) dx


KLダイバージェンスの導出(後編)
だいぶ整理されたので、ここから多変量正規分布P=Z*exp(S)のZとSを展開していく。まずはZから。

log(Z(Σ1) / Z(Σ2))
= log(
((2π)^d * |Σ1|)^(-1/2)
/ ((2π)^d * |Σ2|)^(-1/2)
)

= log( (|Σ1| / |Σ2|)^(-1/2) )
= log( (|Σ2| / |Σ1|)^(1/2) )
= 1/2 * log(|Σ2| / |Σ1|)

これを前編の最後の式に入れるとめでたくZが式から消える。

= log(Z(Σ1) / Z(Σ2)) 
+ ∫P(x | μ1, Σ1) * S(x, μ1, Σ1) dx
- ∫P(x | μ1, Σ1) * S(x, μ2, Σ2) dx

= 1/2 * log(|Σ2| / |Σ1|)
+ ∫P(x | μ1, Σ1) * S(x, μ1, Σ1) dx
- ∫P(x | μ1, Σ1) * S(x, μ2, Σ2) dx

次にSについて。Sについては(1)と(2)で2つの形式があったが、ここでは(2)を使う。

1/2 * log(|Σ2| / |Σ1|)
+ ∫P(x | μ1, Σ1) * S(x, μ1, Σ1) dx
- ∫P(x | μ1, Σ1) * S(x, μ2, Σ2) dx

= 1/2 * log(|Σ2| / |Σ1|)
+ ∫P(x | μ1, Σ1) * -1/2 * Tr(Σ1^-1 (x - μ1)(x - μ1)T) dx
- ∫P(x | μ1, Σ1) * -1/2 * Tr(Σ2^-1 (x - μ2)(x - μ2)T) dx

代入したらxに依存しない部分を積分の外に出す。∫Tr(A x xT) dx = Tr(A ∫ x xT dx)なので

= 1/2 * log(|Σ2| / |Σ1|)
- 1/2 * Tr(Σ1^-1 * ∫(x - μ1)(x - μ1)T * P(x | μ1, Σ1)dx)
+1/2 * Tr(Σ2^-1 * ∫(x - μ2)(x - μ2)T * P(x | μ1, Σ1)dx)

ここで∫ f(x)P(x) dxはf(x)の期待値(E[f(x)])であることを思い出すと、∫ (x-μ)(x-μ)T P(x) dxは(x-μ)(x-μ)Tの期待値を表していることがわかる。
また縦ベクトルxについて行列(x xT)はij成分がx_i * x_jとなる行列なので(x-μ)(x-μ)Tはij成分が(x_i - μ_i) * (x_j - μ_j)となる行列である。この行列の期待値は共分散行列の定義そのものなので=Σとなる。つまり∫ (x-μ)(x-μ)T P(x) dx = Σとなる。これを用いると2番目の項を簡単にできる。

= 1/2 * log(|Σ2| / |Σ1|)
- 1/2 * Tr(Σ1^-1 * Σ1)
+1/2 * Tr(Σ2^-1 * ∫(x - μ2)(x - μ2)T * P(x | μ1, Σ1)dx)

= 1/2 * log(|Σ2| / |Σ1|)
- 1/2 * Tr(I)  # Σ1にΣ1の逆行列を掛けたので単位行列Iになった
+ 1/2 * Tr(Σ2^-1 * ∫(x - μ2)(x - μ2)T * P(x | μ1, Σ1)dx)

= 1/2 * log(|Σ2| / |Σ1|)
- 1/2 * d  # d次単位行列の対角成分の和はdなので
+ 1/2 * Tr(Σ2^-1 * ∫(x - μ2)(x - μ2)T * P(x | μ1, Σ1)dx)

3番目の項についてはμ1とμ2が混在しているので2番目の項のように簡単にすることができない。よって邪魔なμ2を取り除くことを考える。
Tr((x - μ)(x - μ)T) = Tr(x xT) - Tr(2 x μT) + Tr(μ μT)なので

= 1/2 * log(|Σ2| / |Σ1|)
- 1/2 * d
+ 1/2 * Tr(Σ2^-1 * ∫x xT * P(x | μ1, Σ1)dx)
- 1/2 * Tr(Σ2^-1 * ∫2x μ2T * P(x | μ1, Σ1)dx)
+ 1/2 * Tr(Σ2^-1 * ∫μ2 μ2T * P(x | μ1, Σ1)dx)

ここで∫x xT * P dxはx xTの期待値、∫x * P dxはxの期待値=μ、∫P dx=1なので

= 1/2 * log(|Σ2| / |Σ1|)
- 1/2 * d
+ 1/2 * Tr(Σ2^-1 * E[x xT])
- 1/2 * Tr(Σ2^-1 * 2μ1 μ2T)
+ 1/2 * Tr(Σ2^-1 * μ2 μ2T)

またTr(E[x xT]) - Tr(μ μT) = Tr(Σ)なので

= 1/2 * log(|Σ2| / |Σ1|)
- 1/2 * d
+ 1/2 * Tr(Σ2^-1 * E[x xT])
- 1/2 * Tr(Σ2^-1 * μ1 μ1T)
+ 1/2 * Tr(Σ2^-1 * μ1 μ1T)
- 1/2 * Tr(Σ2^-1 * 2μ1 μ2T)
+ 1/2 * Tr(Σ2^-1 * μ2 μ2T)

= 1/2 * log(|Σ2| / |Σ1|)
- 1/2 * d
+ 1/2 * Tr(Σ2^-1 * Σ1)
+ 1/2 * Tr(Σ2^-1 * (μ2 - μ1) (μ2 - μ1)T)

またTr(A x xT) = xT A xだったので

= 1/2 * log(|Σ2| / |Σ1|)
- 1/2 * d
+1/2 * Tr(Σ2^-1 * Σ1)
+1/2 * (μ2 - μ1)T Σ2^-1 (μ2 - μ1)

=(log(|Σ2| / |Σ1|) + Tr(Σ2^-1 * Σ1) + (μ2 - μ1)T Σ2^-1 (μ2 - μ1) - d) / 2

これが多変量正規分布のKLダイバージェンスとなる。これだけの計算を「自明ですよね」の一言ですませる機械学習のプロたちは恐ろしいと言わざるを得ない。怖い。


結論

KL(P(x | μ1, Σ1) || Q(x | μ2, Σ2))
=(log(|Σ2| / |Σ1|) + Tr(Σ2^-1 * Σ1) + (μ2 - μ1)T Σ2^-1 (μ2 - μ1) - d) / 2