Matlabでの最適なトランスポートワーピングの実装


11

私は「登録とワーピングに最適なマストランスポート」というペーパーを実装しています。私の目標は、オイラーマストランスポートコードをオンラインで見つけることができないため、オンラインにすることです。これは、少なくとも画像処理の研究コミュニティにとって興味深いものです。

この論文は次のように要約できます。
- x座標とy座標に沿った1Dヒストグラムマッチングを使用して初期マップを見つける の固定点を、ここでは反時計回りの90度回転を表し、はディリクレ境界条件(= 0)のポアソン方程式の解を表します。そして、ヤコビ行列の行列です。 -タイムステップ安定性が保証されていますu t = 1u
ut=1μ0Du1div(u)u1Du
dt<min|1μ01div(u)|

数値シミュレーション(通常のグリッドで実行)の場合、ポアソン方程式を解くためにmatlabのpoicalcを使用することを示し、風上スキームを使用して計算されるDuを除いて、空間微分に中心有限差分を使用します。

私のコードを使用すると、エネルギー関数とマッピングのカールは、2、3回の反復(タイムステップに応じて数十から数千)に対して適切に減少しています。しかし、その後、シミュレーションは爆発します。非常に少ない反復でエネルギーが増加し、NANに到達します。私は微分と積分(ここで cumptrapzのより高次の置換がここにあります)といくつかの補間スキームに対していくつかの次数を試してみましが、常に同じ問題が発生します(非常に滑らかな画像、どこでも0でないなど)。
誰かがコードや私が直面している理論上の問題に興味がありますか?コードはかなり短いです。

最後のgradient2()をgradient()に置き換えてください。これはより高次の勾配でしたが、問題も解決しません。

今のところ、紙の最適な輸送部分にのみ関心があり、追加の正則化用語には関心がありません。

よろしくお願いします!

回答:


5

私の親友のPascalが数年前にこれを作成しました(ほとんど Matlabにあります):

#! /usr/bin/env python

#from scipy.interpolate import interpolate
from pylab import *
from numpy import *


def GaussianFilter(sigma,f):
    """Apply Gaussian filter to an image"""
    if sigma > 0:
        n = ceil(4*sigma)
        g = exp(-arange(-n,n+1)**2/(2*sigma**2))
        g = g/g.sum()

        fg = zeros(f.shape)

        for i in range(f.shape[0]):
            fg[i,:] = convolve(f[i,:],g,'same')
        for i in range(f.shape[1]):
            fg[:,i] = convolve(fg[:,i],g,'same')
    else:
        fg = f

    return fg


def clamp(x,xmin,xmax):
    """Clamp values between xmin and xmax"""
    return minimum(maximum(x,xmin),xmax)


def myinterp(f,xi,yi):
    """My bilinear interpolator (scipy's has a segfault)"""
    M,N = f.shape
    ix0 = clamp(floor(xi),0,N-2).astype(int)
    iy0 = clamp(floor(yi),0,M-2).astype(int)
    wx = xi - ix0
    wy = yi - iy0
    return ( (1-wy)*((1-wx)*f[iy0,ix0] + wx*f[iy0,ix0+1]) +
        wy*((1-wx)*f[iy0+1,ix0] + wx*f[iy0+1,ix0+1]) )


