代码之家  ›  专栏  ›  技术社区  ›  rnorouzian

减少重叠代码以提高速度

  •  0
  • rnorouzian  · 技术社区  · 3 年前

    我的职责 foo() 工作得很好。但我认为其中有一些冗余/重叠的代码,使它变得有点慢。

    例如 pre pos 对象共享对以下对象的相同调用 pivot_wider() , unnest() 等。

    我的问题是: 考虑到电话之间的重叠 之前 销售时点情报系统 对象,我的代码可以变得更短一些,运行得更快吗?或者,使用BASE R是否有可能实现完全相同的输出?

    library(tidyverse)
    
    foo <- function(data){
    
    pre <- data %>% dplyr::select(n, mpre, sdpre, control, outcome, post) %>% 
      pivot_wider(values_from = c(mpre, sdpre, n), names_from = control, values_fn = list) %>%
      tidyr::unnest(everything()) %>%
      dplyr::select(mpre_FALSE,sdpre_FALSE,n_FALSE,mpre_TRUE,sdpre_TRUE,n_TRUE,outcome) %>% unique %>% 
      rlang::set_names("mT","sdT","nT","mC","sdC","nC","outcome") %>% mutate(time = rep(0,max(row_number())))
    
    pos <- data %>% dplyr::select(n, mpos, sdpos, control, outcome, post) %>% 
      pivot_wider(values_from = c(mpos, sdpos, n), names_from = control, values_fn = list) %>%
      tidyr::unnest(everything()) %>% 
      dplyr::select(mpos_FALSE,sdpos_FALSE,n_FALSE,mpos_TRUE,sdpos_TRUE,n_TRUE,outcome,post) %>% unique %>% 
      rlang::set_names("mT","sdT","nT","mC","sdC","nC","outcome","time") %>% arrange(time)
    
    bind_rows(pre,pos) 
    
    }
    
    ## EXAMPLE OF USE:
    dat <- read.csv("https://raw.githubusercontent.com/rnorouzian/m2/main/f.csv")
    
    lapply(split(dat, dat$study.name), foo)
    
    0 回复  |  直到 3 年前
        1
  •  1
  •   AnilGoyal    3 年前

    根据评论中所述的进一步要求,将解决方案转化为功能 foo ,这可能符合要求

    foo <- function(data, cols){
      
      cols <- c('rev.sign', 'time_wk', 'time_cat')
      
      data %>% group_split(study.name) %>%
        map_dfr(~ .x %>%
                  dplyr::select(all_of(cols), study.name, n, mpre, sdpre, control, outcome, post) %>% 
                  pivot_wider(values_from = c(mpre, sdpre, n), names_from = control, values_fn = list) %>%
                  dplyr::select(all_of(cols), study.name, mT = mpre_FALSE, sdT = sdpre_FALSE,
                                nT = n_FALSE,mC = mpre_TRUE, sdC = sdpre_TRUE, nC = n_TRUE, outcome) %>% 
                  mutate(time = 0)
        ) %>%
        bind_rows(
          data %>% group_split(study.name) %>%
            map_dfr(~ .x %>%
                      dplyr::select(all_of(cols), study.name, n, mpos, sdpos, control, outcome, post) %>% 
                      pivot_wider(values_from = c(mpos, sdpos, n), names_from = control, values_fn = list) %>%
                      dplyr::select(all_of(cols), study.name, mT = mpos_FALSE,sdT = sdpos_FALSE,nT = n_FALSE,mC = mpos_TRUE,
                                    sdC = sdpos_TRUE,nC = n_TRUE,outcome,time = post) %>% unique
            )
        ) %>% unnest(everything()) %>% unique %>% arrange(study.name)
      
    }
    
    library(tidyverse)
    
    dat <- read.csv("https://raw.githubusercontent.com/rnorouzian/m2/main/f.csv")
    cols <- c('rev.sign', 'time_wk', 'time_cat')
    
    foo(dat, cols)
    #> # A tibble: 522 x 12
    #>    rev.sign time_wk time_cat study.name    mT   sdT    nT    mC   sdC    nC
    #>    <lgl>      <int>    <int> <chr>      <dbl> <dbl> <int> <dbl> <dbl> <int>
    #>  1 FALSE          0        1 A1          1.68  1.07    25  1.44  1.08    25
    #>  2 FALSE          2        2 A1          1.68  1.07    25  1.44  1.08    25
    #>  3 FALSE          0        1 A1          7.4   2.22    25  1.08  1.12    25
    #>  4 FALSE          2        2 A1          8.08  1.75    25  1.48  1.08    25
    #>  5 FALSE          0        1 A2         60.3  21.5     13 28.9  13.9     13
    #>  6 FALSE          0        1 A2         82.9  11.4     13 28.9  13.9     13
    #>  7 FALSE          8        4 A2         60.3  21.5     13 28.9  13.9     13
    #>  8 FALSE          8        4 A2         82.9  11.4     13 28.9  13.9     13
    #>  9 FALSE          0        1 A2         74.3  11.4     13 32.5  18.8     13
    #> 10 FALSE          0        1 A2         88.7  11.8     13 32.5  18.8     13
    #> # ... with 512 more rows, and 2 more variables: outcome <int>, time <dbl>
    

    创建于2021年5月11日 reprex package (v2.0.0)


    根据下面注释中所述的要求,建议采用以下策略,但现在输出的行数也会有所不同,因为所有希望包含在输出中的新列都没有相同的值。

    cols <- c('rev.sign', 'time_wk', 'time_cat')
    
    dat %>% group_split(study.name) %>%
      map_dfr(~ .x %>%
                dplyr::select(all_of(cols), study.name, n, mpre, sdpre, control, outcome, post) %>% 
                pivot_wider(values_from = c(mpre, sdpre, n), names_from = control, values_fn = list) %>%
                dplyr::select(all_of(cols), study.name, mT = mpre_FALSE, sdT = sdpre_FALSE,
                              nT = n_FALSE,mC = mpre_TRUE, sdC = sdpre_TRUE, nC = n_TRUE, outcome) %>% 
                mutate(time = 0)
      ) %>%
      bind_rows(
        dat %>% group_split(study.name) %>%
          map_dfr(~ .x %>%
                    dplyr::select(all_of(cols), study.name, n, mpos, sdpos, control, outcome, post) %>% 
                    pivot_wider(values_from = c(mpos, sdpos, n), names_from = control, values_fn = list) %>%
                    dplyr::select(all_of(cols), study.name, mT = mpos_FALSE,sdT = sdpos_FALSE,nT = n_FALSE,mC = mpos_TRUE,
                                  sdC = sdpos_TRUE,nC = n_TRUE,outcome,time = post) %>% unique
          )
      ) %>% unnest(everything()) %>% unique %>% arrange(study.name)
    
    # A tibble: 522 x 12
       rev.sign time_wk time_cat study.name    mT   sdT    nT    mC   sdC    nC outcome  time
       <lgl>      <int>    <int> <chr>      <dbl> <dbl> <int> <dbl> <dbl> <int>   <int> <dbl>
     1 FALSE          0        1 A1          1.68  1.07    25  1.44  1.08    25       1     0
     2 FALSE          2        2 A1          1.68  1.07    25  1.44  1.08    25       1     0
     3 FALSE          0        1 A1          7.4   2.22    25  1.08  1.12    25       1     1
     4 FALSE          2        2 A1          8.08  1.75    25  1.48  1.08    25       1     2
     5 FALSE          0        1 A2         60.3  21.5     13 28.9  13.9     13       1     0
     6 FALSE          0        1 A2         82.9  11.4     13 28.9  13.9     13       1     0
     7 FALSE          8        4 A2         60.3  21.5     13 28.9  13.9     13       1     0
     8 FALSE          8        4 A2         82.9  11.4     13 28.9  13.9     13       1     0
     9 FALSE          0        1 A2         74.3  11.4     13 32.5  18.8     13       1     1
    10 FALSE          0        1 A2         88.7  11.8     13 32.5  18.8     13       1     1
    # ... with 512 more rows
    

    查看相同内容 study.name

    dat %>% group_split(study.name) %>%
      map_dfr(~ .x %>%
                dplyr::select(study.name, n, mpre, sdpre, control, outcome, post) %>% 
                pivot_wider(values_from = c(mpre, sdpre, n), names_from = control, values_fn = list) %>%
                dplyr::select(study.name, mT = mpre_FALSE, sdT = sdpre_FALSE,
                              nT = n_FALSE,mC = mpre_TRUE, sdC = sdpre_TRUE, nC = n_TRUE, outcome) %>% 
                mutate(time = 0)
      ) %>%
      bind_rows(
        dat %>% group_split(study.name) %>%
          map_dfr(~ .x %>%
                    dplyr::select(study.name, n, mpos, sdpos, control, outcome, post) %>% 
                    pivot_wider(values_from = c(mpos, sdpos, n), names_from = control, values_fn = list) %>%
                    dplyr::select(study.name, mT = mpos_FALSE,sdT = sdpos_FALSE,nT = n_FALSE,mC = mpos_TRUE,
                                  sdC = sdpos_TRUE,nC = n_TRUE,outcome,time = post) %>% unique
          )
      ) %>% unnest(everything()) %>% unique %>% arrange(study.name)
    # A tibble: 421 x 9
       study.name    mT   sdT    nT    mC   sdC    nC outcome  time
       <chr>      <dbl> <dbl> <int> <dbl> <dbl> <int>   <int> <dbl>
     1 A1          1.68  1.07    25  1.44  1.08    25       1     0
     2 A1          7.4   2.22    25  1.08  1.12    25       1     1
     3 A1          8.08  1.75    25  1.48  1.08    25       1     2
     4 A2         60.3  21.5     13 28.9  13.9     13       1     0
     5 A2         82.9  11.4     13 28.9  13.9     13       1     0
     6 A2         74.3  11.4     13 32.5  18.8     13       1     1
     7 A2         88.7  11.8     13 32.5  18.8     13       1     1
     8 A2         68.9  15.0     13 39.9  20.8     13       1     2
     9 A2         90.3   8.4     13 39.9  20.8     13       1     2
    10 B1         67.6  19.3     17 51.9  28.3     20       1     0
    # ... with 411 more rows
    

    检查输出 A1 一个输出包含3行,另一个输出中包含4行


    早期答案 这样可以减少一些冗余

    foo1 <- function(data){
      
      pre <- data %>% dplyr::select(n, mpre, sdpre, control, outcome, post) %>% 
        pivot_wider(values_from = c(mpre, sdpre, n), names_from = control, values_fn = list) %>%
        dplyr::select(mT = mpre_FALSE, sdT = sdpre_FALSE,nT = n_FALSE,mC = mpre_TRUE, sdC = sdpre_TRUE, nC = n_TRUE,outcome) %>% 
        unique %>% mutate(time = 0)
      
      pos <- data %>% dplyr::select(n, mpos, sdpos, control, outcome, post) %>% 
        pivot_wider(values_from = c(mpos, sdpos, n), names_from = control, values_fn = list) %>%
        dplyr::select(mT = mpos_FALSE,sdT = sdpos_FALSE,nT = n_FALSE,mC = mpos_TRUE,sdC = sdpos_TRUE,nC = n_TRUE,outcome,time = post) %>% unique
      
      bind_rows(pre,pos) %>% tidyr::unnest(everything())
      
    }
    

    但是,您可以在这样的单个管道中执行此操作。您可以调整 map_* 根据预期的输出格式,即数据帧或列表等。

    dat %>% group_split(study.name) %>%
      map_dfr(~ .x %>%
            dplyr::select(study.name, n, mpre, sdpre, control, outcome, post) %>% 
            pivot_wider(values_from = c(mpre, sdpre, n), names_from = control, values_fn = list) %>%
              dplyr::select(mT = mpre_FALSE, sdT = sdpre_FALSE,nT = n_FALSE,mC = mpre_TRUE, sdC = sdpre_TRUE, nC = n_TRUE,outcome) %>% 
              unique %>% mutate(time = 0)
          ) %>%
      bind_rows(
        dat %>% group_split(study.name) %>%
          map_dfr(~ .x %>%
                    dplyr::select(study.name, n, mpos, sdpos, control, outcome, post) %>% 
                    pivot_wider(values_from = c(mpos, sdpos, n), names_from = control, values_fn = list) %>%
                    dplyr::select(mT = mpos_FALSE,sdT = sdpos_FALSE,nT = n_FALSE,mC = mpos_TRUE,sdC = sdpos_TRUE,nC = n_TRUE,outcome,time = post) %>% unique
                    )
      ) %>% unnest(everything())
    
    # A tibble: 423 x 8
          mT   sdT    nT    mC   sdC    nC outcome  time
       <dbl> <dbl> <int> <dbl> <dbl> <int>   <int> <dbl>
     1  1.68  1.07    25  1.44  1.08    25       1     0
     2 60.3  21.5     13 28.9  13.9     13       1     0
     3 82.9  11.4     13 28.9  13.9     13       1     0
     4 67.6  19.3     17 51.9  28.3     20       1     0
     5 53.1  21.7     18 51.9  28.3     20       1     0
     6 59.4  18.4     20 51.9  28.3     20       1     0
     7 87.1  13.2     12 85.6   8.89    12       1     0
     8 90.6   8.79    27 85.6   8.89    12       1     0
     9 83.1  12.5     12 85.6   8.89    12       1     0
    10 62.5  12.0     13 60.2  17.6     13       1     0
    # ... with 413 more rows
    

    结果 benchmarking 速度

    library(rbenchmark)
    benchmark('reza' = {lapply(split(dat, dat$study.name), foo)},
              'anil' = {lapply(split(dat, dat$study.name), foo1)},
              'purr' = {dat %>% group_split(study.name) %>%
                  map_dfr(~ .x %>%
                            dplyr::select(study.name, n, mpre, sdpre, control, outcome, post) %>% 
                            pivot_wider(values_from = c(mpre, sdpre, n), names_from = control, values_fn = list) %>%
                            dplyr::select(mT = mpre_FALSE, sdT = sdpre_FALSE,nT = n_FALSE,mC = mpre_TRUE, sdC = sdpre_TRUE, nC = n_TRUE,outcome) %>% 
                            unique %>% mutate(time = 0)
                  ) %>%
                  bind_rows(
                    dat %>% group_split(study.name) %>%
                      map_dfr(~ .x %>%
                                dplyr::select(study.name, n, mpos, sdpos, control, outcome, post) %>% 
                                pivot_wider(values_from = c(mpos, sdpos, n), names_from = control, values_fn = list) %>%
                                dplyr::select(mT = mpos_FALSE,sdT = sdpos_FALSE,nT = n_FALSE,mC = mpos_TRUE,sdC = sdpos_TRUE,nC = n_TRUE,outcome,time = post) %>% unique
                      )
                  ) %>% unnest(everything())},
              replications = 10)
    
      test replications elapsed relative user.self sys.self user.child sys.child
    2 anil           10   14.99    1.624     14.94     0.03         NA        NA
    3 purr           10    9.23    1.000      9.17     0.05         NA        NA
    1 reza           10   25.33    2.744     25.23     0.05         NA        NA
    

    结果与 microbenchmark

    Unit: milliseconds
     expr       min        lq     mean    median       uq      max neval
     reza 1388.6720 1480.1678 2127.733 1587.3243 3361.978 3749.096    10
     anil 1055.8224 1060.3372 1328.847 1086.8156 1259.200 3206.637    10
     purr  898.9936  902.0576 1095.040  941.0366  976.386 2451.063    10