簡単な方法でこれを行うことができます。1つ目は、コーディングが簡単で、理解しやすく、かなり高速です。2番目の方法は少しトリッキーですが、このサイズの問題では、最初の方法やここで述べた他のアプローチよりもはるかに効率的です。
方法1:すばやく汚れています。
各行の確率分布から単一の観測値を取得するには、次のようにするだけです。
# Q is the cumulative distribution of each row.
Q <- t(apply(P,1,cumsum))
# Get a sample with one observation from the distribution of each row.
X <- rowSums(runif(N) > Q) + 1
これにより、各行の累積分布が生成され、各分布から1つの観測値がサンプリングされます。を再利用できる場合は、一度計算して、後で使用するために保存できることに注意してください。ただし、質問には、反復ごとに異なるに対して機能するものが必要です。P Q PP PQP
各行から複数()の観測が必要な場合は、最後の行を次の行に置き換えます。n
# Returns an N x n matrix
X <- replicate(n, rowSums(runif(N) > Q)+1)
これは一般的にこれを実行するための非常に効率的な方法ではありませんR
が、通常は実行速度の主要な決定要因であるベクトル化機能を十分に活用しています。理解するのも簡単です。
方法2:cdfを連結します。
2つのベクトルを取る関数があり、2番目のベクトルが単調非減少順にソートされ、最初の各要素の最大下限の2番目のベクトルのインデックスが見つかったとします。次に、この関数と巧妙なトリックを使用できます。すべての行の累積分布関数の累積合計を作成するだけです。これにより、範囲の要素を持つ単調増加するベクトルが得られます。[0,N]
これがコードです。
i <- 0:(N-1)
# Cumulative function of the cdfs of each row of P.
Q <- cumsum(t(P))
# Find the interval and then back adjust
findInterval(runif(N)+i, Q)-i*K+1
最後の行が何をしているのかに注意してください、それは分布するランダム変数を作成し、次に呼び出して各エントリの最大の下限のインデックスを見つけます。したがって、これは、の最初の要素がインデックス1とインデックスにあり、2番目の要素はインデックスと間にあることを示しています。それぞれ、対応する行の分布に従っています。次に、各インデックスをの範囲に戻すために、逆変換を行う必要があります。K K + 1 2 K P { 1 、... 、K }(0,1),(1,2),…,(N−1,N)findInterval
runif(N)+i
KK+12KP{1,…,K}
findInterval
アルゴリズム的にも実装的にも高速であるため、この方法は非常に効率的です。
ベンチマーク
私の古いラップトップ(MacBook Pro、2.66 GHz、8 GB RAM)で、およびこれを試し、更新された質問で提案されたとおり、サイズ 5000サンプルを生成し、合計5,000万のランダム変量。K = 100 NN=10000K=100N
方法1のコードの実行にはほぼ正確に15分、つまり1秒あたり約55Kのランダム変量がかかりました。方法2のコードの実行には約4分30分、つまり毎秒約183Kのランダム変量がかかりました。
ここに再現性のためのコードがあります。(コメントに示されているように、OPの状況をシミュレートするために、5000回の反復ごとにが再計算されることに注意してください。)Q
# Benchmark code
N <- 10000
K <- 100
set.seed(17)
P <- matrix(runif(N*K),N,K)
P <- P / rowSums(P)
method.one <- function(P)
{
Q <- t(apply(P,1,cumsum))
X <- rowSums(runif(nrow(P)) > Q) + 1
}
method.two <- function(P)
{
n <- nrow(P)
i <- 0:(n-1)
Q <- cumsum(t(P))
findInterval(runif(n)+i, Q)-i*ncol(P)+1
}
これが出力です。
# Method 1: Timing
> system.time(replicate(5e3, method.one(P)))
user system elapsed
691.693 195.812 899.246
# Method 2: Timing
> system.time(replicate(5e3, method.two(P)))
user system elapsed
182.325 82.430 273.021
追記:のコードを見ると、エントリがあるかどうか、または2番目の引数がソートされfindInterval
ていないかどうかを確認するために、入力のチェックが行われていることがわかりNA
ます。したがって、これからより多くのパフォーマンスを引き出したい場合は、findInterval
これらのチェックを取り除いた独自の変更バージョンを作成できます。