代码之家  ›  专栏  ›  技术社区  ›  Olga Botvinnik

如何使用groupby减少采样xarray数据集?

  •  3
  • Olga Botvinnik  · 技术社区  · 7 年前

    groupby 选择组,然后在每组中抽取10%的样本。我正在使用下面的代码,但我得到 IndexError: index 1330 is out of bounds for axis 0 with size 1330 subset 肯定有非零维度。

    我正在使用 squeeze=True GroupBy documentation 但这没用,所以我把它改成了 squeeze=False .

    非常感谢。

    # Set random seed for reproducibility
    np.random.seed(0)
    
    def select_random_cell_subset(x):
        size = int(0.1 * len(x.cell))
        random_cells = sorted(np.random.choice(x.cell, size=size, replace=False))
        print('number of random cells:', len(random_cells))
        print('\tsome random cells:', random_cells[:5])
        subset = x.sel(cell=random_cells)
        print('subset:', subset)
        return subset
    
    # squeeze=False because the final dataset is smaller than the original
    ds_subset = ds.groupby('group', squeeze=True).apply(select_random_cell_subset)
    ds_subset
    

    以下是错误:

    ---------------------------------------------------------------------------
    IndexError                                Traceback (most recent call last)
    <ipython-input-44-39c7803e9e40> in <module>()
         12 
         13 # squeeze=False because the final dataset is smaller than the original
    ---> 14 ds_subset = ds.groupby('group', squeeze=True).apply(select_random_cell_subset)
         15 ds_subset
    
    ~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/groupby.py in apply(self, func, **kwargs)
        615         kwargs.pop('shortcut', None)  # ignore shortcut if set (for now)
        616         applied = (func(ds, **kwargs) for ds in self._iter_grouped())
    --> 617         return self._combine(applied)
        618 
        619     def _combine(self, applied):
    
    ~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/groupby.py in _combine(self, applied)
        622         coord, dim, positions = self._infer_concat_args(applied_example)
        623         combined = concat(applied, dim)
    --> 624         combined = _maybe_reorder(combined, dim, positions)
        625         if coord is not None:
        626             combined[coord.name] = coord
    
    ~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/groupby.py in _maybe_reorder(xarray_obj, dim, positions)
        443         return xarray_obj
        444     else:
    --> 445         return xarray_obj[{dim: order}]
        446 
        447 
    
    ~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/dataset.py in __getitem__(self, key)
        716         """
        717         if utils.is_dict_like(key):
    --> 718             return self.isel(**key)
        719 
        720         if hashable(key):
    
    ~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/dataset.py in isel(self, drop, **indexers)
       1141         for name, var in iteritems(self._variables):
       1142             var_indexers = dict((k, v) for k, v in indexers if k in var.dims)
    -> 1143             new_var = var.isel(**var_indexers)
       1144             if not (drop and name in var_indexers):
       1145                 variables[name] = new_var
    
    ~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/variable.py in isel(self, **indexers)
        568             if dim in indexers:
        569                 key[i] = indexers[dim]
    --> 570         return self[tuple(key)]
        571 
        572     def squeeze(self, dim=None):
    
    ~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/variable.py in __getitem__(self, key)
        398         dims = tuple(dim for k, dim in zip(key, self.dims)
        399                      if not isinstance(k, integer_types))
    --> 400         values = self._indexable_data[key]
        401         # orthogonal indexing should ensure the dimensionality is consistent
        402         if hasattr(values, 'ndim'):
    
    ~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/indexing.py in __getitem__(self, key)
        476     def __getitem__(self, key):
        477         key = self._convert_key(key)
    --> 478         return self._ensure_ndarray(self.array[key])
        479 
        480     def __setitem__(self, key, value):
    
    IndexError: index 1330 is out of bounds for axis 0 with size 1330
    
    2 回复  |  直到 7 年前
        1
  •  5
  •   shoyer    7 年前

    这是一件完全明智的事情,但遗憾的是,它还没有奏效。Xarray使用一些启发式方法来确定 apply 操作属于 reduce transform 在这种情况下,我们错误地将分组操作标识为“转换”,因为输出重用了原始维度名称。我只是 filed a bug report

    最简单的解决方法可能是让应用的函数返回一个布尔数据数组,指示要保留的位置。然后可以使用索引操作从原始对象中进行选择。

        2
  •  5
  •   Olga Botvinnik    7 年前

    下面是我如何实现它的。正如@shoyer在上面建议的那样,我返回了一个布尔值 xarray.DataArray

    # Set random seed for reproducibility
    np.random.seed(0)
    
    def select_random_cell_subset(x, threshold=0.1):
        random_bools = xr.DataArray(np.random.uniform(size=len(x.cell)) <= threshold,
                                   coords=dict(cell=x.cell)) 
        return random_bools
    
        subset_bools = ds.groupby('group',).apply(select_random_cell_subset, 
                                                        threshold=0.1)
    ds_subset = ds.sel(cell=subset_bools)