代码之家  ›  专栏  ›  技术社区  ›  anon01

基于条件对某个组的转换创建新的列

  •  1
  • anon01  · 技术社区  · 6 年前

    是否有一种更有效的方法可以在分组后执行以下操作?

    对于每一个 group ,我想得到最大值 value 为此 time 等于3

    import numpy as np
    import pandas as pd
    
    
    d = dict(group=[1,1,1,1,1,2,2,2,2,2,3,3,3,3,3], times=[0,1,2,3,4]*3, values=np.random.rand(15))
    df = pd.DataFrame.from_dict(d)
    
    # e.g.:
    
        group  times    values
    0       1      0  0.277623
    1       1      1  0.227311
    2       1      2  0.798941
    3       1      3  0.861006
    4       1      4  0.486385
    5       2      0  0.543527
    6       2      1  0.347159
    7       2      2  0.138165
    8       2      3  0.152132
    9       2      4  0.402830
    10      3      0  0.688038
    11      3      1  0.450904
    12      3      2  0.351267
    13      3      3  0.195594
    14      3      4  0.834823
    

    以下内容似乎有效,但有点慢,而且不太简洁:

    for label, group in df.groupby(['group']):
        rows = group.index
        df.loc[rows,'new_value'] = group.loc[group.time <= 3, 'values'].max()
    
    1 回复  |  直到 6 年前
        1
  •  1
  •   cs95 abhishek58g    6 年前

    认为 你可以使用 where 分组前。为了获得更好的性能,请使用 transform :

    df['new_value'] = df['values'].where(df.times < 3).groupby(df.group).transform('max')    
    df
    
        group  times    values  new_value
    0       1      0  0.271137   0.751412
    1       1      1  0.262456   0.751412
    2       1      2  0.751412   0.751412
    3       1      3  0.364099   0.751412
    4       1      4  0.462447   0.751412
    5       2      0  0.022403   0.792396
    6       2      1  0.792396   0.792396
    7       2      2  0.181434   0.792396
    8       2      3  0.106931   0.792396
    9       2      4  0.226425   0.792396
    10      3      0  0.425845   0.535085
    11      3      1  0.527567   0.535085
    12      3      2  0.535085   0.535085
    13      3      3  0.194340   0.535085
    14      3      4  0.958947   0.535085
    

    这正是您当前代码返回的结果。


    哪里 确保我们不考虑时间3的值,因为 max 忽视NANS。这个 groupby 根据此中间结果计算。

    df['values'].where(df.times <= 3)
    
    0     0.271137
    1     0.262456
    2     0.751412
    3     0.364099
    4          NaN
    5     0.022403
    6     0.792396
    7     0.181434
    8     0.106931
    9          NaN
    10    0.425845
    11    0.527567
    12    0.535085
    13    0.194340
    14         NaN
    Name: values, dtype: float64