代码之家  ›  专栏  ›  技术社区  ›  sachinruk

预测当前时间的生存概率

  •  0
  • sachinruk  · 技术社区  · 5 年前

    我用以下几句话训练我的生存模式:

    wft = WeibullAFTFitter()
    wft.fit(train, 'duration', event_col='y')
    

    在这之后,我想看看目前的生存概率( duration 列)。

    如果使用以下for循环,则当前执行此操作的方式:

    p_surv = np.zeros(len(test))
    for i in range(len(p_surv)):
        row = test.iloc[i:i+1].drop(dep_var, axis=1)
        t = test.iloc[i:i+1, col_num]
        p_surv[i] = wft.predict_survival_function(row, t).values[0][0]
    

    然而,考虑到Im使用for循环(200k+行),这确实很慢。另一种选择 wft.predict_survival_function(test, test['duration']) 将创建一个200000x200000矩阵,因为它根据所有提供的时间检查每一行。

    lifelines 是这样吗?

    0 回复  |  直到 5 年前
        1
  •  0
  •   sachinruk    5 年前

    好问题。我认为现在,最好的方法是重现预测生存函数所做的事情。也就是说,做这样的事情:

    def predict_cumulative_hazard_at_single_time(self, X, times, ancillary_X=None):
        lambda_, rho_ = self._prep_inputs_for_prediction_and_return_scores(X, ancillary_X)
        return (times / lambda_) ** rho_
    
    def predict_survival_function_at_single_time(self, X, times, ancillary_X=None):
        return np.exp(-self.predict_cumulative_hazard_at_single_time(X, times=times, ancillary_X=ancillary_X))
    
    
    wft.predict_survival_function_at_single_time = predict_survival_function_at_single_time.__get__(wft)
    wft.predict_cumulative_hazard_at_single_time = predict_cumulative_hazard_at_single_time.__get__(wft)
    
    p_surv2 = wft.predict_survival_function_at_single_time(test, test['duration'])