代码之家  ›  专栏  ›  技术社区  ›  andrew

扩展xgboost。XGB分类器

  •  1
  • andrew  · 技术社区  · 7 年前

    XGBExtended 这扩展了类 xgboost.XGBClassifier ,xgboost的scikit学习API。我遇到了一些问题 get_params 获取参数 似乎只返回我在其中定义的属性 XGBExtended.__init__ ,以及在父初始化方法期间定义的属性( xgboost.XGBClassifier.__init__

    In [182]: import xgboost as xgb
         ...: 
         ...: class XGBExtended(xgb.XGBClassifier):
         ...:   def __init__(self, foo):
         ...:     super(XGBExtended, self).__init__()
         ...:     self.foo = foo
         ...: 
         ...: clf = XGBExtended(foo = 1)
         ...: 
         ...: clf.get_params()
         ...: 
    ---------------------------------------------------------------------------
    KeyError                                  Traceback (most recent call last)
    <ipython-input-182-431c4c3f334b> in <module>()
          8 clf = XGBExtended(foo = 1)
          9 
    ---> 10 clf.get_params()
    
    /Users/andrewhannigan/lib/xgboost/python-package/xgboost/sklearn.pyc in get_params(self, deep)
        188         if isinstance(self.kwargs, dict):  # if kwargs is a dict, update params accordingly
        189             params.update(self.kwargs)
    --> 190         if params['missing'] is np.nan:
        191             params['missing'] = None  # sklearn doesn't handle nan. see #4725
        192         if not params.get('eval_metric', True):
    
    KeyError: 'missing'
    

    所以我遇到了一个错误,因为“缺失”不是 params dict在 XGBClassifier.get_params

    In [183]: %debug
    > /Users/andrewhannigan/lib/xgboost/python-package/xgboost/sklearn.py(190)get_params()
        188         if isinstance(self.kwargs, dict):  # if kwargs is a dict, update params accordingly
        189             params.update(self.kwargs)
    --> 190         if params['missing'] is np.nan:
        191             params['missing'] = None  # sklearn doesn't handle nan. see #4725
        192         if not params.get('eval_metric', True):
    
    ipdb> params
    {'foo': 1}
    ipdb> self.__dict__
    {'n_jobs': 1, 'seed': None, 'silent': True, 'missing': nan, 'nthread': None, 'min_child_weight': 1, 'random_state': 0, 'kwargs': {}, 'objective': 'binary:logistic', 'foo': 1, 'max_depth': 3, 'reg_alpha': 0, 'colsample_bylevel': 1, 'scale_pos_weight': 1, '_Booster': None, 'learning_rate': 0.1, 'max_delta_step': 0, 'base_score': 0.5, 'n_estimators': 100, 'booster': 'gbtree', 'colsample_bytree': 1, 'subsample': 1, 'reg_lambda': 1, 'gamma': 0}
    ipdb> 
    

    正如你所见 参数 foo . 但由于某种原因 BaseEstimator.get_params 从调用的方法 xgboost.XGBClassifier.get_params 方法不幸的是,即使我明确要求 获取参数 具有 deep = True ,它仍然无法正常工作:

    ipdb> super(XGBModel, self).get_params(deep=True)
    {'foo': 1}
    ipdb> 
    

    系统规格:

    In [186]: print IPython.sys_info()
    {'commit_hash': u'1149d1700',
     'commit_source': 'installation',
     'default_encoding': 'UTF-8',
     'ipython_path': '/Users/andrewhannigan/virtualenvironment/nimble_ai/lib/python2.7/site-packages/IPython',
     'ipython_version': '5.4.1',
     'os_name': 'posix',
     'platform': 'Darwin-14.5.0-x86_64-i386-64bit',
     'sys_executable': '/usr/local/Cellar/python/2.7.10/Frameworks/Python.framework/Versions/2.7/Resources/Python.app/Contents/MacOS/Python',
     'sys_platform': 'darwin',
     'sys_version': '2.7.10 (default, Jul  3 2015, 12:05:53) \n[GCC 4.2.1 Compatible Apple LLVM 6.1.0 (clang-602.0.53)]'}
    
    1 回复  |  直到 7 年前
        1
  •  1
  •   Vivek Kumar    7 年前

    这里的问题是子类的声明不正确。 当您仅使用 foo

    您应该使用以下内容:

    class XGBExtended(xgb.XGBClassifier):
        def __init__(self, foo, max_depth=3, learning_rate=0.1,
                     n_estimators=100, silent=True,
                     objective="binary:logistic",
                     nthread=-1, gamma=0, min_child_weight=1,
                     max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
                     reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
                     base_score=0.5, seed=0, missing=None, **kwargs):
    
            # Pass the required parameters to super class
            super(XGBExtended, self).__init__(max_depth, learning_rate,
                                                n_estimators, silent, objective,
                                                nthread, gamma, min_child_weight,
                                                max_delta_step, subsample,
                                                colsample_bytree, colsample_bylevel,
                                                reg_alpha, reg_lambda,
    scale_pos_weight, base_score, seed, missing, **kwargs)
    
            # Use other custom parameters
            self.foo = foo
    

    在此之后,您将不会得到任何错误。

    clf = XGBExtended(foo = 1)
    print(clf.get_params(deep=True))
    
    >>> {'reg_alpha': 0, 'colsample_bytree': 1, 'silent': True, 
         'colsample_bylevel': 1, 'scale_pos_weight': 1, 'learning_rate': 0.1, 
         'missing': None, 'max_delta_step': 0, 'nthread': -1, 'base_score': 0.5, 
         'n_estimators': 100, 'subsample': 1, 'reg_lambda': 1, 'seed': 0, 
         'min_child_weight': 1, 'objective': 'binary:logistic', 
         'foo': 1, 'max_depth': 3, 'gamma': 0}