畳み込みはどのように行列乗算(行列形式)として表現できますか?


11

この質問はプログラミングにはあまり関係がないかもしれませんが、画像処理の背後にある理論を理解しなければ、実際に何かを実装することはできません。

ガウスフィルターは、ピクセルの近傍の加重平均を計算し、エッジ検出に非常に役立ちます。これは、ぼかしを適用して画像を同時に導出できるためです。単にガウス関数の導関数とたたみ込みます。

しかし、誰かが私を説明したり、それらがどのように計算されたかについていくつかの参照を私に与えたりできますか?

たとえば、Cannyのエッジ検出器は5x5ガウスフィルターについて話しますが、それらはどのようにしてこれらの特定の数値を取得しましたか?そして、それらはどのようにして、継続的な畳み込みから行列の乗算に移行しましたか?



画像コンボリューション用の行列を生成するための完全なコードを含む回答追加しました
Royi

回答:


3

この操作を機能させるには、画像がベクトルとして再形成されることを想像する必要があります。次に、このベクトルの左側に畳み込み行列を掛けて、ぼやけた画像を取得します。結果は、入力と同じサイズのベクトル、つまり同じサイズの画像でもあることに注意してください。

畳み込み行列の各行は、入力画像の1ピクセルに対応します。これには、対象となるピクセルのぼやけた対応物に対する、画像内の他のすべてのピクセルの寄与の重みが含まれます。

例を挙げましょう:サイズ6 × 6ピクセルの画像上のサイズピクセルのボックスぼかし。再形成された画像は36の列からなり、ぼかし行列のサイズは36 × 36です。3×36×636×36

  • この行列をどこでも0に初期化しましょう。
  • (i,j)1/9(i1,j1);(i1,j),(i1,j+1),,(i+1,j+1)
  • (i,j)6i+j1/9(6i+j)
  • 他のすべてのピクセルでも同じようにします。

密接に関連するプロセス(畳み込み+減算)の視覚的な説明は、このブログ投稿(私の個人ブログから)にあります。


リンクが死んでいる。
gauteh

2

画像または畳み込みネットワークへのアプリケーションの場合、最新のGPUで行列乗算器をより効率的に使用するために、入力は通常、一度に複数のフィルター/カーネルで乗算できるアクティブ化行列の列に再形成されます。

詳細については、スタンフォードのCS231nからこのリンクを確認し、「行列乗算としての実装」のセクションまでスクロールしてください。

このプロセスは、入力イメージまたはアクティブ化マップ上のすべてのローカルパッチ(カーネルで乗算されるパッチ)を取得し、一般にim2colと呼ばれる操作によって新しい行列Xの列にそれらを引き伸ばすことによって機能します。カーネルはまた、重み行列Wの行を埋めるように引き伸ばされるため、行列演算W * Xを実行すると、結果の行列Yには畳み込みのすべての結果が含まれます。最後に、一般にcal2imと呼ばれる操作によって列を画像に変換することにより、Yマトリックスを再度整形する必要があります。


1
これはとても良いリンクです、ありがとう!ただし、リンクからの重要な抽出を回答に追加することをお勧めします。これにより、リンクが壊れた場合でも回答が有効になります。回答を編集するには、回答を編集することを検討してください。
マッテオ

1

時間領域でのたたみ込みは、周波数領域での行列乗算に等しく、その逆も同様です。

フィルタリングは、時間領域でのたたみ込み、つまり周波数領域での行列乗算に相当します。

5x5マップまたはマスクに関しては、それらはcanny / sobelオペレーターを離散化することから得られます。


2
フィルタリングが周波数領域でのたたみ込みであるという事実には同意しません。ここで説明する種類のフィルターは、空間領域でのたたみ込み(つまり、周波数領域でのフィルター応答による要素ごとの乗算)です。
ピシェネット2013年

私が書いたものを訂正してくれてありがとう。その後の編集を行いました。投稿する前に自分の回答を再確認する必要があると思います。しかし、私の答えの大部分はまだ残っています。
Naresh 2013年

フーリエ変換は実際に畳み込みを乗算に変換します(逆も同様です)。ただし、問題は画像の形状を変更することで得られる行列とベクトルの乗算についてですが、これらは正確な乗算です。
sansuiso 2013年

canny / sobel演算子で5x5行列が得られる理由は、演算子の離散化がいかにかについて言及しました。
Naresh

1

StackOverflow Q2080835 GitHubリポジトリでこれを解決する関数を書きました(を見てくださいCreateImageConvMtx())。
実際、この関数は、任意の畳み込み形状をサポートできます- fullsameおよびvalid

コードは次のとおりです。

function [ mK ] = CreateImageConvMtx( mH, numRows, numCols, convShape )

CONVOLUTION_SHAPE_FULL  = 1;
CONVOLUTION_SHAPE_SAME  = 2;
CONVOLUTION_SHAPE_VALID = 3;

