代码之家  ›  专栏  ›  技术社区  ›  user2811630

基于条件的多列之和

  •  0
  • user2811630  · 技术社区  · 6 年前

    我找到了一种方法,但当我有20多列要总结时,这似乎不是一个好主意,因为它会为每一列生成一个额外的列。

    想要得到的结果是:以“_val”结尾的所有列的值之和,其中值为0或1(或<2,我现在只想排除值3)

     val df1 = Seq(
      ("id1", 1, 0, 3),
      ("id2", 0, 0, 3),
      ("id3", 1, 1, 3))
      .toDF("id", "bla_val", "blub_val", "bli_val")
    

    在sum列中包含所需结果的我的解决方案

    val channelNames = df1.schema.fieldNames.filter(_.endsWith("_val"))
    val ch = channelNames.map(x => col(x+"_redval"))
    
    val df2 = df1.select(col("*") +: (channelNames.map(c =>
      when(col(c) === 1, lit(1))
        .otherwise(lit(0)).as(c+"_redval"))): _*) 
    
    val df3 = df2.withColumn("sum", ch.reduce(_+_))
    df3.show()
    

    示例输出:

    +---+-------+--------+-------+--------------+---------------+--------------+---+ | id|bla_val|blub_val|bli_val|bla_val_redval|blub_val_redval|bli_val_redval|sum| +---+-------+--------+-------+--------------+---------------+--------------+---+ |id1| 1| 0| 3| 1| 0| 0| 1| |id2| 0| 0| 3| 0| 0| 0| 0| |id3| 1| 1| 3| 1| 1| 0| 2| +---+-------+--------+-------+--------------+---------------+--------------+---+

    2 回复  |  直到 6 年前
        1
  •  1
  •   stack0114106    6 年前

    val df1 = Seq(
      ("id1", 1, 0, 3),
      ("id2", 0, 0, 3),
      ("id3", 1, 1, 3))
      .toDF("id", "bla_val", "blub_val", "bli_val")
    
    val newcols= df1.columns.filter(_.endsWith("_val")).map( x=> when(col(x)===1, lit(1)).otherwise(lit(0))).reduce(_+_)
    df1.withColumn("redval_count",newcols).show(false)
    

    输出:

    +---+-------+--------+-------+------------+
    |id |bla_val|blub_val|bli_val|redval_count|
    +---+-------+--------+-------+------------+
    |id1|1      |0       |3      |1           |
    |id2|0      |0       |3      |0           |
    |id3|1      |1       |3      |2           |
    +---+-------+--------+-------+------------+
    
        2
  •  -1
  •   Terry Dactyl    6 年前
    def sumNot3(s: Seq[Int]): Int = {
      s.filter(_ != 3).sum
    }
    
    val sumNot3Udf = udf(sumNot3(_: Seq[Int]))
    
    val channelNameCols = df1.schema.fieldNames.filter(_.endsWith("_val")).map(c => col(c))
    
    df1.select(sumNot3Udf(array(channelNameCols: _*)).as("sum"))