这可能不是最正确或最好的解决方案。但这是我在挣扎了一段时间后想到的。如果其他人找到更好的解决方案,请张贴!这也可能与不同版本的金属过时。
我第一次试着用
_atomic<T>
它包含在我结构的金属语言中。这个
工作。经过一番努力,我终于检查了文档,并意识到模板目前被苹果限制为bool的,int的和uint的。
然后,我尝试使用原子int来“锁定”关键比较部分,但实际上没有成功地保护关键部分。我可能对这个实现做了一些错误的事情,并且可以看到它在工作。
但是
这样我就可以继续了。
kernel void compareX(const device Point2 *array [[ buffer(0) ]],
device atomic_int *result [[ buffer(1) ]],
uint id [[ thread_position_in_grid ]],
uint tid [[ thread_index_in_threadgroup ]],
uint bid [[ threadgroup_position_in_grid ]],
uint blockDim [[ threads_per_threadgroup ]]) {
threadgroup int shared_memory[THREADGROUP_SIZE];
uint i = bid * blockDim + tid;
shared_memory[tid] = i;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s = 1; s < blockDim; s *= 2) {
if (tid % (2 * s) == 0) {
// aggregate the index to our smallest value in shared_memory
if ( array[shared_memory[tid + s]].x < array[shared_memory[tid]].x) {
shared_memory[tid] = shared_memory[tid + s];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if (0 == tid ) {
// get the current index so we can test against that
int current = atomic_load_explicit(result, memory_order_relaxed);
if( array[shared_memory[0]].x < array[current].x) {
while(!atomic_compare_exchange_weak_explicit(result, ¤t, shared_memory[0], memory_order_relaxed, memory_order_relaxed)) {
// another thread won. Check if we still need to set it.
if (array[shared_memory[0]].x > array[current].x) {
// they won, and have a smaller value, ignore our best result
break;
}
}
}
}
}