switch(convShape)
    case(CONVOLUTION_SHAPE_FULL)
        % Code for the 'full' case
        convShapeString = 'full';
    case(CONVOLUTION_SHAPE_SAME)
        % Code for the 'same' case
        convShapeString = 'same';
    case(CONVOLUTION_SHAPE_VALID)
        % Code for the 'valid' case
        convShapeString = 'valid';
end

mImpulse = zeros(numRows, numCols);

for ii = numel(mImpulse):-1:1
    mImpulse(ii)    = 1; %<! Create impulse image corresponding to i-th output matrix column
    mTmp            = sparse(conv2(mImpulse, mH, convShapeString)); %<! The impulse response
    cColumn{ii}     = mTmp(:);
    mImpulse(ii)    = 0;
end

mK = cell2mat(cColumn);


end

画像フィルタリング用の行列を作成する関数も作成しました(MATLABと同様の考え方imfilter()):

function [ mK ] = CreateImageFilterMtx( mH, numRows, numCols, operationMode, boundaryMode )
%UNTITLED6 Summary of this function goes here
%   Detailed explanation goes here

OPERATION_MODE_CONVOLUTION = 1;
OPERATION_MODE_CORRELATION = 2;

BOUNDARY_MODE_ZEROS         = 1;
BOUNDARY_MODE_SYMMETRIC     = 2;
BOUNDARY_MODE_REPLICATE     = 3;
BOUNDARY_MODE_CIRCULAR      = 4;

switch(operationMode)
    case(OPERATION_MODE_CONVOLUTION)
        mH = mH(end:-1:1, end:-1:1);
    case(OPERATION_MODE_CORRELATION)
        % mH = mH; %<! Default Code is correlation
end

switch(boundaryMode)
    case(BOUNDARY_MODE_ZEROS)
        mK = CreateConvMtxZeros(mH, numRows, numCols);
    case(BOUNDARY_MODE_SYMMETRIC)
        mK = CreateConvMtxSymmetric(mH, numRows, numCols);
    case(BOUNDARY_MODE_REPLICATE)
        mK = CreateConvMtxReplicate(mH, numRows, numCols);
    case(BOUNDARY_MODE_CIRCULAR)
        mK = CreateConvMtxCircular(mH, numRows, numCols);
end


end


function [ mK ] = CreateConvMtxZeros( mH, numRows, numCols )
%UNTITLED6 Summary of this function goes here
%   Detailed explanation goes here

numElementsImage    = numRows * numCols;
numRowsKernel       = size(mH, 1);
numColsKernel       = size(mH, 2);
numElementsKernel   = numRowsKernel * numColsKernel;

vRows = reshape(repmat(1:numElementsImage, numElementsKernel, 1), numElementsImage * numElementsKernel, 1);
vCols = zeros(numElementsImage * numElementsKernel, 1);
vVals = zeros(numElementsImage * numElementsKernel, 1);

kernelRadiusV = floor(numRowsKernel / 2);
kernelRadiusH = floor(numColsKernel / 2);

pxIdx       = 0;
elmntIdx    = 0;

for jj = 1:numCols
    for ii = 1:numRows
        pxIdx = pxIdx + 1;
        for ll = -kernelRadiusH:kernelRadiusH
            for kk = -kernelRadiusV:kernelRadiusV
                elmntIdx = elmntIdx + 1;

                pxShift = (ll * numCols) + kk;

                if((ii + kk <= numRows) && (ii + kk >= 1) && (jj + ll <= numCols) && (jj + ll >= 1))
                    vCols(elmntIdx) = pxIdx + pxShift;
                    vVals(elmntIdx) = mH(kk + kernelRadiusV + 1, ll + kernelRadiusH + 1);
                else
                    vCols(elmntIdx) = pxIdx;
                    vVals(elmntIdx) = 0; % See the accumulation property of 'sparse()'.
                end
            end
        end
    end
end

mK = sparse(vRows, vCols, vVals, numElementsImage, numElementsImage);


end


function [ mK ] = CreateConvMtxSymmetric( mH, numRows, numCols )
%UNTITLED6 Summary of this function goes here
%   Detailed explanation goes here

numElementsImage    = numRows * numCols;
numRowsKernel       = size(mH, 1);
numColsKernel       = size(mH, 2);
numElementsKernel   = numRowsKernel * numColsKernel;

vRows = reshape(repmat(1:numElementsImage, numElementsKernel, 1), numElementsImage * numElementsKernel, 1);
vCols = zeros(numElementsImage * numElementsKernel, 1);
vVals = zeros(numElementsImage * numElementsKernel, 1);

kernelRadiusV = floor(numRowsKernel / 2);
kernelRadiusH = floor(numColsKernel / 2);

pxIdx       = 0;
elmntIdx    = 0;