def mkwarp(f1,f2,sigma,phi,showplot=0):
    """Image warping by solving the Monge-Kantorovich problem"""
    M,N = f1.shape[:2]

    alpha = 1
    f1 = GaussianFilter(sigma,f1)
    f2 = GaussianFilter(sigma,f2)

    # Shift indices for going from vertices to cell centers
    iUv = arange(M)             # Up
    iDv = arange(1,M+1)         # Down
    iLv = arange(N)             # Left
    iRv = arange(1,N+1)         # Right
    # Shift indices for cell centers (to cell centers)
    iUc = r_[0,arange(M-1)]
    iDc = r_[arange(1,M),M-1]
    iLc = r_[0,arange(N-1)]
    iRc = r_[arange(1,N),N-1]
    # Shifts for going from centers to vertices
    iUi = r_[0,arange(M)]
    iDi = r_[arange(M),M-1]
    iLi = r_[0,arange(N)]
    iRi = r_[arange(N),N-1]


    ### The main gradient descent loop ###      
    for iter in range(0,30):
        ### Approximate derivatives ###
        # Compute gradient phix and phiy at pixel centers.  Array phi has values
        # at the pixel vertices.
        phix = (phi[iUv,:][:,iRv] - phi[iUv,:][:,iLv] + 
            phi[iDv,:][:,iRv] - phi[iDv,:][:,iLv])/2
        phiy = (phi[iDv,:][:,iLv] - phi[iUv,:][:,iLv] + 
            phi[iDv,:][:,iRv] - phi[iUv,:][:,iRv])/2
        # Compute second derivatives at pixel centers using central differences.
        phixx = (phix[:,iRc] - phix[:,iLc])/2
        phixy = (phix[iDc,:] - phix[iUc,:])/2
        phiyy = (phiy[iDc,:] - phiy[iUc,:])/2
        # Hessian determinant
        detD2 = phixx*phiyy - phixy*phixy

        # Interpolate f2 at (phix,phiy) with bilinear interpolation
        f2gphi = myinterp(f2,phix,phiy)

        ### Update phi ###
        # Compute M'(phi) at pixel centers
        dM = alpha*(f1 - f2gphi*detD2)
        # Interpolate to pixel vertices
        phi = phi - (dM[iUi,:][:,iLi] + 
            dM[iDi,:][:,iLi] + 
            dM[iUi,:][:,iRi] + 
            dM[iDi,:][:,iRi])/4


    ### Plot stuff ###      
    if showplot:
        pad = 2
        x,y = meshgrid(arange(N),arange(M))
        x = x[pad:-pad,:][:,pad:-pad]
        y = y[pad:-pad,:][:,pad:-pad]
        phix = phix[pad:-pad,:][:,pad:-pad]
        phiy = phiy[pad:-pad,:][:,pad:-pad]

        # Vector plot of the mapping
        subplot(1,2,1)
        quiver(x,y,flipud(phix-x),-flipud(phiy-y))
        axis('image')
        axis('off')
        title('Mapping')

        # Grayscale plot of mapping divergence
        subplot(1,2,2)  
        divs = phixx + phiyy # Divergence of mapping s(x)
        imshow(divs[pad:-pad,pad:-pad],cmap=cm.gray)
        axis('off')
        title('Divergence of Mapping')
        show()

    return phi


if __name__ == "__main__":  # Demo
    from pylab import *
    from numpy import * 

    f1 = imread('brain-tumor.png')
    f2 = imread('brain-healthy.png')
    f1 = f1[:,:,1]
    f2 = f2[:,:,1]

    # Initialize phi as the identity map
    M,N = f1.shape
    n,m = meshgrid(arange(N+1),arange(M+1))
    phi = ((m-0.5)**2 + (n-0.5)**2)/2

    sigma = 3
    phi = mkwarp(f1,f2,sigma,phi)
    phi = mkwarp(f1,f2,sigma/2,phi,1)
#   phi = mkwarp(f1,f2,sigma/4,phi,1)

テスト実行には約2秒かかります。

ここでは、勾配降下法について説明します:people.clarkson.edu/~ebollt/Papers/quadcost.pdf


すばらしい..どうもありがとう!私はこのコードを試して、私のものと比較してエラーをチェックします。このアプローチは、実際にはHakerらによる論文のローカルバージョンのようです。私が言及したこと-つまり、ラプラシアンを解くことなく。再度、感謝します !
WhitAngl

私は最終的にこのコードでいくつかの問題に遭遇しています...:を計算する場合、ガウスを削除するときでも、(でヘッセ行列)からかなり離れていますぼかし。また、反復回数を数千に増やすだけの場合、コードが爆発してNaN値が表示されます(クラッシュします)。何か案が ?よろしくお願いします!f2(ϕ)detHϕf1H
WhitAngl 2012年

さらにぼかすとNaNの問題が解決しますか?
dranxo

はい、確かに、より多くのテストの後、それはより多くのぼかしを助けます-ありがとう!。私は現在、1ピクセルのstdev(まだ計算中)まで、標準偏差140ピクセルのぼかしを開始する8つのステップを試しています。しかし、最後の結果にはかなりの量の元の画像が残っています(64pxのぼかしを使用)。残っているカールもチェックします。ϕ
WhitAngl 2012年

いいわ。おもう。画像が自然に連続的ではないため(エッジ)、ぼかしがあり、グラデーションが問題になります。うまくいけば、あまりぼやけさせることなく、良い答えを得ることができます。
dranxo
弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.