代码之家  ›  专栏  ›  技术社区  ›  Masterbuilder

用一个hotecoder表示的Spark结构

  •  0
  • Masterbuilder  · 技术社区  · 6 年前

    我有一个有两列的数据框,

    +---+-------+
    | id|  fruit|
    +---+-------+
    |  0|  apple|
    |  1| banana|
    |  2|coconut|
    |  1| banana|
    |  2|coconut|
    +---+-------+
    

    而且我有一个包含所有物品的通用清单,

    fruitList: Seq[String] = WrappedArray(apple, coconut, banana)
    

    现在,我想在数据框中创建一个新列,该数组包含1个、0个数组,其中1个表示存在的项,如果该项不存在该行,则表示0个。

    期望输出

        +---+-----------+
        | id|  fruitlist|
        +---+-----------+
        |  0|  [1,0,0]  |
        |  1| [0,1,0]   |
        |  2|[0,0,1]    |
        |  1| [0,1,0]   |
        |  2|[0,0,1]    |
        +---+-----------+
    

    这是我试过的,

    import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
    
    val df = spark.createDataFrame(Seq(
      (0, "apple"),
      (1, "banana"),
      (2, "coconut"),
      (1, "banana"),
      (2, "coconut")
    )).toDF("id", "fruit")
    
    df.show
    import org.apache.spark.sql.functions._
    val fruitList = df.select(collect_set("fruit")).first().getAs[Seq[String]](0)
    print(fruitList)
    

    我试图用一个hotecoder来解决这个问题,但是在转换成稠密向量之后,结果是这样的,这不是我所需要的。

        +---+-------+----------+-------------+---------+
    | id|  fruit|fruitIndex|     fruitVec|       vd|
    +---+-------+----------+-------------+---------+
    |  0|  apple|       2.0|    (2,[],[])|[0.0,0.0]|
    |  1| banana|       1.0|(2,[1],[1.0])|[0.0,1.0]|
    |  2|coconut|       0.0|(2,[0],[1.0])|[1.0,0.0]|
    |  1| banana|       1.0|(2,[1],[1.0])|[0.0,1.0]|
    |  2|coconut|       0.0|(2,[0],[1.0])|[1.0,0.0]|
    +---+-------+----------+-------------+---------+
    
    1 回复  |  直到 6 年前
        1
  •  3
  •   Ramesh Maharjan    6 年前

    如果你的收藏是

    val fruitList: Seq[String] = Array("apple", "coconut", "banana")
    

    那你要么用 内置函数 自定义项函数

    内置函数(数组、when和lit)

    import org.apache.spark.sql.functions._
    df.withColumn("fruitList", array(fruitList.map(x => when(lit(x) === col("fruit"),1).otherwise(0)): _*)).show(false)
    

    自定义项函数

    import org.apache.spark.sql.functions._
    def containedUdf = udf((fruit: String) => fruitList.map(x => if(x == fruit) 1 else 0))
    
    df.withColumn("fruitList", containedUdf(col("fruit"))).show(false)
    

    它应该给你

    +---+-------+---------+
    |id |fruit  |fruitList|
    +---+-------+---------+
    |0  |apple  |[1, 0, 0]|
    |1  |banana |[0, 0, 1]|
    |2  |coconut|[0, 1, 0]|
    |1  |banana |[0, 0, 1]|
    |2  |coconut|[0, 1, 0]|
    +---+-------+---------+
    

    udf函数很容易理解和直接处理原始数据类型,但是如果可以使用优化的快速内置函数来执行相同的任务,则应该避免使用udf函数。

    我希望答案对你有帮助