我有两个数据帧。
第一个看起来像这样(通道的数量将根据类型而变化)此数据帧存储设备的类型和每个通道的值。
+-----+----------+----------+
| Type|X_ChannelA|Y_ChannelB|
+-----+----------+----------+
|TypeA| 11| 20|
+-----+----------+----------+
第二个数据帧是从csv导入的,由我生成。
现在我有这个格式(可以更改为任何需要的格式)
+-----+--------------+--------------+--------------+--------------+
| Type|X_ChannelA_min|X_ChannelA_max|Y_ChannelB_min|Y_ChannelB_max|
+-----+--------------+--------------+--------------+--------------+
|TypeA| 8| 12| 9| 13|
+-----+--------------+--------------+--------------+--------------+
现在,我想将实际通道值与最小值和最大值进行比较,并创建一个状态为\u的新列,如果该值介于最小值和最大值之间,则该列包含一个值,如果该值超过最小值或最大值,则该列包含一个零。
此示例的所需结果
+-----+----------+----------+-----------------+-----------------+
| Type|X_ChannelA|Y_ChannelB|X_ChannelA_status|Y_ChannelB_status|
+-----+----------+----------+-----------------+-----------------+
|TypeA| 11| 20| 1| 0|
+-----+----------+----------+-----------------+-----------------+
代码在此处:
val df_orig = spark.sparkContext.parallelize(Seq(
("TypeA", 11, 20)
)).toDF("Type", "X_ChannelA", "Y_ChannelB")
val df_def = spark.sparkContext.parallelize(Seq(
("TypeA", 8, 12, 9, 13)
)).toDF("Type", "X_ChannelA_min", "X_ChannelA_max", "Y_ChannelB_min", "Y_ChannelB_max")
我已经尝试了一些不同的事情,但都没有成功。
类似于通过获取所有通道的字符串数组来创建列,然后使用
val pattern = """[XYZP]_Channel.*"""
val fieldNames = df_orig.schema.fieldNames.filter(_.matches(pattern))
fieldNames.foreach(x => df.withColumn(s"${x}_status", <compare logic comes here>)
我的下一个想法是将df\u orig与df\u def连接起来,然后将channel\u value、channel\u min、channel\u max与concat\u ws添加到单个列中,用比较逻辑比较这些值,并将结果写入列中
+-----+----------+----------+----------------+----------------+-------------+...
| Type|X_ChannelA|Y_ChannelB|X_ChannelA_array|Y_ChannelB_array|X_ChannelA_st|
+-----+----------+----------+----------------+----------------+-------------+...
|TypeA| 11| 20| [11, 8, 12]| [20, 9, 13]| 1|
+-----+----------+----------+----------------+----------------+-------------+...
如果有一个更简单的解决方案,最好是朝着正确的方向推进。
编辑:如果我的描述不清楚,基本上我要找的是:
我要找的是
foreach channel in channellist (
ds.withColumn("<channel>_status", when($"<channel>" < $"<channel>_min" || $"<channel>" > $"<channel>_max"), 1).otherwise 0)
)
编辑:我找到了一个解决方案:
val df_joined = df_orig.join(df_def, Seq("Type"))
val pattern = """[XYZP]_Channel.*"""
val fieldNames = df_orig.schema.fieldNames.filter(_.matches(pattern))
val df_newnew = df_joined.select(col("*") +: (fieldNames.map(c => when(col(c) <= col(c+"_min") || col(c) >= col(c+"_max"), 1).otherwise(0).as(c+"_status))): _*)