代码之家  ›  专栏  ›  技术社区  ›  Sebastian Sauer

基于rstan-MCMC的简单二项式GLM低效率采样

  •  1
  • Sebastian Sauer  · 技术社区  · 6 年前

    我想用 rethinking 包装(利用 rstan MCMC)。

    这个模型是合适的,但是抽样效率很低,而且Rhat表明出了问题。我不明白这个装配问题的原因。

    数据如下:

    d <- read_csv("https://gist.githubusercontent.com/sebastiansauer/a2519de39da49d70c4a9e7a10191cb97/raw/election.csv")
    d <- as.data.frame(dummy)
    

    这是模型:

    m1_stan <- map2stan( 
     alist(
        afd_votes ~ dbinom(votes_total, p),
        logit(p) <- beta0 + beta1*foreigner_n,
        beta0 ~ dnorm(0, 10),
        beta1 ~ dnorm(0, 10)
      ),
      data = d, 
      WAIC = FALSE,
      iter = 1000)
    

    拟合诊断(Rhat,有效样本数)表明出现了问题:

          Mean StdDev lower 0.89 upper 0.89 n_eff Rhat
    beta0 -3.75      0      -3.75      -3.75     3 2.21
    beta1  0.00      0       0.00       0.00    10 1.25
    

    轨迹图没有显示“肥毛毛虫”:

    traceplot m0_stan

    adapt_delta max_treedepth ,我做到了。这在一定程度上改善了取样过程:

          Mean StdDev lower 0.89 upper 0.89 n_eff Rhat
    beta0 18.1   0.09      18.11      18.16    28 1.06
    beta1  0.0   0.00       0.00       0.00    28 1.06
    

    但正如追踪图所示,还是有些问题: traceplot2

    pairs plot

    我还试过:

    • I居中/z-标准化预测器(产生此错误:“sampler$call\u sampler(args\u list[[I]])中的错误:初始化失败。”)
    • 我试过一个普通的模型(但它是计数数据)
    • 我检查过没有遗漏(没有遗漏)
    • 我将迭代次数增加到4000次,没有任何改进
    • 我增加了前科的标准差(模型需要很长时间才能适应)

    但到目前为止没有任何帮助。不合身的原因是什么?我能试试什么?

     mean(d_short$afd_votes)
    [1] 19655.83
    

    数据摘录:

     head(d)
      afd_votes votes_total foreigner_n
    1     11647      170396       16100
    2      9023      138075       12600
    3     11176      130875       11000
    4     11578      156268        9299
    5     10390      150173       25099
    6     11161      130514       13000
    

    会话信息:

    sessionInfo()
    R version 3.5.0 (2018-04-23)
    Platform: x86_64-apple-darwin15.6.0 (64-bit)
    Running under: macOS High Sierra 10.13.6
    
    Matrix products: default
    BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
    LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib
    
    locale:
    [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
    
    attached base packages:
    [1] parallel  stats     graphics  grDevices utils     datasets  methods   base     
    
    other attached packages:
     [1] viridis_0.5.1      viridisLite_0.3.0  sjmisc_2.7.3       pradadata_0.1.3    rethinking_1.59    rstan_2.17.3       StanHeaders_2.17.2 forcats_0.3.0      stringr_1.3.1     
    [10] dplyr_0.7.6        purrr_0.2.5        readr_1.1.1        tidyr_0.8.1        tibble_1.4.2       ggplot2_3.0.0      tidyverse_1.2.1   
    
    loaded via a namespace (and not attached):
     [1] httr_1.3.1         jsonlite_1.5       modelr_0.1.2       assertthat_0.2.0   stats4_3.5.0       cellranger_1.1.0   yaml_2.1.19        pillar_1.3.0       backports_1.1.2   
    [10] lattice_0.20-35    glue_1.3.0         digest_0.6.15      rvest_0.3.2        snakecase_0.9.1    colorspace_1.3-2   htmltools_0.3.6    plyr_1.8.4         pkgconfig_2.0.1   
    [19] broom_0.5.0        haven_1.1.2        bookdown_0.7       mvtnorm_1.0-8      scales_0.5.0       stringdist_0.9.5.1 sjlabelled_1.0.12  withr_2.1.2        RcppTOML_0.1.3    
    [28] lazyeval_0.2.1     cli_1.0.0          magrittr_1.5       crayon_1.3.4       readxl_1.1.0       evaluate_0.11      nlme_3.1-137       MASS_7.3-50        xml2_1.2.0        
    [37] blogdown_0.8       tools_3.5.0        loo_2.0.0          data.table_1.11.4  hms_0.4.2          matrixStats_0.54.0 munsell_0.5.0      prediction_0.3.6   bindrcpp_0.2.2    
    [46] compiler_3.5.0     rlang_0.2.1        grid_3.5.0         rstudioapi_0.7     labeling_0.3       rmarkdown_1.10     gtable_0.2.0       codetools_0.2-15   inline_0.3.15     
    [55] curl_3.2           R6_2.2.2           gridExtra_2.3      lubridate_1.7.4    knitr_1.20         bindr_0.1.1        rprojroot_1.3-2    KernSmooth_2.23-15 stringi_1.2.4     
    [64] Rcpp_0.12.18       tidyselect_0.2.4   xfun_0.3           coda_0.19-1       
    
    1 回复  |  直到 6 年前
        1
  •  2
  •   merv    6 年前

    STAN在单位标度、不相关参数方面做得更好。 根据STAN手册§28.4模型调节和曲率:

    理想情况下,所有参数都应进行编程,使其具有单位精度 属性意味着不需要旋转或缩放 Stans算法的最优性能。对于哈密顿蒙特卡罗, 这意味着单位质量矩阵,它不需要调整,因为它是 条件反射在计算上非常昂贵。

    就你而言 beta1 与…有关 foreigner_n beta0 . 此外,由于 不居中,两个beta都在更改 p 在采样过程中,因此后验相关。

    转化 外国人 中心化和单位尺度使得模型能够快速收敛并产生高有效样本量。我还认为这种形式的beta更易于解释,因为 β0 只关注 ,而 β1 只关心 外国人 解释了 afd_votes/total_votes

    library(readr)
    library(rethinking)
    
    d <- read_csv("https://gist.githubusercontent.com/sebastiansauer/a2519de39da49d70c4a9e7a10191cb97/raw/election.csv")
    d <- as.data.frame(d)
    d$foreigner_z <- scale(d$foreigner_n)
    
    m1 <- alist(
      afd_votes ~ dbinom(votes_total, p),
      logit(p) <- beta0 + beta1*foreigner_z,
      c(beta0, beta1) ~ dnorm(0, 1)
    )
    
    m1_stan <- map2stan(m1, data = d, WAIC = FALSE,
                        iter = 10000, chains = 4, cores = 4)
    

    检查取样,我们有

    > summary(m1_stan)
    Inference for Stan model: afd_votes ~ dbinom(votes_total, p). 
    4 chains, each with iter=10000; warmup=5000; thin=1;  
    post-warmup draws per chain=5000, total post-warmup draws=20000.
    
                  mean se_mean   sd         2.5%          25%          50%          75%        97.5% n_eff Rhat 
    beta0        -1.95    0.00 0.00        -1.95        -1.95        -1.95        -1.95        -1.95 16352    1 
    beta1        -0.24    0.00 0.00        -0.24        -0.24        -0.24        -0.24        -0.24 13456    1 
    dev      861952.93    0.02 1.97    861950.98    861951.50    861952.32    861953.73    861958.26  9348    1 
    lp__  -17523871.11    0.01 0.99 -17523873.77 -17523871.51 -17523870.80 -17523870.39 -17523870.13  9348    1
    
    Samples were drawn using NUTS(diag_e) at Sat Sep  1 11:48:55 2018.
    For each parameter, n_eff is a crude measure of effective sample size, 
    and Rhat is the potential scale reduction factor on split chains (at
    convergence, Rhat=1).
    

    再看成对图,我们看到beta之间的相关性降低到0.15:

    enter image description here


    我最初直觉认为 外国人 这是主要问题。同时,我有点困惑,因为斯坦正在使用HMC/NUTS,我一直认为HMC/NUTS对于相关的潜在变量应该是相当强大的。然而,我注意到STAN手册中关于由于数值不稳定性导致的尺度不变性的实际问题的评论,这些问题也是 commented on by Michael Betancourt in a CrossValidated answer (虽然这是一个相当古老的职位)。所以,我想测试居中或缩放是否对改进采样最有效。

    单独定心

    定心仍然导致相当差的性能。注:有效样本量实际上是每个链一个有效样本。

    library(readr)
    library(rethinking)
    
    d <- read_csv("https://gist.githubusercontent.com/sebastiansauer/a2519de39da49d70c4a9e7a10191cb97/raw/election.csv")
    d <- as.data.frame(d)
    d$foreigner_c <- d$foreigner_n - mean(d$foreigner_n)
    
    m2 <- alist(
      afd_votes ~ dbinom(votes_total, p),
      logit(p) <- beta0 + beta1*foreigner_c,
      c(beta0, beta1) ~ dnorm(0, 1)
    )
    
    m2_stan <- map2stan(m2, data = d, WAIC = FALSE,
                        iter = 10000, chains = 4, cores = 4)
    

    Inference for Stan model: afd_votes ~ dbinom(votes_total, p).
    4 chains, each with iter=10000; warmup=5000; thin=1; 
    post-warmup draws per chain=5000, total post-warmup draws=20000.
    
                  mean   se_mean          sd         2.5%          25%          50%         75%        97.5% n_eff Rhat
    beta0        -0.64       0.4        0.75        -1.95        -1.29        -0.54         0.2         0.42     4 2.34
    beta1         0.00       0.0        0.00         0.00         0.00         0.00         0.0         0.00     4 2.35
    dev    18311608.99 8859262.1 17270228.21    861951.86   3379501.84  14661443.24  37563992.4  46468786.08     4 1.75
    lp__  -26248697.70 4429630.9  8635113.76 -40327285.85 -35874888.93 -24423614.49 -18782644.5 -17523870.54     4 1.75
    
    Samples were drawn using NUTS(diag_e) at Sun Sep  2 18:59:52 2018.
    For each parameter, n_eff is a crude measure of effective sample size,
    and Rhat is the potential scale reduction factor on split chains (at 
    convergence, Rhat=1).
    

    似乎还有一个问题:

    enter image description here

    单独缩放

    缩放大大改善了采样!尽管所得到的后验值仍具有相当高的相关性,但有效样本量仍在可接受的范围内,尽管远低于完全标准化的样本量。

    library(readr)
    library(rethinking)
    
    d <- read_csv("https://gist.githubusercontent.com/sebastiansauer/a2519de39da49d70c4a9e7a10191cb97/raw/election.csv")
    d <- as.data.frame(d)
    d$foreigner_s <- d$foreigner_n / sd(d$foreigner_n)
    
    m3 <- alist(
      afd_votes ~ dbinom(votes_total, p),
      logit(p) <- beta0 + beta1*foreigner_s,
      c(beta0, beta1) ~ dnorm(0, 1)
    )
    
    m3_stan <- map2stan(m2, data = d, WAIC = FALSE,
                        iter = 10000, chains = 4, cores = 4)
    

    顺从的

    Inference for Stan model: afd_votes ~ dbinom(votes_total, p).
    4 chains, each with iter=10000; warmup=5000; thin=1; 
    post-warmup draws per chain=5000, total post-warmup draws=20000.
    
                  mean se_mean   sd         2.5%          25%          50%          75%        97.5% n_eff Rhat
    beta0        -1.58    0.00 0.00        -1.58        -1.58        -1.58        -1.58        -1.57  5147    1
    beta1        -0.24    0.00 0.00        -0.24        -0.24        -0.24        -0.24        -0.24  5615    1
    dev      861952.93    0.03 2.01    861950.98    861951.50    861952.31    861953.69    861958.31  5593    1
    lp__  -17523870.45    0.01 1.00 -17523873.15 -17523870.83 -17523870.14 -17523869.74 -17523869.48  5591    1
    
    Samples were drawn using NUTS(diag_e) at Sun Sep  2 19:02:00 2018.
    For each parameter, n_eff is a crude measure of effective sample size,
    and Rhat is the potential scale reduction factor on split chains (at 
    convergence, Rhat=1).
    

    成对图显示仍存在显著相关性:

    enter image description here