Pytorch:unetアーキテクチャでカスタムウェイトマップを使用する正しい方法


8

u-netアーキテクチャには、カスタムのウェイトマップを使用して精度を高めるための有名なトリックがあります。詳細は次のとおりです。

ここに画像の説明を入力してください

さて、ここや他の複数の場所で質問することで、2つのアプローチについて知ることができます。

1)最初はtorch.nn.Functionalトレーニングループでメソッドを使用することです-

loss = torch.nn.functional.cross_entropy(output, target, w) ここで、wは計算されたカスタムの重みです。

2)2つ目はreduction='none'、トレーニングループ外の損失関数の呼び出しで使用することです。 criterion = torch.nn.CrossEntropy(reduction='none')

そして、トレーニングループでカスタムウェイトを掛けます-

gt # Ground truth, format torch.long
pd # Network output
W # per-element weighting based on the distance map from UNet
loss = criterion(pd, gt)
loss = W*loss # Ensure that weights are scaled appropriately
loss = torch.sum(loss.flatten(start_dim=1), axis=0) # Sums the loss per image
loss = torch.mean(loss) # Average across a batch

さて、どちらが正しいのか、他に方法があるのか​​、どちらが正しいのか、ちょっと混乱しています。

回答:


3

重み付け部分は、クラスの数(以下の例では2)に対してこのように実行される単純に重み付けされたクロスエントロピーのように見えます。

weights = torch.FloatTensor([.3, .7])
loss_func = nn.CrossEntropyLoss(weight=weights)

編集:

パトリック・ブラックからこの実装を見たことがありますか?

# Set properties
batch_size = 10
out_channels = 2
W = 10
H = 10

# Initialize logits etc. with random
logits = torch.FloatTensor(batch_size, out_channels, H, W).normal_()
target = torch.LongTensor(batch_size, H, W).random_(0, out_channels)
weights = torch.FloatTensor(batch_size, 1, H, W).random_(1, 3)

# Calculate log probabilities
logp = F.log_softmax(logits)

# Gather log probabilities with respect to target
logp = logp.gather(1, target.view(batch_size, 1, H, W))

# Multiply with weights
weighted_logp = (logp * weights).view(batch_size, -1)

# Rescale so that loss is in approx. same interval
weighted_loss = weighted_logp.sum(1) / weights.view(batch_size, -1).sum(1)

# Average over mini-batch
weighted_loss = -1. * weighted_loss.mean()

ここで重要なのは、重みは特定の関数によってここで計算され、目立たないということです。詳細については、ここに論文があります-arxiv.org/abs/1505.04597
Mark

1
@Markああ、私は今見ます。したがって、ピクセル単位の損失出力です。そして、境界は何らかのライブラリなどを使用して事前に計算されopencv、それらのピクセル位置は各画像に対して保存され、後でトレーニング中に損失テンソルが乗算されるため、アルゴリズムはこれらの領域の損失の低減に焦点を合わせます。
jchaykow

ありがとうございます。この合法的なものは回答のように見えます。検証と実装をさらに試み、その後回答を受け入れます。
マーク

あなたはこのラインの後ろの直観説明することができますlogp = logp.gather(1, target.view(batch_size, 1, H, W))
マーク・

0

torch.nn.CrossEntropyLoss()はtorch.nn.functionalを呼び出すクラスであることに注意してください。https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html#CrossEntropyLossを参照してください

基準を定義するときに重みを使用できます。それらを機能的に比較すると、どちらの方法も同じです。

さて、メソッド1のトレーニングループの内側とメソッド2のトレーニングループの外側の損失を計算するという考えは理解できません。ループの外側で損失を計算する場合、どのように逆伝播しますか?


私が使用しての間で混乱はなかったtorch.nn.CrossEntropyLoss() torch.nn.functional.cross_entropy(output, target, w)- 、私はこの論文を参照loss.Pleaseにマップするカスタムの重みを使用する方法を混同してarxiv.org/abs/1505.04597を、あなたはまだ私は何を把握することができないなら、私は、知っています質問
マーク

1
正しく理解できれば、方法2が正しいと思います。損失torch.nn.functional.cross_entropy(output、target、w)内の重み(w)は、式のw(x)ではないクラスの重みです。小さなスクリプトで簡単にテストできます。
Devansh Bisla

ええ、私も同じ結論に達しています。ネットワークが期待どおりに動作し、回答が承認された場合は、元に戻ります。
マーク

大丈夫、そのないworking.Iは取得していますgrad can be implicitly created only for scalar outputs、私は方法ワット損失=損失*を実行したときに
マーク・

それらを合計しているのか、それとも平均値を取っているのですか?
Devansh Bisla
弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.