(PyTorchを使用して)不均衡なクラスに使用する損失関数は何ですか?


16

次のアイテムを含む3つのクラスのデータセットがあります。

  • クラス1:900要素
  • クラス2:15000要素
  • クラス3:800要素

規範からの重要な逸脱を示すクラス1とクラス3を予測する必要があります。クラス2はデフォルトの「通常の」ケースで、私は気にしません。

ここではどのような損失関数を使用しますか?CrossEntropyLossの使用を考えていましたが、クラスの不均衡があるため、重み付けする必要があると思いますか?実際にはどのように機能しますか?このように(PyTorchを使用)?

summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)

または、重量を逆にする必要がありますか?つまり、1 /重量?

これは最初から正しいアプローチですか、それとも私が使用できる他の/より良い方法がありますか?

ありがとう

回答:


11

ここではどのような損失関数を使用しますか?

クロスエントロピーは、平衡または不平衡のいずれかの分類タスクの主要な損失関数です。これは、ドメインの知識からまだ好みが構築されていない場合の最初の選択です。

これは私が思うに重み付けする必要があるでしょうか?実際にはどのように機能しますか?

cc

たとえば、クラス1に900、クラス2に15000、クラス3に800のサンプルがある場合、それらの重みはそれぞれ16.67、1.0、18.75になります。

最小クラスを分母として使用することもできます。これにより、それぞれ0.889、0.053、および1.0が得られます。これは単なる再スケーリングであり、相対的な重みは同じです。

これは最初から正しいアプローチですか、それとも私が使用できる他の/より良い方法がありますか?

はい、これは正しいアプローチです。

編集

@Muppetのおかげで、クラスのオーバーサンプリングを使用することもできます。これは、クラスの重みを使用すること同等です。これはWeightedRandomSampler、前述の同じ重みを使用して、PyTorchで実行されます。


1
他の誰かがこれを見ている場合に備えて、PyTorchのWeightedRandomSamplerを使用することも役立つことを追加したかっただけです。
マペット

0

あなたが言うとき:あなたはまた、それぞれ0.889、0.053、および1.0を与える分母として最小クラスを使用することができます。これは単なる再スケーリングであり、相対的な重みは同じです。

しかし、この解決策はあなたが最初に与えた解決策とは矛盾しています。

弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.