代码之家  ›  专栏  ›  技术社区  ›  snark

如何转换来自RDD的熊猫数据帧。将分区()映射到Spark DataFrame?

  •  3
  • snark  · 技术社区  · 7 年前

    我有一个Python函数,它返回熊猫数据帧。我在Spark 2.2.0中使用pyspark调用此函数 RDD.mapPartitions() . 但我无法转换返回的RDD mapPartitions() 进入Spark数据框。Pandas生成此错误:

    ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
    

    说明问题的简单代码:

    import pandas as pd
    
    def func(data):
        pdf = pd.DataFrame(list(data), columns=("A", "B", "C"))
        pdf += 10 # Add 10 to every value. The real function is a lot more complex!
        return [pdf]
    
    pdf = pd.DataFrame([(1.87, 0.6, 7.1), (-0.3, 0.1, 8.2), (2.8, 0.3, 6.1), (-0.2, 0.5, 5.9)], columns=("A", "B", "C"))
    
    sdf = spark.createDataFrame(pdf)
    sdf.show()
    rddIn = sdf.rdd
    
    for i in rddIn.collect():
        print(i)
    
    result = rddIn.mapPartitions(func)
    
    for i in result.collect():
        print(i)
    
    resDf = spark.createDataFrame(result) # --> ValueError!
    resDf.show()
    

    输出为:

    +----+---+---+
    |   A|  B|  C|
    +----+---+---+
    |1.87|0.6|7.1|
    |-0.3|0.1|8.2|
    | 2.8|0.3|6.1|
    |-0.2|0.5|5.9|
    +----+---+---+
    Row(A=1.87, B=0.6, C=7.1)
    Row(A=-0.3, B=0.1, C=8.2)
    Row(A=2.8, B=0.3, C=6.1)
    Row(A=-0.2, B=0.5, C=5.9)
           A     B     C
    0  11.87  10.6  17.1
         A     B     C
    0  9.7  10.1  18.2
          A     B     C
    0  12.8  10.3  16.1
         A     B     C
    0  9.8  10.5  15.9
    

    但倒数第二行生成 ValueError 如上所述。我真的很想 resDf.show() 看起来完全一样 sdf.show() 表中每个值加10除外。理想情况下 result RDD的结构应与 rddIn ,RDD将进入 mapPartitions() .

    1 回复  |  直到 7 年前
        1
  •  5
  •   Giannis    6 年前

    您必须将数据转换为标准Python类型并展平:

    resDf = spark.createDataFrame(
        result.flatMap(lambda df: (r.tolist() for r in df.to_records()))
    )
    
    resDF.show()
    # +---+------------------+----+----+                                              
    # | _1|                _2|  _3|  _4|
    # +---+------------------+----+----+
    # |  0|11.870000000000001|10.6|17.1|
    # |  0|               9.7|10.1|18.2|
    # |  0|              12.8|10.3|16.1|
    # |  0|               9.8|10.5|15.9|
    # +---+------------------+----+----+
    

    如果使用Spark 2.3,也应该可以

    from pyspark.sql.functions import pandas_udf, spark_partition_id
    from pyspark.sql.functions import PandasUDFType
    
    @pandas_udf(sdf.schema, functionType=PandasUDFType.GROUPED_MAP)  
    def func(pdf):
        pdf += 10 
        return pdf
    
    sdf.groupBy(spark_partition_id().alias("_pid")).apply(func)