sklearn、3クラス分類のランダムフォレストの適切なOobスコアは何ですか?[重複]


8

約45,000のサンプルで構成される学習データがあり、それぞれ21の機能があります。3つのクラス(-1、0、1)のラベルが付けられたこのデータでランダムフォレスト分類器をトレーニングしようとしています。クラスのサイズはほぼ同じです。

私のランダムフォレスト分類子モデルはgini、その分割品質基準として使用しています。木の数は10であり、木の深さを制限していません。

ほとんどの機能は無視できるほどの重要性を示しています。平均は約5%、それらの3分の1は重要度0、それらの3分の1は平均より上に重要です。

ただし、おそらく最も印象的な事実は、oob(out-of-bag)スコア(1%未満)です。それはモデルが失敗したと私に思わせました、そして実際に、サイズ〜40kの新しい独立したセットでモデルをテストしたところ、63%(これまでのところ良い音)のスコアを得ましたが、混同行列をより詳しく調べると、モデルはクラス0でのみ成功し、1と-1の間で決定する場合、約50%のケースで失敗します。

添付されたPythonの出力:

array([[ 7732,   185,  6259],
       [  390, 11506,   256],
       [ 7442,   161,  6378]])

これは当然のことですが、0クラスには予測をはるかに容易にする特別なプロパティがあるためです。しかし、私が見つけたOobスコアがすでにモデルが良くない兆候であるというのは本当ですか?ランダムフォレストのOobスコアはいくつですか?モデルが「良好」であるか、oobスコアのみを使用するか、またはモデルの他の結果と組み合わせて使用​​するかを決定するのに役立つ経験則はありますか?


編集:不正なデータ(データの約3分の1)を削除した後、ラベルは0の場合は2%程度、-1 / + 1の場合は49%でした。oobスコアは0.011で、テストデータのスコアは0.49であり、混同行列はクラス1(予測の約3/4)にほとんど偏っていません。


2
明確にするために。scikit learnを使用していますか?そしてそれはoobスコア<.001を報告していますか?次に、.63を取得する新しいデータで.score関数を使用しますか?一般に、クロス検証スコアを反映するか、わずかに過小評価するOOBスコアを見つけました。scikit学習分類のスコアはクラス全体の平均精度であると思います(ドキュメントを正しく読んでいる場合)ので、全体/平均以外の精度と直接比較すべきではありませんが、これは実装に依存し、これを引き起こしてはなりません。大きな不一致。
ライアンブレスラー2014

はい、scikit学習を使用しています。oobスコアは0.01より少し低く、テストデータのスコアは約.63でした。
バッハ、

行は独立していますか、それとも同じケース(またはその他の階層/クラスターデータ)を繰り返し測定しましたか?また、明確にしてください:あなたのoobの「スコア」は、エラーの測定値ですか、それとも合意の測定値ですか?
cbeleitesはSXに不満2014

行が繰り返されていませんが、依存している可能性があります。私は信じているscikits「はoob_score、合意の尺度でスコア、です。しかし、文書化されていませんでした。
バッハ

簡単な検索でランダムフォレストのmanページが表示され、「oob_score:boolアウトオブバッグサンプルを使用して汎化エラーを推定するかどうか」と表示されているため、これはエラーメジャーのように見えます。これが本当である場合、あなたのoob推定は非常に過度に楽観的です-これは依存する行の予期される「症状」です。
cbeleitesはSXに不満2014

回答:


4

