代码之家  ›  专栏  ›  技术社区  ›  Igor Rivin

jax抱怨静态启动/停止/步骤

  •  1
  • Igor Rivin  · 技术社区  · 10 月前

    以下是jax中一个非常简单的计算,它会因对静态索引的抱怨而出错:

    def get_slice(ar, k, I):
      return ar[i:i+k]
    
    vec_get_slice = jax.vmap(get_slice, in_axes=(None, None, 0))
    
    arr = jnp.array([1, 2,3, 4, 5])
    
    vec_get_slice(arr, 2, jnp.arange(3))
    
    ---------------------------------------------------------------------------
    IndexError                                Traceback (most recent call last)
    <ipython-input-32-6c60650ce6b7> in <cell line: 1>()
    ----> 1 vec_get_slice(arr, 2, jnp.arange(3))
    
        [... skipping hidden 3 frame]
    
    4 frames
    <ipython-input-29-9528369725c2> in get_slice(ar, k, i)
          1 def get_slice(ar, k, i):
    ----> 2   return ar[i:i+k]
    
    /usr/local/lib/python3.10/dist-packages/jax/_src/array.py in __getitem__(self, idx)
        346           return out
        347 
    --> 348     return lax_numpy._rewriting_take(self, idx)
        349 
        350   def __iter__(self):
    
    /usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
       4602 
       4603   treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
    -> 4604   return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
       4605                  unique_indices, mode, fill_value)
       4606 
    
    /usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
       4611             unique_indices, mode, fill_value):
       4612   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
    -> 4613   indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
       4614   y = arr
       4615 
    
    /usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in _index_to_gather(x_shape, idx, normalize_indices)
       4854                "dynamic_update_slice (JAX does not support dynamically sized "
       4855                "arrays within JIT compiled functions).")
    -> 4856         raise IndexError(msg)
       4857 
       4858       start, step, slice_size = _preprocess_slice(i, x_shape[x_axis])
    
    Horrible error output below. I am obviously missing something simple, but what?
    
    
    IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
      val = Array([0, 1, 2], dtype=int32)
      batch_dim = 0, Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
      val = Array([2, 3, 4], dtype=int32)
      batch_dim = 0, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
    
    1 回复  |  直到 10 月前
        1
  •  1
  •   jakevdp    10 月前

    传递给JAX中切片的索引必须是静态的。中映射的值 vmap 不是静态的:因为您正在映射起始索引,所以您的索引不是静态的,您会看到这个错误。

    不过有一个好消息:子数组的大小由 k ,它在您的代码中是未映射的,因此是静态的;它只是切片的位置(由 I )这是动态的。正是这种情况 jax.lax.dynamic_slice 是为设计的,因此您可以像这样重写代码:

    import jax
    import jax.numpy as jnp
    
    def get_slice(ar, k, I):
      return jax.lax.dynamic_slice(ar, (I,), (k,))
    
    vec_get_slice = jax.vmap(get_slice, in_axes=(None, None, 0))
    
    arr = jnp.array([1, 2, 3, 4, 5])
    
    vec_get_slice(arr, 2, jnp.arange(3))
    # Array([[1, 2],
    #        [2, 3],
    #        [3, 4]], dtype=int32)
    
    推荐文章