NMFについてまとめ

ことの発端はこれ→NMFを勉強中 - シコウサクゴ()で、知ってる人に教えてもらったりしてやっと理解できたので、まとめ。今日はちゃんと書く・・・つもり。ベースは[1]の論文だが、使っている記号は全然それに則ってないのであしからず。

NMFとは

NMF(Non-negative Matrix Factorization: 非負値行列分解)では、ある非負行列Vを2つの行列の積WHに分解する。つまり、

V \approx WH

となるような、W, Hを決めることである。どの世界でどのように使われているかはよく分からないけど、恐らく画像処理、音声処理、そして自然言語処理で使われているはず。一般に、Wは特徴的なパターンを表し、HWのパターンの出現回数と見ることができる。

コスト関数

どうやって求めるかというと、コスト関数を最小化するいわゆる最適化の方法で求める。コスト関数は、二乗誤差バージョンとダイバージェンスバージョンがあって、二乗誤差バージョンは
 J(W,H) = \sum_{i,j} [ v_{ij} - \sum_k (w_{ik} h_{kj}) ]^2 = \sum_{i,j} [v_{ij}^2 - 2v_{ij}\sum_k (w_{ik} h_{kj}) + (\sum_k w_{ik} h_{kj})^2]
 v_{ij}^2の項は観測Vのみに依存する定数なので最適化に直接関係ないから
 J(W,H) = \sum_{i,j} [- 2v_{ij}\sum_k (w_{ik} h_{kj}) + (\sum_k w_{ik} h_{kj})^2]
と書き変えられる。また、ダイバージェンスバージョンは
 J(V||WH) = v_{ij}\log\frac{v_{ij}}{\sum_k w_{ik}h_{kj}} - v_{ij} + \sum_k w_{ik} h_{kj}
である。どちらを使うかはタスクによるとのこと。二乗誤差の場合は誤差平面が対称で、ダイバージェンスは非対称なので、例えば自然言語処理でトピック分類として使う場合はダイバージェンスを使ったほうが直感的に合っていていいらしい。ここでは二乗誤差の場合のみを考えていく。ダイバージェンスはまた今度やるかも。

高速な解法:補助関数法

直接J(W,H)を最小化することはやめて、補助関数A(W,H,R)を導入する。突然出てきたRという変数は、J(W,H)A(W,H,R)の間でイエンセン(Jensen)の不等式が成り立つように決めている。つまり、
 J(W,H) \leq A(W,H,R)
 J(W,H) = \min_R A(W,H,R)
を満たすようにRを決める。もし、この2つの条件を満たすようにRが決められれば、以下の2つのステップを繰り返すことで誤差を最小にすることができる。

  1. A(W,H,R)Rについて最小化
  2. A(W,H,R)W,Hについて最小化

補助関数の導出

それでA(W,H,R)をどうやって決めるのか。上で出てきた目的関数
 J(W,H) = \sum_{i,j} [- 2v_{ij}\sum_k (w_{ik} h_{kj}) + (\sum_k w_{ik} h_{kj})^2]
のi,jの総和についての第二項
(\sum_k w_{ik} h_{kj})^2
に着目すると、イエンセンの不等式から
(\sum_k w_{ik} h_{kj})^2 \leq \sum_k \frac{(w_{ik}h_{kj})^2}{r_{ijk}}
をつくることが出来る*1。これを使って、
A(W,H,R) = \sum_{i,j} [- 2v_{ij}\sum_k (w_{ik} h_{kj}) + \sum_k \frac{(w_{ik}h_{kj})^2}{r_{ijk}} ]
とする。

A(W,H,R)Rについて最小化

Rについて最小化することは、(\sum_k w_{ik} h_{kj})^2 \leq \sum_k \frac{(w_{ik}h_{kj})^2}{r_{ijk}}で等式になるようなRを求めることと等しい。一般に、イエンセンの不等式で等号が成り立つ場合というのは、
 f(\sum_k p_k x_k) = \sum_k p_k f(x_k) ただし、 p_k \geq 0, \sum_k p_k = 1で、x_1=\cdots=x_K
なので、A(W,H,R)の場合では、
 \frac{w_{i1}h_{1j}}{r_{ij1}} = \cdots = \frac{w_{iK}h_{Kj}}{r_{ijK}} = const だから、 r_{ijk} = \frac{w_{ik}h_{kj}}{const}
 \sum_k r_{ijk} = 1という条件を使って、
 \sum_k r_{ijk} = \sum_k \frac{w_{ik}h_{kj}}{const} = 1 \Longleftrightarrow const = \sum_k w_{ik} h_{kj}となり、r_{ijk} = \frac{w_{ik} h_{kj}}{\sum_k w_{ik} h_{kj}}の点で最小となる。

A(W,H,R)W,Hについて最小化

A(W,H,R)w_{ik}h_{kj}偏微分して0と置いて解いてみると、
\frac{\partial A}{\partial w_{ik}} = \sum_j [ -2v_{ij}h_{kj} + 2\frac{(w_{ik}h_{kj})h_{kj}}{r_{ijk}} ] = 0 \Longrightarrow w_{ik} = \frac{\sum_j v_{ij} h_{kj}}{\sum_j \frac{h^2_{kj}}{r_{ijk}}}
\frac{\partial A}{\partial h_{kj}} = \sum_i [ -2v_{ij}w_{ik} + 2\frac{(w_{ik}h_{kj})w_{ik}}{r_{ijk}} ] = 0 \Longrightarrow h_{kj} = \frac{\sum_i v_{ij} w_{ik}}{\sum_i \frac{w^2_{ik}}{r_{ijk}}}

以上をひとつにまとめて、NMFアルゴリズムの完成

上の(2)で求めた w_{ik} h_{kj} r_{ijk}の部分に(1)で求めたr_{ijk}を代入してまとめれば更新式の完成。
 t_{ik} \leftarrow t_{ik}\frac{\sum_j v_{ij} h_{kj}}{\sum_j\sum_l(w_{il}h_{lj})h_{kj}}     h_{kj} \leftarrow h_{kj}\frac{\sum_i v_{ij} w_{ik}}{\sum_i \sum_l(w_{il} h_{lj})w_{ik} }
この更新式のことを[1]では、"multiplicative update rules"と呼んでいる。この更新式をくるくる回して、二乗誤差の値の変化がとても小さくなったら終了。

参考文献

[1] Lee, D. D., & Seung, H. S. (2001). Algorithms for non-negative matrix factorization. Advances in neural information processing systems, 13(1), V-621-V-624.

*1:だいぶ省いた。。。