代码之家  ›  专栏  ›  技术社区  ›  Qubix

pyspark-选择至少连续两天看到的用户

  •  0
  • Qubix  · 技术社区  · 6 年前

    我有一个数据框架 dataframe_actions 使用字段: user_id , action , day . 这个 用户标识 对于每个用户和 白天 取1到31之间的值。我只想筛选至少连续两天看到的用户,例如:

    如果在第1天、第2天、第4天、第8天、第9天看到一个用户,我希望保留它们,因为它们至少在连续两天看到。

    我现在所做的是笨重和非常缓慢(而且似乎不起作用):

    df_final = spark.sql(""" with t1( select user_id, day, row_number()
               over(partition by user_id order by day)-day diff from dataframe_actions), 
               t2( select user_id, day, collect_set(diff) over(partition by user_id) diff2 from t1) 
               select user_id, day from t2 where size(diff2) > 2""")
    

    我不知道怎么解决这个问题。

    编辑:

    | user_id | action | day |
    --------------------------
    | asdc24  | conn   |  1  |
    | asdc24  | conn   |  2  |
    | asdc24  | conn   |  5  |
    | adsfa6  | conn   |  1  |
    | adsfa6  | conn   |  3  |
    | asdc24  | conn   |  9  |
    | adsfa6  | conn   |  5  |
    | asdc24  | conn   |  11 |
    | adsfa6  | conn   |  10 |
    | asdc24  | conn   |  15 |
    

    应该返回

    | user_id | action | day |
    --------------------------
    | asdc24  | conn   |  1  |
    | asdc24  | conn   |  2  |
    | asdc24  | conn   |  5  |
    | asdc24  | conn   |  9  |
    | asdc24  | conn   |  11 |
    | asdc24  | conn   |  15 |
    

    因为至少连续两天(第1天和第2天)只连接了此用户。

    2 回复  |  直到 6 年前
        1
  •  1
  •   Vamsi Prabhala    6 年前

    使用 lag 要为每个用户获取前一天,请从当前行的日期中减去前一天,然后检查其中是否至少有一天是1。这是用来做的 group by 和A filter 此后。

    from pyspark.sql import functions as f
    from pyspark.sql import Window
    w = Window.partitionBy(dataframe_actions.user_id).orderBy(dataframe_actions.day)
    user_prev = dataframe_actions.withColumn('prev_day_diff',dataframe_actions.day-f.lag(dataframe_actions.day).over(w))
    res = user_prev.groupBy(user_prev.user_id).agg(f.sum(f.when(user_prev.prev_day_diff==1,1).otherwise(0)).alias('diff_1'))
    res.filter(res.diff_1 >= 1).show()
    

    另一种方法是行数差法。这将启用对给定用户的所有列的选择。

    w = Window.partitionBy(dataframe_actions.user_id).orderBy(dataframe_actions.day)
    rownum_diff = dataframe_actions.withColumn('rdiff',day-f.row_number().over(w))
    w1 = Window.partitionBy(rownum_diff.user_id)
    counts_per_user = rownum_diff.withColumn('cnt',f.sum(f.when(rownum_diff.rdiff == 1,1).otherwise(0)).over(w1))
    cols_to_select = ['user_id','action','day']
    counts_per_user.filter(counts_per_user.cnt >= 1).select(*cols_to_select).show()
    
        2
  •  1
  •   stack0114106    6 年前

    另一种使用给定输入的SQL方法。

    圣火

    >>> from pyspark.sql.functions import *
    >>> df = sc.parallelize([("asdc24","conn",1),
    ... ("asdc24","conn",2),
    ... ("asdc24","conn",5),
    ... ("adsfa6","conn",1),
    ... ("adsfa6","conn",3),
    ... ("asdc24","conn",9),
    ... ("adsfa6","conn",5),
    ... ("asdc24","conn",11),
    ... ("adsfa6","conn",10),
    ... ("asdc24","conn",15)]).toDF(["user_id","action","day"])
    >>> df.createOrReplaceTempView("qubix")
    >>> spark.sql(" select * from qubix order by user_id, day").show()
    +-------+------+---+
    |user_id|action|day|
    +-------+------+---+
    | adsfa6|  conn|  1|
    | adsfa6|  conn|  3|
    | adsfa6|  conn|  5|
    | adsfa6|  conn| 10|
    | asdc24|  conn|  1|
    | asdc24|  conn|  2|
    | asdc24|  conn|  5|
    | asdc24|  conn|  9|
    | asdc24|  conn| 11|
    | asdc24|  conn| 15|
    +-------+------+---+
    
    >>> spark.sql(""" with t1 (select user_id,action, day,lead(day) over(partition by user_id order by day) ld from qubix), t2 (select user_id from t1 where ld-t1.day=1 ) select * from qubix where user_id in (select user_id from t2) """).show()
    +-------+------+---+
    |user_id|action|day|
    +-------+------+---+
    | asdc24|  conn|  1|
    | asdc24|  conn|  2|
    | asdc24|  conn|  5|
    | asdc24|  conn|  9|
    | asdc24|  conn| 11|
    | asdc24|  conn| 15|
    +-------+------+---+
    
    >>>
    

    斯卡拉版本

    scala> val df = Seq(("asdc24","conn",1),
         | ("asdc24","conn",2),
         | ("asdc24","conn",5),
         | ("adsfa6","conn",1),
         | ("adsfa6","conn",3),
         | ("asdc24","conn",9),
         | ("adsfa6","conn",5),
         | ("asdc24","conn",11),
         | ("adsfa6","conn",10),
         | ("asdc24","conn",15)).toDF("user_id","action","day")
    df: org.apache.spark.sql.DataFrame = [user_id: string, action: string ... 1 more field]
    
    scala> df.orderBy('user_id,'day).show(false)
    +-------+------+---+
    |user_id|action|day|
    +-------+------+---+
    |adsfa6 |conn  |1  |
    |adsfa6 |conn  |3  |
    |adsfa6 |conn  |5  |
    |adsfa6 |conn  |10 |
    |asdc24 |conn  |1  |
    |asdc24 |conn  |2  |
    |asdc24 |conn  |5  |
    |asdc24 |conn  |9  |
    |asdc24 |conn  |11 |
    |asdc24 |conn  |15 |
    +-------+------+---+
    
    
    scala> df.createOrReplaceTempView("qubix")
    
    scala> spark.sql(""" with t1 (select user_id,action, day,lead(day) over(partition by user_id order by day) ld from qubix), t2 (select user_id fro
    m t1 where ld-t1.day=1 ) select * from qubix where user_id in (select user_id from t2) """).show(false)
    +-------+------+---+
    |user_id|action|day|
    +-------+------+---+
    |asdc24 |conn  |1  |
    |asdc24 |conn  |2  |
    |asdc24 |conn  |5  |
    |asdc24 |conn  |9  |
    |asdc24 |conn  |11 |
    |asdc24 |conn  |15 |
    +-------+------+---+
    
    
    scala>