代码之家  ›  专栏  ›  技术社区  ›  road_to_quantdom

如何从data.tables列表中计算每组的平均值?

  •  0
  • road_to_quantdom  · 技术社区  · 6 年前

    假设我们有一个 data.table 就像这样:

    dt <- data.table(x=rnorm(10^6,100,10), letters=sample(LETTERS,10^6,T))
    myList <- list(dt1=dt,dt2=dt,dt3=dt,dt4=dt,dt5=dt)
    

    如果我想要一个能够计算所有数据表中每个组的平均值的解决方案,我可以执行以下操作:

    bigDT <- rbindlist(myList)
    bigDT[,list('average'=mean(x)),by=letters]
    

    但是,我的数据,每个 dt 相当大(数百万行),每个列表也相当大(500-1000行) dt 在每个列表中)。对于 by 选择。

    我计划使用遗传算法进行优化的函数的一部分需要按组计算上述平均值。我想知道是否有比 rbind -使用前对列表进行删除 数据表 每个组的计算能力?否则,最大化算法将对这个潜在的瓶颈计算进行许多函数调用。

    任何帮助都将不胜感激!

    microbenchmark(doThis())
    Unit: milliseconds
         expr     min       lq     mean   median       uq      max neval
     doThis() 151.512 154.3395 174.8071 167.7151 170.2952 440.9359   100
    
    1 回复  |  直到 6 年前
        1
  •  2
  •   r2evans    6 年前

    一种方法是计算列表中每个表的分组平均值,然后绑定,然后计算 weighted mean 他们当中。由于每个字母的计数不同,因此需要保留 .N 也。

    我将更改列表中的每个元素,以便我们可以验证加权平均值计算。对于再现性:

    set.seed(1)
    myList <- replicate(5, data.table(x=rnorm(10^6,100,10), letters=sample(LETTERS,10^6,T)),
                        simplify=FALSE)
    myList[1:2]
    # [[1]]
    #                  x letters
    #       1:  93.73546       P
    #       2: 101.83643       I
    #       3:  91.64371       F
    #       4: 115.95281       V
    #       5: 103.29508       D
    #      ---                  
    #  999996: 109.24487       Q
    #  999997:  99.86486       K
    #  999998:  93.95941       J
    #  999999: 116.28763       O
    # 1000000: 106.93750       E
    # [[2]]
    #                  x letters
    #       1:  97.53576       R
    #       2: 105.27503       T
    #       3: 107.53592       L
    #       4: 102.21228       M
    #       5:  98.71087       G
    #      ---                  
    #  999996: 109.46843       C
    #  999997:  99.14458       M
    #  999998:  96.76845       Y
    #  999999:  94.22413       E
    # 1000000:  98.25855       K
    

    只对一张桌子执行此操作:

    head(myList[[1]][,.(mu = mean(x), n = .N), keyby=letters])
    #    letters        mu     n
    # 1:       A 100.04987 39005
    # 2:       B 100.01288 38576
    # 3:       C  99.97402 38547
    # 4:       D  99.99909 38460
    # 5:       E 100.03689 38030
    # 6:       F 100.02697 38293
    

    首先,计算每个列表元素的平均值:

    myAgg <- rbindlist(lapply(myList, function(d) d[,.(mu = mean(x), n = .N), keyby="letters"]))
    

    现在手动或使用加权平均值 Hmisc::wtd.mean :

    cbind(
      # just to verify the below answer is the same as the brute-force method of rbind-then-average
      rbindlist(myList)[,.(mu = mean(x)), keyby=letters],
      # either of these is your answer
      myAgg[,.(mu = sum(n*mu)/sum(n)),keyby=letters],
      myAgg[,.(mu = Hmisc::wtd.mean(mu, weights=n)),keyby=letters]
    )
    #     letters        mu letters        mu letters        mu
    #  1:       A 100.02325       A 100.02325       A 100.02325
    #  2:       B 100.03473       B 100.03473       B 100.03473
    #  3:       C 100.00688       C 100.00688       C 100.00688
    #  4:       D 100.04041       D 100.04041       D 100.04041
    #  5:       E 100.00780       E 100.00780       E 100.00780
    #  6:       F 100.01202       F 100.01202       F 100.01202
    #  7:       G 100.01200       G 100.01200       G 100.01200
    #  8:       H  99.97232       H  99.97232       H  99.97232
    #  9:       I 100.00495       I 100.00495       I 100.00495
    # 10:       J 100.03019       J 100.03019       J 100.03019
    # 11:       K  99.96851       K  99.96851       K  99.96851
    # 12:       L 100.01850       L 100.01850       L 100.01850
    # 13:       M 100.00976       M 100.00976       M 100.00976
    # 14:       N 100.01299       N 100.01299       N 100.01299
    # 15:       O 100.02108       O 100.02108       O 100.02108
    # 16:       P 100.02052       P 100.02052       P 100.02052
    # 17:       Q 100.03814       Q 100.03814       Q 100.03814
    # 18:       R  99.99013       R  99.99013       R  99.99013
    # 19:       S  99.95219       S  99.95219       S  99.95219
    # 20:       T  99.97721       T  99.97721       T  99.97721
    # 21:       U  99.96310       U  99.96310       U  99.96310
    # 22:       V  99.94430       V  99.94430       V  99.94430
    # 23:       W  99.98877       W  99.98877       W  99.98877
    # 24:       X 100.07352       X 100.07352       X 100.07352
    # 25:       Y  99.96677       Y  99.96677       Y  99.96677
    # 26:       Z  99.99397       Z  99.99397       Z  99.99397
    #     letters        mu letters        mu letters        mu
    

    快速基准测试,用于比较:

    library(microbenchmark)
    microbenchmark(
      bruteforce = rbindlist(myList)[,.(mu = mean(x)), keyby=letters],
      # either of these is your answer
      baseR = {
        myAgg <- rbindlist(lapply(myList, function(d) d[,.(mu = mean(x), n = .N), keyby="letters"]))
        myAgg[,.(mu = sum(n*mu)/sum(n)),keyby=letters]
      },
      Hmisc =  {
        myAgg <- rbindlist(lapply(myList, function(d) d[,.(mu = mean(x), n = .N), keyby="letters"]))
        myAgg[,.(mu = Hmisc::wtd.mean(mu, weights=n)),keyby=letters]
      },
      times=50
    )
    # Unit: milliseconds
    #        expr      min       lq      mean    median       uq      max neval
    #  bruteforce 131.8770 139.4562 153.93202 151.95375 159.6329 315.6117    50
    #       baseR  89.7047  93.3623 109.20174  98.11670 115.0171 268.2517    50
    #       Hmisc  89.2784  91.5927  97.87455  93.73475  98.1655 119.2671    50
    
    推荐文章