代码之家  ›  专栏  ›  技术社区  ›  Marsellus Wallace

Spark数据帧筛选器在随机情况下未按预期工作

  •  0
  • Marsellus Wallace  · 技术社区  · 5 年前

    这是我的数据帧

    df.groupBy($"label").count.show
    +-----+---------+                                                               
    |label|    count|
    +-----+---------+
    |  0.0|400000000|
    |  1.0| 10000000|
    +-----+---------+
    

    我尝试对label==0.0的记录进行子采样,如下所示:

    val r = scala.util.Random
    val df2 = df.filter($"label" === 1.0 || r.nextDouble > 0.5) // keep 50% of 0.0
    

    我的输出如下所示:

    df2.groupBy($"label").count.show
    +-----+--------+                                                                
    |label|   count|
    +-----+--------+
    |  1.0|10000000|
    +-----+--------+
    
    0 回复  |  直到 5 年前
        1
  •  4
  •   10465355 user11020637    5 年前

    r.nextDouble 是表达式中的常量,因此实际的计算结果与您的意思大不相同。根据实际采样值

    scala> r.setSeed(0)
    
    scala> $"label" === 1.0 || r.nextDouble > 0.5
    res0: org.apache.spark.sql.Column = ((label = 1.0) OR true)
    

    scala> r.setSeed(4096)
    
    scala> $"label" === 1.0 || r.nextDouble > 0.5
    res3: org.apache.spark.sql.Column = ((label = 1.0) OR false)
    

    true
    

    (保存所有记录)或

    label = 1.0 
    

    要生成随机数,您应该使用 corresponding SQL function

    scala> import org.apache.spark.sql.functions.rand
    import org.apache.spark.sql.functions.rand
    
    scala> $"label" === 1.0 || rand > 0.5
    res1: org.apache.spark.sql.Column = ((label = 1.0) OR (rand(3801516599083917286) > 0.5))
    

    尽管Spark已经提供了分层抽样工具:

    df.stat.sampleBy(
      "label",  // column
      Map(0.0 -> 0.5, 1.0 -> 1.0),  // fractions
      42 // seed 
    )