Spark DataFrame列をPythonリストに変換する


103

私は、mvvとcountの2つの列を持つデータフレームで作業します。

+---+-----+
|mvv|count|
+---+-----+
| 1 |  5  |
| 2 |  9  |
| 3 |  3  |
| 4 |  1  |

mvv値とカウント値を含む2つのリストを取得したいと思います。何かのようなもの

mvv = [1,2,3,4]
count = [5,9,3,1]

だから、私は次のコードを試してみました:最初の行は行のpythonリストを返すはずです。私は最初の値を見たかった:

mvv_list = mvv_count_df.select('mvv').collect()
firstvalue = mvv_list[0].getInt(0)

しかし、2行目にエラーメッセージが表示されます。

AttributeError:getInt


Spark 2.3以降、このコードは最速であり、OutOfMemory例外を引き起こす可能性は最も低くなりますlist(df.select('mvv').toPandas()['mvv'])ArrowはtoPandas大幅にスピードアップしたPySparkに統合されました。Spark 2.3以降を使用している場合は、他のアプローチを使用しないでください。ベンチマークの詳細については、私の回答を参照してください。
パワーズ

回答:


140

ご覧のように、あなたが行っているこの方法が機能しないのはなぜですか。まず、タイプから整数を取得しようとしています。収集の出力は次のようになります。

>>> mvv_list = mvv_count_df.select('mvv').collect()
>>> mvv_list[0]
Out: Row(mvv=1)

あなたがこのようなものを取るなら:

>>> firstvalue = mvv_list[0].mvv
Out: 1

mvv値を取得します。配列のすべての情報が必要な場合は、次のようにすることができます。

>>> mvv_array = [int(row.mvv) for row in mvv_list.collect()]
>>> mvv_array
Out: [1,2,3,4]

しかし、他の列で同じことを試みると、次のようになります。

>>> mvv_count = [int(row.count) for row in mvv_list.collect()]
Out: TypeError: int() argument must be a string or a number, not 'builtin_function_or_method'

これcountは、が組み込みメソッドであるために発生します。また、列の名前はと同じcountです。これを回避するには、の列名をcountに変更します_count

>>> mvv_list = mvv_list.selectExpr("mvv as mvv", "count as _count")
>>> mvv_count = [int(row._count) for row in mvv_list.collect()]

ただし、ディクショナリ構文を使用して列にアクセスできるため、この回避策は必要ありません。

>>> mvv_array = [int(row['mvv']) for row in mvv_list.collect()]
>>> mvv_count = [int(row['count']) for row in mvv_list.collect()]

そして、それは最終的に動作します!


最初の列ではうまく機能しますが、(sparkの関数カウント)のため、私が考える列カウントでは機能しません
a.moussa

カウントで何をしているのですか?コメントにここに追加してください。
チアゴバルディム2016

あなたの応答に感謝したがって、この行はmvv_count.select( 'mvv')。collect()]でmvv_list = [int(i.mvv)for iが機能しますが、mvv_countでiは機能しませんcount_list = [int(i.count)for i .select( 'count')。collect()]が無効な構文を返す
a.moussa

このselect('count')ようにこの使用法を追加する必要はありませんcount_list = [int(i.count) for i in mvv_list.collect()]。応答に例を追加します。
チアゴバルディム2016

1
@ a.moussa [i.['count'] for i in mvv_list.collect()]は、count関数ではなく「count」という名前の列を使用することを明示的にするように機能します
user989762

103

1つのライナーに続いて、必要なリストを提供します。

mvv = mvv_count_df.select("mvv").rdd.flatMap(lambda x: x).collect()

3
このソリューションのパフォーマンスは、ソリューションよりもはるかに高速です。mvv_list= [int(i.mvv)for i in mvv_count.select( 'mvv')。collect()]
Chanaka Fernando

これは私が見た中で断然最高のソリューションです。ありがとう。
hui chen

22

これにより、すべての要素がリストとして表示されます。

mvv_list = list(
    mvv_count_df.select('mvv').toPandas()['mvv']
)

1
これは、Spark 2.3以降で最も高速で効率的なソリューションです。私の回答でベンチマーク結果をご覧ください。
パワーズ

15

次のコードはあなたを助けます

mvv_count_df.select('mvv').rdd.map(lambda row : row[0]).collect()

3
これは受け入れられる答えになるはずです。その理由は、プロセス全体を通してスパークコンテキストにとどまり、その後、スパークコンテキストから抜け出すのではなく、最後に収集するためです。
AntiPawn79

15

私のデータについて、私はこれらのベンチマークを得ました:

>>> data.select(col).rdd.flatMap(lambda x: x).collect()

0.52秒

>>> [row[col] for row in data.collect()]

0.271秒

>>> list(data.select(col).toPandas()[col])

0.427秒

結果は同じです


1
toLocalIterator代わりに使用するcollectと、メモリ効率がさらに向上するはずです[row[col] for row in data.toLocalIterator()]
oglop

5

以下のエラーが発生した場合:

AttributeError: 'list' object has no attribute 'collect'

このコードはあなたの問題を解決します:

mvv_list = mvv_count_df.select('mvv').collect()

mvv_array = [int(i.mvv) for i in mvv_list]

私もそのエラーが発生し、この解決策で問題が解決しました。しかし、なぜエラーが発生したのですか?(多くの人はそれを
理解し

1

私はベンチマーク分析を行いlist(mvv_count_df.select('mvv').toPandas()['mvv'])、最速の方法です。びっくりしました。

Spark 2.4.5の5ノードi3.xlargeクラスター(各ノードには30.5 GBのRAMと4コア)を使用して、10万/ 1億行のデータセットでさまざまなアプローチを実行しました。データは、単一の列を含む20のSnappy圧縮Parquetファイルに均等に分散されました。

ベンチマーク結果(実行時間(秒))は次のとおりです。

+-------------------------------------------------------------+---------+-------------+
|                          Code                               | 100,000 | 100,000,000 |
+-------------------------------------------------------------+---------+-------------+
| df.select("col_name").rdd.flatMap(lambda x: x).collect()    |     0.4 | 55.3        |
| list(df.select('col_name').toPandas()['col_name'])          |     0.4 | 17.5        |
| df.select('col_name').rdd.map(lambda row : row[0]).collect()|     0.9 | 69          |
| [row[0] for row in df.select('col_name').collect()]         |     1.0 | OOM         |
| [r[0] for r in mid_df.select('col_name').toLocalIterator()] |     1.2 | *           |
+-------------------------------------------------------------+---------+-------------+

* cancelled after 800 seconds

ドライバーノードでデータを収集する際に従うべき黄金律:

  • 他の方法で問題を解決してください。ドライバーノードへのデータ収集はコストが高く、Sparkクラスターの能力を活用しないため、可能な限り回避する必要があります。
  • できるだけ少ない行を収集します。データを収集する前に、列を集計、重複排除、フィルター、および整理します。できるだけ少ないデータをドライバーノードに送信します。

toPandas Spark 2.3では大幅に改善されました。2.3より前のバージョンのSparkを使用している場合、これはおそらく最良の方法ではありません。

詳細とベンチマーク結果については、こちらをご覧ください。


1

考えられる解決策は、のcollect_list()関数を使用することですpyspark.sql.functions。これにより、すべての列の値がpyspark配列に集約され、収集時にpythonリストに変換されます。

mvv_list   = df.select(collect_list("mvv")).collect()[0][0]
count_list = df.select(collect_list("count")).collect()[0][0] 
弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.