「正月にMetropolis Procedural Modelingを基礎から理解してみる」その3。
あらかじめp(x)の形が分かってるんじゃ詰まらないので、 もうちょい答えが自明でないような問題をやってみよう。 なんか適当な2次元の点の集合が与えられた時に、フィットするカーブを見つける。 たとえば入力はこんなの。
今度のコードはmcmc3.scm。
今回のサンプル空間は、前回のような(x,y)ではなくて、 フィットすべき多項式の係数。1次式なら ax+b = y の [a,b]、 2次式なら ax^2+bx+c = y の [a,b,c] ってな具合。
例えば1次式でフィッティングするとして、[a,b]の分布ってのはどう考えるべきだろう。 実は、今回は「最もフィットするサンプルを見つける」ことが目的なんで、全体の分布の 形はあまり重要ではなくて、単に最もフィットする[a,b]が分布上でも極大になっていることさえ 保証できればいい。
そこで、pには誤差が正規分布になるとして得られる確率を使うことにしよう。Dは入力の 点群、θがサンプルである (前回までxと言ってたもの)。 異なる次数の多項式に対応できるようにしている。
;; Sample points D = {(x,y)} is given in a format ((x0 . y0) (x1 . y1) ...) ;; Within an area 0 <= x <= 10 and 0 <= y <= 10 ;; Calculate likelihood p(D|θ). We use gaussian distribution. (define (p D θ) (/ (exp (- (/ (err D (polynomial θ)) (* 2 4)))) 2)) ;σ = 2 (define (err D f) (fold (^[d s] (+ s (expt (- (cdr d) (f (car d))) 2))) 0 D)) (define (polynomial θ) (match θ [(a b) (^x (+ (* a x) (* 5 b)))] [(a b c) (^x (+ (* a x x) (* 5 b x) (* 5 5 c)))] [(a b c d) (^x (+ (* a x x x) (* 5 b x x) (* 5 5 c x) (* 5 5 5 d)))] [(a b c d e) (^x (+ (* a x x x x) (* 5 b x x x) (* 5 5 c x x) (* 5 5 5 d x) (* 5 5 5 5 e)))] [(a b c d e f) (^x (+ (* a x x x x x) (* 5 b x x x x) (* 5 5 c x x x) (* 5 5 5 d x x) (* 5 5 5 5 e x) (* 5 5 5 5 5 f)))]))
polynomial
で 5 という係数が入ってるのは、xの範囲を 0<=x<=10にしてるんで、
その期待値。こうしておくと各パラメータを動かす範囲をだいたい同じにできる。
この係数を入れないと、例えば5次式ではaをわずかに動かすだけでグラフが
大きく変動するのに、fの方は大きく動かさないと影響が見えない、なんてことになる。
(と、試行錯誤の経験からの後付け。最初にxを正規化しておいた方が綺麗だったかも。)
今回のマルコフ過程も、各パラメータを現在の位置から正規分布で動かす。 ただ、今回は離散的である必要がないので、連続関数からのサンプリングを行う cpickを使う。cpickの定義はmcmc-util.scmに。
(define-constant σ 0.002) ;; choose kernel ;; Pick a sample θ* in n-dimensional space (define (pickN θ) (map (^t (cpick t σ)) θ))
Metropolis-Hastingsのステップはこれまでとほぼ同じ。 ただ、誤差の変化を見たいので、stepは更新された状態とともにその誤差を返す。
;; Calculate posteriori probability q(θ*|θ) (define (q θ* θ) (apply * (map (^[t* t] (/ (φ (/ (- t* t) σ)) σ)) θ* θ))) ;; Single MH step. Returns next θ and error^2 value. (define (step D θ) (let* ([u (random-real)] [θ* (pickN θ)] [θ1 (if (< u (min 1 (/ (* (p D θ*) (q θ* θ)) (* (p D θ) (q θ θ*))))) θ* θ)] [e (err D (polynomial θ1))]) (values θ1 e)))
で、これがドライバ。最もエラーが少なかった状態をθ-bestに保存。
(define (run D θ0 N) (receive (θ e) (step D θ0) (let loop ([i 0] [θ θ] [θ-best θ] [e-best e]) (if (= i N) (values θ-best e-best) (receive (θ1 e1) (step D θ) (print i "\t" e1) (if (< e1 e-best) (loop (+ i 1) θ1 θ1 e1) (loop (+ i 1) θ1 θ-best e-best))))))) (define (mcmc3 θ0 :key (samples 50000) (datafile "mcmc3.input") (errfile "/dev/null")) (let1 D (map (cut apply cons <>) (slices (file->sexp-list datafile) 2)) (with-output-to-file errfile (^[] (run D θ0 samples)))))
ではちょっと走らせてみる。まずは1次式。これはわりと収束が速いので、 10000サンプルくらい走らせてみる。
gosh> (mcmc3 '(0 0) :samples 10000 :errfile "mcmc3.err") (0.6350394180428015 0.34573476504861994) 95.0749613122097 gosh> (plot'result '(0.6350394180428015 0.34573476504861994)) #<undef>
なんとなくフィットしてますな。誤差の遷移はこのとおり。
gosh> (plot'error :xmin 0 :ymax 500) #<undef>
1700サンプルくらいで最適値の近くに落ち着いて、後はその周囲をうろつきながら少し良くなればラッキー、という感じか。
2次式いってみる。入力はいかにも3次式っぽいのだが、MHは果たしてどのような答えを出すか。
gosh> (mcmc3 '(0 0 0) :samples 10000 :errfile "mcmc3.err") (0.004929486265444075 0.1183081453945932 0.07132720430159654) 94.93203226951134 gosh> (plot'result '(0.004929486265444075 0.1183081453945932 0.07132720430159654)) #<undef> gosh> (plot'error :xmin 0 :ymax 500) #<undef>
あーなるほど。中心をずらして直線に近い部分を当ててきた。面白いことに 収束は1次式よりかなり速い。
3次式。これも結構収束は速いんだけど、状態の遷移によりセンシティブな感じで、 誤差のギザギザが多い。ちょっと違う方向に外れてからアプローチしなおす、みたいなことを 繰り返してるような。長く走らせてるともうちょい誤差の少ないのを見つけることも できる (実験中は27くらいまで見たかな。)
gosh> (mcmc3 '(0 0 0 0) :samples 10000 :errfile "mcmc3.err") (0.04019606976415797 -0.12295913378494573 0.12721511338017866 -0.0036730173096711726) 30.917123842960937 gosh> (plot'result '(0.04019606976415797 -0.12295913378494573 0.12721511338017866 -0.0036730173096711726)) #<undef> gosh> (plot'error :xmin 0 :ymax 500) #<undef>
では4次式だとどうなるか。とりあえず50kサンプル走らせて誤差を見る。
gosh> (mcmc3 '(0 0 0 0 0) :samples 50000 :errfile "mcmc3.err") (0.0028206266808124903 -0.0037032990192696134 -0.00803881338184116 0.016731836398429796 1.8249063733267665e-4) 33.72910050098417 gosh> (plot'error :xmin 0 :ymax 500) #<undef>
これまでと大部違うグラフになった。一ヶ所に留まることが多いのは、 遷移候補が却下されることが格段に多くなったってことだ。 それだけ動くのが難しくなってるってことで、十分に空間が探索されてない可能性がある。 ちょっとサンプルを増やしてみよう。
gosh> (mcmc3 '(0 0 0 0 0) :samples 200000 :errfile "mcmc3.err") (0.0034482639171228313 -0.006011645024975547 -0.005733385212459884 0.016593561731897424 -1.7863372331037718e-4) 34.527811540078076 gosh> (plot'result '(0.0034482639171228313 -0.006011645024975547 -0.005733385212459884 0.016593561731897424 -1.7863372331037718e-4)) #<undef> gosh> (plot'error :xmin 0 :ymax 500) #<undef>
んー、とりたてて改善されなかったなあ。見た目ではそれなりにフィットしてるけど。 今のコードではaが0になることを禁止してないので、実は4次式でもa=0にして 3次式に落とした方がよくフィットする。その解を見つけるには至らなかったようだ。
では5次式。かなーり動きづらくなってるので、500kサンプルいってみる。 コーヒーでも淹れて結果を待とう。
gosh> (mcmc3 '(0 0 0 0 0 0) :samples 500000 :errfile "mcmc3.err") (-0.001570251751036906 0.007335991498906258 -0.009942979882494226 0.0017548736436514207 0.004426355897388467 -3.672877738783138e-4) 33.86391421988126 gosh> (plot'result '(-0.001570251751036906 0.007335991498906258 -0.009942979882494226 0.0017548736436514207 0.004426355897388467 -3.672877738783138e-4)) #<undef> gosh> (plot'error :xmin 0 :ymax 500) #<undef>
お、誤差33。わりとよく頑張った。誤差履歴を見ると38万サンプル目くらいでこの解を見つけた様子。
なんてことがわかった。
さて、今回は「次数がnならどうなるか」って具合に実験を進めたが、 もしあらかじめ最適な次数がわかっていない場合はどうしたらいいだろう。
と、そこでRJMCMCの出番となるのであった。→ Gauche:MetropolisProceduralModeling:RJMCMC-fit