for jj = 1:numCols
    for ii = 1:numRows
        pxIdx = pxIdx + 1;
        for ll = -kernelRadiusH:kernelRadiusH
            for kk = -kernelRadiusV:kernelRadiusV
                elmntIdx = elmntIdx + 1;

                pxShift = (ll * numCols) + kk;

                if(ii + kk > numRows)
                    pxShift = pxShift - (2 * (ii + kk - numRows) - 1);
                end

                if(ii + kk < 1)
                    pxShift = pxShift + (2 * (1 -(ii + kk)) - 1);
                end

                if(jj + ll > numCols)
                    pxShift = pxShift - ((2 * (jj + ll - numCols) - 1) * numCols);
                end

                if(jj + ll < 1)
                    pxShift = pxShift + ((2 * (1 - (jj + ll)) - 1) * numCols);
                end

                vCols(elmntIdx) = pxIdx + pxShift;
                vVals(elmntIdx) = mH(kk + kernelRadiusV + 1, ll + kernelRadiusH + 1);

            end
        end
    end
end

mK = sparse(vRows, vCols, vVals, numElementsImage, numElementsImage);


end


function [ mK ] = CreateConvMtxReplicate( mH, numRows, numCols )
%UNTITLED6 Summary of this function goes here
%   Detailed explanation goes here

numElementsImage    = numRows * numCols;
numRowsKernel       = size(mH, 1);
numColsKernel       = size(mH, 2);
numElementsKernel   = numRowsKernel * numColsKernel;

vRows = reshape(repmat(1:numElementsImage, numElementsKernel, 1), numElementsImage * numElementsKernel, 1);
vCols = zeros(numElementsImage * numElementsKernel, 1);
vVals = zeros(numElementsImage * numElementsKernel, 1);

kernelRadiusV = floor(numRowsKernel / 2);
kernelRadiusH = floor(numColsKernel / 2);

pxIdx       = 0;
elmntIdx    = 0;

for jj = 1:numCols
    for ii = 1:numRows
        pxIdx = pxIdx + 1;
        for ll = -kernelRadiusH:kernelRadiusH
            for kk = -kernelRadiusV:kernelRadiusV
                elmntIdx = elmntIdx + 1;

                pxShift = (ll * numCols) + kk;

                if(ii + kk > numRows)
                    pxShift = pxShift - (ii + kk - numRows);
                end

                if(ii + kk < 1)
                    pxShift = pxShift + (1 -(ii + kk));
                end

                if(jj + ll > numCols)
                    pxShift = pxShift - ((jj + ll - numCols) * numCols);
                end

                if(jj + ll < 1)
                    pxShift = pxShift + ((1 - (jj + ll)) * numCols);
                end

                vCols(elmntIdx) = pxIdx + pxShift;
                vVals(elmntIdx) = mH(kk + kernelRadiusV + 1, ll + kernelRadiusH + 1);

            end
        end
    end
end

mK = sparse(vRows, vCols, vVals, numElementsImage, numElementsImage);


end


function [ mK ] = CreateConvMtxCircular( mH, numRows, numCols )
%UNTITLED6 Summary of this function goes here
%   Detailed explanation goes here

numElementsImage    = numRows * numCols;
numRowsKernel       = size(mH, 1);
numColsKernel       = size(mH, 2);
numElementsKernel   = numRowsKernel * numColsKernel;

vRows = reshape(repmat(1:numElementsImage, numElementsKernel, 1), numElementsImage * numElementsKernel, 1);
vCols = zeros(numElementsImage * numElementsKernel, 1);
vVals = zeros(numElementsImage * numElementsKernel, 1);

kernelRadiusV = floor(numRowsKernel / 2);
kernelRadiusH = floor(numColsKernel / 2);

pxIdx       = 0;
elmntIdx    = 0;

for jj = 1:numCols
    for ii = 1:numRows
        pxIdx = pxIdx + 1;
        for ll = -kernelRadiusH:kernelRadiusH
            for kk = -kernelRadiusV:kernelRadiusV
                elmntIdx = elmntIdx + 1;

                pxShift = (ll * numCols) + kk;

                if(ii + kk > numRows)
                    pxShift = pxShift - numRows;
                end

                if(ii + kk < 1)
                    pxShift = pxShift + numRows;
                end

                if(jj + ll > numCols)
                    pxShift = pxShift - (numCols * numCols);
                end

                if(jj + ll < 1)
                    pxShift = pxShift + (numCols * numCols);
                end

                vCols(elmntIdx) = pxIdx + pxShift;
                vVals(elmntIdx) = mH(kk + kernelRadiusV + 1, ll + kernelRadiusH + 1);

            end
        end
    end
end

mK = sparse(vRows, vCols, vVals, numElementsImage, numElementsImage);


end

コードはMATLABに対して検証されましたimfilter()

StackOverflow Q2080835 GitHubリポジトリで完全なコードを入手できます

弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.