以下是jax中一个非常简单的计算,它会因对静态索引的抱怨而出错:
def get_slice(ar, k, I):
return ar[i:i+k]
vec_get_slice = jax.vmap(get_slice, in_axes=(None, None, 0))
arr = jnp.array([1, 2,3, 4, 5])
vec_get_slice(arr, 2, jnp.arange(3))
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-32-6c60650ce6b7> in <cell line: 1>()
----> 1 vec_get_slice(arr, 2, jnp.arange(3))
[... skipping hidden 3 frame]
4 frames
<ipython-input-29-9528369725c2> in get_slice(ar, k, i)
1 def get_slice(ar, k, i):
----> 2 return ar[i:i+k]
/usr/local/lib/python3.10/dist-packages/jax/_src/array.py in __getitem__(self, idx)
346 return out
347
--> 348 return lax_numpy._rewriting_take(self, idx)
349
350 def __iter__(self):
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
4602
4603 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 4604 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
4605 unique_indices, mode, fill_value)
4606
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
4611 unique_indices, mode, fill_value):
4612 idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 4613 indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
4614 y = arr
4615
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in _index_to_gather(x_shape, idx, normalize_indices)
4854 "dynamic_update_slice (JAX does not support dynamically sized "
4855 "arrays within JIT compiled functions).")
-> 4856 raise IndexError(msg)
4857
4858 start, step, slice_size = _preprocess_slice(i, x_shape[x_axis])
Horrible error output below. I am obviously missing something simple, but what?
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
val = Array([0, 1, 2], dtype=int32)
batch_dim = 0, Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
val = Array([2, 3, 4], dtype=int32)
batch_dim = 0, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).