代码之家  ›  专栏  ›  技术社区  ›  user2811630

基于列标识符比较多列的值

  •  1
  • user2811630  · 技术社区  · 6 年前

    我有两个数据帧。

    第一个看起来像这样(通道的数量将根据类型而变化)此数据帧存储设备的类型和每个通道的值。

    +-----+----------+----------+
    | Type|X_ChannelA|Y_ChannelB|
    +-----+----------+----------+
    |TypeA|        11|        20|
    +-----+----------+----------+
    

    第二个数据帧是从csv导入的,由我生成。

    现在我有这个格式(可以更改为任何需要的格式)

    +-----+--------------+--------------+--------------+--------------+
    | Type|X_ChannelA_min|X_ChannelA_max|Y_ChannelB_min|Y_ChannelB_max|
    +-----+--------------+--------------+--------------+--------------+
    |TypeA|             8|            12|             9|            13|
    +-----+--------------+--------------+--------------+--------------+
    

    现在,我想将实际通道值与最小值和最大值进行比较,并创建一个状态为\u的新列,如果该值介于最小值和最大值之间,则该列包含一个值,如果该值超过最小值或最大值,则该列包含一个零。

    此示例的所需结果

    +-----+----------+----------+-----------------+-----------------+
    | Type|X_ChannelA|Y_ChannelB|X_ChannelA_status|Y_ChannelB_status|
    +-----+----------+----------+-----------------+-----------------+
    |TypeA|        11|        20|                1|                0|
    +-----+----------+----------+-----------------+-----------------+
    

    代码在此处:

        val df_orig = spark.sparkContext.parallelize(Seq(
          ("TypeA", 11, 20)
        )).toDF("Type", "X_ChannelA", "Y_ChannelB")
    
        val df_def = spark.sparkContext.parallelize(Seq(
          ("TypeA", 8, 12, 9, 13)
        )).toDF("Type", "X_ChannelA_min", "X_ChannelA_max", "Y_ChannelB_min", "Y_ChannelB_max")
    

    我已经尝试了一些不同的事情,但都没有成功。

    类似于通过获取所有通道的字符串数组来创建列,然后使用

    val pattern = """[XYZP]_Channel.*"""
    val fieldNames = df_orig.schema.fieldNames.filter(_.matches(pattern))
    fieldNames.foreach(x => df.withColumn(s"${x}_status", <compare logic comes here>)
    

    我的下一个想法是将df\u orig与df\u def连接起来,然后将channel\u value、channel\u min、channel\u max与concat\u ws添加到单个列中,用比较逻辑比较这些值,并将结果写入列中

    +-----+----------+----------+----------------+----------------+-------------+...
    | Type|X_ChannelA|Y_ChannelB|X_ChannelA_array|Y_ChannelB_array|X_ChannelA_st|
    +-----+----------+----------+----------------+----------------+-------------+...
    |TypeA|        11|        20|     [11, 8, 12]|     [20, 9, 13]|            1|
    +-----+----------+----------+----------------+----------------+-------------+...
    

    如果有一个更简单的解决方案,最好是朝着正确的方向推进。

    编辑:如果我的描述不清楚,基本上我要找的是: 我要找的是

    foreach channel in channellist (
        ds.withColumn("<channel>_status", when($"<channel>" < $"<channel>_min" || $"<channel>" > $"<channel>_max"), 1).otherwise 0)
    )
    

    编辑:我找到了一个解决方案:

    val df_joined = df_orig.join(df_def, Seq("Type"))
    val pattern = """[XYZP]_Channel.*"""
    val fieldNames = df_orig.schema.fieldNames.filter(_.matches(pattern))
    val df_newnew = df_joined.select(col("*") +: (fieldNames.map(c => when(col(c) <= col(c+"_min") || col(c) >= col(c+"_max"), 1).otherwise(0).as(c+"_status))): _*)
    
    1 回复  |  直到 5 年前
        1
  •  2
  •   Ramesh Maharjan    6 年前

    join 就是要走的路。你必须利用 when 功能适当,如下所示

    import org.apache.spark.sql.functions._
    df_orig.join(df_def, Seq("Type"), "left")
      .withColumn("X_ChannelA_status", when(col("X_ChannelA") >= col("X_ChannelA_min") && col("X_ChannelA") <= col("X_ChannelA_max"), 1).otherwise(0))
      .withColumn("Y_ChannelB_status", when(col("Y_ChannelB") >= col("Y_ChannelB_min") && col("Y_ChannelB") <= col("Y_ChannelB_max"), 1).otherwise(0))
      .select("Type", "X_ChannelA", "Y_ChannelB", "X_ChannelA_status", "Y_ChannelB_status")
    

    您应该得到所需的输出

    +-----+----------+----------+-----------------+-----------------+
    |Type |X_ChannelA|Y_ChannelB|X_ChannelA_status|Y_ChannelB_status|
    +-----+----------+----------+-----------------+-----------------+
    |TypeA|11        |20        |1                |0                |
    +-----+----------+----------+-----------------+-----------------+
    

    已更新

    如果您的 通道数据帧 如果您不想像上面提到的那样硬编码所有列,那么您可以从 foldLeft (scala中的强大功能)

    但在此之前,您必须决定要迭代的列(即通道)

    val df_orig_Columns = df_orig.columns
    val columnsToIterate = df_orig_Columns.toSet - "Type"
    

    然后在你之后 参加 他们,使用 foldLeft公司 概括以上内容 withColumn 过程

    val joinedDF = df_orig.join(df_def, Seq("Type"), "left")
    
    import org.apache.spark.sql.functions._
    val finalDF = columnsToIterate.foldLeft(joinedDF){(tempDF, colName) => tempDF.withColumn(colName+"_status", when(col(colName) >= col(colName+"_min") && col(colName) <= col(colName+"_max"), 1).otherwise(0))}
    

    最后你 select 必要的列为

    val finalDFcolumns = df_orig_Columns ++ columnsToIterate.map(_+"_status")
    finalDF.select(finalDFcolumns.map(col): _*)
    

    我想就是这样。希望它不仅有用