u-netアーキテクチャには、カスタムのウェイトマップを使用して精度を高めるための有名なトリックがあります。詳細は次のとおりです。
さて、ここや他の複数の場所で質問することで、2つのアプローチについて知ることができます。
1)最初はtorch.nn.Functional
トレーニングループでメソッドを使用することです-
loss = torch.nn.functional.cross_entropy(output, target, w)
ここで、wは計算されたカスタムの重みです。
2)2つ目はreduction='none'
、トレーニングループ外の損失関数の呼び出しで使用することです。
criterion = torch.nn.CrossEntropy(reduction='none')
そして、トレーニングループでカスタムウェイトを掛けます-
gt # Ground truth, format torch.long
pd # Network output
W # per-element weighting based on the distance map from UNet
loss = criterion(pd, gt)
loss = W*loss # Ensure that weights are scaled appropriately
loss = torch.sum(loss.flatten(start_dim=1), axis=0) # Sums the loss per image
loss = torch.mean(loss) # Average across a batch
さて、どちらが正しいのか、他に方法があるのか、どちらが正しいのか、ちょっと混乱しています。