代码之家  ›  专栏  ›  技术社区  ›  Behnam Hedayat

如何在mlr3proba的嵌套交叉验证中转换“2级ParamUty”类?

  •  0
  • Behnam Hedayat  · 技术社区  · 3 年前

    对于生存分析,我使用 mlr3proba R的包装。
    我的数据集由39个特征(连续和因子,我将其全部转换为整数和数字)和目标(时间和状态)组成。
    我想调整超参数:num_nodes,在 Param_set .
    这是一个 ParamUty 具有默认值的类参数: 32,32 .
    所以我决定改变它。
    我编写了如下代码,用于优化超参数 surv.deephit 学习者使用“嵌套交叉验证”(10个内折叠和3个外折叠)。

    #task definition
    task.mlr <- TaskSurv$new(id = "id", backend = main.dataset, event = 'status', time = 'time')
    
    #learner definition
    dh.learner <- lrn('surv.deephit') 
       
    #resampling method
    resampling <- rsmp('cv', folds =10)
      
    #tuner method
    tuner <- tnr('random_search')
    
    #measure method
    measure <- msr('surv.harrellC')
    
    #termination method
    terminator <- trm('stagnation')
    
    #search_space definition(for num_nodes)
    search_space <- ps(num_nodes = p_fct(list(c(32,64,128,256)), trafo = function(x) c(sample(x,1), sample(x,1))))
    
    #To check search_space
    generate_design_random(search_space,10)$transpose()
    

    当我使用转置运行最后一行代码时,它展示了一个num_nodes列表,每个节点都包含一对类别,如下所示:

    [[1]]$num_nodes
    [1] 64 128
    
    [[2]]$num_nodes
    [1] 32 256
    
    ...
    

    然后我写了以下代码:

    #defining autotuner
    at <- AutoTuner$new(dh.learner, resampling, measure, terminator, tuner, search_space)
    
    #outer cross validation
    resampling_outer <- rsmp('cv', folds = 3)
    
    # nested resampling
    nest_rsm <- resample(task.mlr, at, resampling_outer)
    

    但在嵌套重采样的结果中,它显示num_nodes为 c(32,64,128,256) 。大致如下:

        num_nodes 
       c(32,64,128,256)
        num_nodes
       c(32,64,128,256)
       ... 
    

    如何将“Param_Uty”转换为2个级别(例如。 32,64 )?
    老实说,我已经仔细搜索了这个话题,毕竟我没有找到合适的答案,所以我很感激你的帮助。

    0 回复  |  直到 3 年前
        1
  •  2
  •   RaphaelS    3 年前

    您好,感谢您使用mlr3proba。实际上,我刚刚完成了一篇教程,正好回答了这个问题!它涵盖了在mlr3proba中训练、调整和评估神经网络。对于您的具体问题,本教程的相关部分如下:

    library(paradox)
    search_space = ps(
      nodes = p_int(lower = 1, upper = 32),
      k = p_int(lower = 1, upper = 4)
    )
    
    search_space$trafo = function(x, param_set) {
      x$num_nodes = rep(x$nodes, x$k)
      x$nodes = x$k = NULL
      return(x)
    }
    

    在这里,我调整了两个新的超参数,一个表示每层的节点数,另一个表示层数,然后我创建了一个转换,将它们组合成所需的超参数, num_nodes .

    您应该能够将此应用于您的示例。

    推荐文章