Skip to content

Calling single_prefill_with_kv_cache results in a large number of zeros #1489

@QZH-eng

Description

@QZH-eng

Description

Hi
When I called single_prefill_with_kv_cache, a large number of zeros appeared. At the same time, I compared it with the eager implementation and found that the same position was not zero. The following is my test code, and the test data is in the link. My flashinfer version is 0.2.7.post1
https://drive.google.com/file/d/1kl0rP8js2IfqFMh_qlLtO7_k3EJDaFvP/view?usp=drive_link

Reproduction Steps

# -*- coding: utf-8 -*-
#!/bin/env python3
import sys, importlib
import os
import torch
import flashinfer
import numpy as np
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
import numpy as np

if getattr(flashinfer, "__file__", None) is None:
    # 说明导入到了同名目录的“空包”
    sys.modules.pop("flashinfer", None)
    pkg_dir = os.path.join(os.path.dirname(__file__), "..", "flashinfer")
    sys.path.insert(0, os.path.abspath(pkg_dir))
    flashinfer = importlib.import_module("flashinfer")

print("flashinfer file:", getattr(flashinfer, "__file__", None))

def flashinfer_multihead_attention(q, k, v, mask):
    o = flashinfer.single_prefill_with_kv_cache(q, k, v, pos_encoding_mode="NONE", causal=True, custom_mask=mask, use_fp16_qk_reduction=False)
    return o

def naive_multihead_attention(q, k, v, mask):
    q_len, kv_len = q.size(0), k.size(0)
    num_q_heads, num_kv_heads = q.size(1), k.size(1)
    head_dim = q.size(2)

    q = q.reshape(q_len, num_q_heads, head_dim).transpose(0, 1)  # (num_q_heads, seq_len, head_dim)
    k = k.reshape(kv_len, num_kv_heads, head_dim).transpose(0, 1)  # (num_kv_heads, seq_len, head_dim)
    v = v.reshape(kv_len, num_kv_heads, head_dim).transpose(0, 1)  # (num_kv_heads, seq_len, head_dim)

    # Support GQA: repeat KV heads to match Q heads
    if num_q_heads != num_kv_heads:
        assert num_q_heads % num_kv_heads == 0, f"num_q_heads ({num_q_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
        repeat_factor = num_q_heads // num_kv_heads
        # Repeat each KV head repeat_factor times
        k = k.repeat_interleave(repeat_factor, dim=0)  # (num_q_heads, seq_len, head_dim)
        v = v.repeat_interleave(repeat_factor, dim=0)  # (num_q_heads, seq_len, head_dim)

    # Compute attention scores
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (head_dim ** 0.5)
    attn_scores = attn_scores.masked_fill(mask.unsqueeze(0) == False, float('-inf'))
    attn_weights = torch.nn.functional.softmax(attn_scores.to(torch.float32), dim=-1).to(q.dtype)

    # Apply attention weights to values
    attn_output = torch.matmul(attn_weights, v).transpose(0, 1)  # (seq_len, num_q_heads, head_dim)

    return attn_output

if __name__ == "__main__":

    q_len = 1275
    kv_len = 1275
    num_q_heads = 28
    num_kv_heads = 4
    head_dim = 128
    max_pos = 4096
    num = 27

    data = np.fromfile(f"layers.27.SelfAttention.qkv.out.out.bin", dtype=np.float32).reshape([1275, 4608])[:, : 3584].reshape(q_len, num_q_heads, head_dim)

    q = torch.from_numpy(data).to(torch.bfloat16).cuda()
    data = np.fromfile(f"layers.27.SelfAttention.qkv.out.out.bin", dtype=np.float32).reshape([1275, 4608])[:, 3584 : 4096].reshape(kv_len, num_kv_heads, head_dim)

    k = torch.from_numpy(data).to(torch.bfloat16).cuda()
    data = np.fromfile(f"layers.27.SelfAttention.qkv.out.out.bin", dtype=np.float32).reshape([1275, 4608])[:, 4096 : 4608].reshape(kv_len, num_kv_heads, head_dim)

    v = torch.from_numpy(data).to(torch.bfloat16).cuda()
    q_clone = q.clone()
    k_clone = k.clone()
    v_clone = v.clone()

    mask = torch.tril(torch.full((q_len, kv_len), True), diagonal=kv_len - q_len).cuda()

    o = flashinfer_multihead_attention(q, k, v, mask)
    print("o.shape = ", o.shape)
    data_c_order = np.ascontiguousarray(o.to(torch.float32).cpu().numpy())
    # np.savetxt("data/attn_out_flashinfer.txt", data_c_order.astype(np.float32).reshape(-1, 1), fmt="%.6f")


    o = naive_multihead_attention(q_clone, k_clone, v_clone, mask)
    o = o.reshape(q_len, num_q_heads * head_dim)
    print(o.shape)
    # np.savetxt("data/navie_out_bf16.txt", o.cpu().to(torch.float32).numpy().astype(np.float32).reshape(-1, 1), fmt="%.6f")

Expected Behavior

The flashinfer inference result and the eager inference result are within the range of accuracy error

Actual Behavior

The flashinfer inference result contains a large number of zero values. The eager inference result is not zero at the same position
Image

Environment

  • flashinfer version: 0.2.7.post1
  • PyTorch version: 2.7.1
  • CUDA version: 12.6
  • GPU: NVIDIA H800

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions