perlでLDA(Latent Dirichlet Allocation)を書いてみた

感覚を掴むためにperlでLDAを書いてみた。130行くらい。あくまで練習なので効率のよさとかは考えてない。とりあえず動いたよ的な。実装はBlei論文に書かれている変分ベイズ版を使った。

Latent Dirichlet Allocation D.M.Blei et al, 2003

なおディガンマ関数の実装は持橋さんのものを参考にさせていただいた。感謝。(問題ありそうだったら消します。)

http://chasen.org/~daiti-m/dist/lda/

#!/usr/lcoa/bin/perl
use strict;
use warnings;
use Data::Dumper;

my $filename = shift @ARGV;
my $v        = shift @ARGV;
my $k        = shift @ARGV;
my $t        = shift @ARGV;
my @beta;
# initialize
for (my $ki = 0; $ki < $k; $ki++) {
  my $Z = 0;
  for (my $vi = 0; $vi < $v; $vi++) {
    $beta[$ki][$vi] = rand();
    $Z += $beta[$ki][$vi];
  }
  for (my $vi = 0; $vi < $v; $vi++) {
    $beta[$ki][$vi] /= $Z;
  }
}

# parameter estimation
for (my $ti = 0; $ti < $t; $ti++) {
  my @beta_next;

  open(my $F, "<$filename") or die "$filename cannot open.\n";
  while (<$F>) {
    chomp;
    my @a = split(/,/);

    my @doc;
    my $N = 0;
    foreach (@a) {
      my ($w, $f) = split(/:/);
      push(@doc, {word => $w, freq => $f});
      $N += $f;
    }
    my $phi_ref = get_phi(\@doc, $N, \@beta, $k, $t);

    my $n = scalar(@doc);
    for (my $ni = 0; $ni < $n; $ni++) {
      my $word = $doc[$ni]{word};
      my $freq = $doc[$ni]{freq};
      for (my $ki = 0; $ki < $k; $ki++) {
        $beta_next[$ki][$word] += ($phi_ref->[$ki][$ni] * $freq);
      }
    }
  }
  close($F);

  for (my $ki = 0; $ki < $k; $ki++) {
    my $Z = 0;
    for (my $vi = 0; $vi < $v; $vi++) {
      $Z += $beta_next[$ki][$vi];
    }
    for (my $vi = 0; $vi < $v; $vi++) {
      $beta[$ki][$vi] = $beta_next[$ki][$vi] / $Z;
    }
  }
}
print "[beta]\n";
print Dumper(\@beta);


sub get_phi {
  my $doc_ref  = shift;
  my $N        = shift;
  my $beta_ref = shift;
  my $k        = shift;
  my $t        = shift;

  # initialize
  my $n = scalar(@$doc_ref);
  my @phi;
  my @gamma;
  for (my $ki = 0; $ki < $k; $ki++) {
    $gamma[$ki] = 1 + $N / $k;
    for (my $ni = 0; $ni < $n; $ni++) {
      $phi[$ki][$ni] = 1 / $k;
    }
  }

  # variational inference
  for (my $ti = 0; $ti < $t; $ti++) {
    my @gamma_next;
    for (my $ni = 0; $ni < $n; $ni++) {
      my $word = $doc_ref->[$ni]{word};
      my $freq = $doc_ref->[$ni]{freq};
      my $Z = 0;
      for (my $ki = 0; $ki < $k; $ki++) {
        $phi[$ki][$ni] = $beta_ref->[$ki][$word]
                         * exp(digamma($gamma[$ki]));
        $Z += $phi[$ki][$ni];
      }
      for (my $ki = 0; $ki < $k; $ki++) {
        $phi[$ki][$ni] /= $Z;
        $gamma_next[$ki] += ($phi[$ki][$ni] * $freq);
      }
    }
    for (my $ki = 0; $ki < $k; $ki++) {
      $gamma[$ki] = $gamma_next[$ki] + 1;
    }
  }
  return \@phi;
}

sub digamma {
  my $x = shift;

  my $s3 = 1.0 / 12;
  my $s4 = 1.0 / 120;
  my $s5 = 1.0 / 252;
  my $s6 = 1.0 / 240;
  my $s7 = 1.0 / 132;

  my $y = 0;
  while($x < 12) {
    $y -= (1 / $x);
    $x++;
  }
  my $r = 1 / $x;
  $y += (log($x) - (0.5 * $r));
  $r *= $r;
  $y -= $r * ($s3 - $r * ($s4 - $r * ($s5 - $r * ($s6 - $r * $s7))));
  return $y;
}

で早速動かしてみる。入力ファイルは1行1文書で「単語ID:頻度,...」となっている。

$$ cat input.txt
0:5,1:4,2:1
0:1,1:2,2:5
0:4,1:5,2:2

実行する。引数は入力ファイル、単語の語彙数、トピック数、イテレーション数。

$$ ./lda.pl input.txt 3 2 10
[beta]
$VAR1 = [
          [
            '0.122957106393658',
            '0.318122459019145',
            '0.558920434587196'
          ],
          [
            '0.498310181601847',
            '0.421638082742583',
            '0.0800517356555694'
          ]
        ];

ひとつめのトピックは単語2が出やすく、ふたつめのトピックは単語0と単語1がでやすいことがわかる。