-
Notifications
You must be signed in to change notification settings - Fork 446
Open
Description
Description
Hi
When I called single_prefill_with_kv_cache, a large number of zeros appeared. At the same time, I compared it with the eager implementation and found that the same position was not zero. The following is my test code, and the test data is in the link. My flashinfer version is 0.2.7.post1
https://drive.google.com/file/d/1kl0rP8js2IfqFMh_qlLtO7_k3EJDaFvP/view?usp=drive_link
Reproduction Steps
# -*- coding: utf-8 -*-
#!/bin/env python3
import sys, importlib
import os
import torch
import flashinfer
import numpy as np
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
import numpy as np
if getattr(flashinfer, "__file__", None) is None:
# 说明导入到了同名目录的“空包”
sys.modules.pop("flashinfer", None)
pkg_dir = os.path.join(os.path.dirname(__file__), "..", "flashinfer")
sys.path.insert(0, os.path.abspath(pkg_dir))
flashinfer = importlib.import_module("flashinfer")
print("flashinfer file:", getattr(flashinfer, "__file__", None))
def flashinfer_multihead_attention(q, k, v, mask):
o = flashinfer.single_prefill_with_kv_cache(q, k, v, pos_encoding_mode="NONE", causal=True, custom_mask=mask, use_fp16_qk_reduction=False)
return o
def naive_multihead_attention(q, k, v, mask):
q_len, kv_len = q.size(0), k.size(0)
num_q_heads, num_kv_heads = q.size(1), k.size(1)
head_dim = q.size(2)
q = q.reshape(q_len, num_q_heads, head_dim).transpose(0, 1) # (num_q_heads, seq_len, head_dim)
k = k.reshape(kv_len, num_kv_heads, head_dim).transpose(0, 1) # (num_kv_heads, seq_len, head_dim)
v = v.reshape(kv_len, num_kv_heads, head_dim).transpose(0, 1) # (num_kv_heads, seq_len, head_dim)
# Support GQA: repeat KV heads to match Q heads
if num_q_heads != num_kv_heads:
assert num_q_heads % num_kv_heads == 0, f"num_q_heads ({num_q_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
repeat_factor = num_q_heads // num_kv_heads
# Repeat each KV head repeat_factor times
k = k.repeat_interleave(repeat_factor, dim=0) # (num_q_heads, seq_len, head_dim)
v = v.repeat_interleave(repeat_factor, dim=0) # (num_q_heads, seq_len, head_dim)
# Compute attention scores
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (head_dim ** 0.5)
attn_scores = attn_scores.masked_fill(mask.unsqueeze(0) == False, float('-inf'))
attn_weights = torch.nn.functional.softmax(attn_scores.to(torch.float32), dim=-1).to(q.dtype)
# Apply attention weights to values
attn_output = torch.matmul(attn_weights, v).transpose(0, 1) # (seq_len, num_q_heads, head_dim)
return attn_output
if __name__ == "__main__":
q_len = 1275
kv_len = 1275
num_q_heads = 28
num_kv_heads = 4
head_dim = 128
max_pos = 4096
num = 27
data = np.fromfile(f"layers.27.SelfAttention.qkv.out.out.bin", dtype=np.float32).reshape([1275, 4608])[:, : 3584].reshape(q_len, num_q_heads, head_dim)
q = torch.from_numpy(data).to(torch.bfloat16).cuda()
data = np.fromfile(f"layers.27.SelfAttention.qkv.out.out.bin", dtype=np.float32).reshape([1275, 4608])[:, 3584 : 4096].reshape(kv_len, num_kv_heads, head_dim)
k = torch.from_numpy(data).to(torch.bfloat16).cuda()
data = np.fromfile(f"layers.27.SelfAttention.qkv.out.out.bin", dtype=np.float32).reshape([1275, 4608])[:, 4096 : 4608].reshape(kv_len, num_kv_heads, head_dim)
v = torch.from_numpy(data).to(torch.bfloat16).cuda()
q_clone = q.clone()
k_clone = k.clone()
v_clone = v.clone()
mask = torch.tril(torch.full((q_len, kv_len), True), diagonal=kv_len - q_len).cuda()
o = flashinfer_multihead_attention(q, k, v, mask)
print("o.shape = ", o.shape)
data_c_order = np.ascontiguousarray(o.to(torch.float32).cpu().numpy())
# np.savetxt("data/attn_out_flashinfer.txt", data_c_order.astype(np.float32).reshape(-1, 1), fmt="%.6f")
o = naive_multihead_attention(q_clone, k_clone, v_clone, mask)
o = o.reshape(q_len, num_q_heads * head_dim)
print(o.shape)
# np.savetxt("data/navie_out_bf16.txt", o.cpu().to(torch.float32).numpy().astype(np.float32).reshape(-1, 1), fmt="%.6f")
Expected Behavior
The flashinfer inference result and the eager inference result are within the range of accuracy error
Actual Behavior
The flashinfer inference result contains a large number of zero values. The eager inference result is not zero at the same position
Environment
- flashinfer version: 0.2.7.post1
- PyTorch version: 2.7.1
- CUDA version: 12.6
- GPU: NVIDIA H800
Metadata
Metadata
Assignees
Labels
No labels