We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent df36bce commit 4227c9bCopy full SHA for 4227c9b
ggml/src/ggml-cuda/fattn-common.cuh
@@ -539,11 +539,15 @@ static __global__ void flash_attn_mask_to_KV_max(
539
all_inf = warp_reduce_all(all_inf);
540
541
if (!all_inf) {
542
- KV_max_sj += FATTN_KQ_STRIDE;
543
break;
544
}
545
546
+ // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
547
+ // If the break was triggered it's the lower edge of the tile with the first non-masked values.
548
+ // In either case, walk back the decrementation by FATTN_KQ_STRIDE.
549
+ KV_max_sj += FATTN_KQ_STRIDE;
550
+
551
if (threadIdx.x != 0) {
552
return;
553
0 commit comments