Skip to content

Commit 4227c9b

Browse files
CUDA: fix negative KV_max values in FA (#15321)
1 parent df36bce commit 4227c9b

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,11 +539,15 @@ static __global__ void flash_attn_mask_to_KV_max(
539539
all_inf = warp_reduce_all(all_inf);
540540

541541
if (!all_inf) {
542-
KV_max_sj += FATTN_KQ_STRIDE;
543542
break;
544543
}
545544
}
546545

546+
// If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
547+
// If the break was triggered it's the lower edge of the tile with the first non-masked values.
548+
// In either case, walk back the decrementation by FATTN_KQ_STRIDE.
549+
KV_max_sj += FATTN_KQ_STRIDE;
550+
547551
if (threadIdx.x != 0) {
548552
return;
549553
}

0 commit comments

Comments
 (0)