sklearnのoob_score_ドキュメントとソースコードを読んだ後、sklearnのRF (末尾のアンダースコアに注意してください)は、Rに比べて非常にわかりにくいです。モデルを改善する方法についての私のアドバイスは次のとおりです。

  1. sklearnのRFは、max_features=1(「すべてのノードですべての機能を試す」のように)のひどいデフォルトを使用していました。次に、ランダムフォレストのようなランダムな列(/機能)の選択を行わなくなりました。これを例えばmax_features=0.33(Rのようにmtry)に変更して再実行します。新しいスコアを教えてください。

  2. 「ほとんどの機能は無視できるほどの重要性しか示していない」。次に、分類に従って、ドキュメントに従って機能の選択を行う必要があります。こちらのCrossValidated.SEのドキュメントおよびその他の記事を参照してください。例を使用して、残りのトレーニングとは異なる(たとえば20〜30%)ホールドアウトセットでFSを実行しsklearn.cross_validation.train_test_split()ます(はい、名前は少し誤解を招きます)。では、FS後に得点を教えてください。

  3. あなたは、前記「(データの約3分の1)不正なデータを除去した後、標識は多かれ少なかれ0 2%の-1 / + 1の各49%でした」。その後、あなたは深刻なクラスの不均衡を持っています。また、「混同行列は、モデルがクラス0でのみ成功し、+ 1と-1の間のケースの約50%で失敗することを示しています。これは、クラスの不均衡の症状です。層別サンプリングを使用するか、+ 1および-1クラスの例を使用して分類子をトレーニングします。OAA(One-Against-All)またはOAO(One-Against-One)分類子を実行できます。各クラスに1つずつ、3つのOAA分類子を試してください。最後にそれらのスコアを教えてください。


6
参考までに、scikit 0.16.1では、max_featuresのデフォルトは「auto」であり、「auto」がsqrt(number_features)に変換される1ではありません。
firefly2442 2015年

1

良いoob_scoreのようなものはありません、それはvalid_scoreとoob_scoreの違いが重要です。

oob_scoreをトレーニングセットの一部のサブセット(たとえば、oob_set)のスコアと考えてください。作成方法については、こちらを参照してください

oob_setはトレーニングセットから取得されます。そして、あなたはすでにあなたの検証セット(例えば、valid_set)を持っています。

validation_scoreが0.7365でoob_scoreが0.8329であるシナリオを想定してみましょう

このシナリオでは、トレーニングデータセットから直接取得したoob_setでモデルのパフォーマンスが向上しています。示すように、validation_setは別の期間のものです。(たとえば、training_setには「January」の月のレコードがあり、validation_setには「July」の月のレコードがあります)。したがって、モデルのパフォーマンスのテストではなく、oob_scoreは「Validation_setの代表性」のテストです。

モデルのパフォーマンスの指標としてスコアが使用されるため、適切な代表的なvalidation_setがあることを常に確認する必要があります。したがって、あなたの目標は、oob_scoreとvalid_scoreの違いをできるだけ少なくすることです。

私は通常、validation_scoreとoob_scoreを使用して、validation_setがどの程度優れているかを確認します。このテクニックはジェレミーハワードから学びました。


0

Q: sklearn、3クラス分類のランダムフォレストの適切なOobスコアは何ですか?

A:依存します。私の見解では、学習とテストのサンプルが同じ分布から抽出される場合、私の見解では、OOBは約3倍の相互検証に相当します。したがって、同じ質問を繰り返しますが、「3倍の相互検証」を使用すると、答えは同じになります。実際のテストサンプルは、別のディストリビューションのものです。」

データセットをくれませんか?私はそれを少し楽しんで、無料で何ができるかを教えてくれます。


0

質問に対する別の見方:まず第一に、あなたはあなたが行うすべての誤分類に損失を関連付ける必要があります。誤分類に対するこの支払い済み/損失/ペナルティは、(おそらく)False Positive(FP)とFalse Negatives(FN)で異なります。一部の分類では、がんの検出は、FNよりもFPの方が多いとされています。スパムフィルターと呼ばれる他のいくつかは、友達からのメール(FP)をブロックするよりも、特定のスパム(FN)を許可します。このロジックに基づいて、目的に合ったF1スコアまたは精度を使用できます(たとえば、スパムフィルターにFPがなく、スコアが.1であれば、心配するスパムが10%少ないので、私は満足できるでしょう。 。一方、他の誰かは.9でも不満を感じる可能性があります(スパムの90%がフィルター処理されます。その場合、良いスコアは何でしょうか?)

弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.