Skip to content

Commit 4afe692

Browse files
Aya-ZIbrafacebook-github-bot
authored andcommitted
avoid propagation of NaN
Summary: X-link: facebookresearch/FBGEMM#806 as title Introduce padding in dequantization kernel to avoid passing of NaNs to the output of FA3 in prefill stage. Reviewed By: jianyuh Differential Revision: D69522001
1 parent 3fed238 commit 4afe692

File tree

1 file changed

+76
-54
lines changed

1 file changed

+76
-54
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu

Lines changed: 76 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,6 +1692,7 @@ at::Tensor xpos_qkv_decoding(
16921692
16931693
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \
16941694
(defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
1695+
template <bool ExternalQParam>
16951696
__global__ void dequantize_fp8_cache_kernel(
16961697
// This code currently represents FP8 version not int4
16971698
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
@@ -1711,60 +1712,69 @@ __global__ void dequantize_fp8_cache_kernel(
17111712
auto D_H_q = cache_K.size(3);
17121713
// TODO: support D_H < 128 for small model used in testing.
17131714
CUDA_KERNEL_ASSERT(D_H == 128);
1715+
CUDA_KERNEL_ASSERT(4 * threadIdx.x < D_H)
1716+
const uint8_t offset_bytes = (ExternalQParam) ? 0 : 4;
1717+
CUDA_KERNEL_ASSERT(D_H_q - D_H == 0);
17141718
17151719
auto b = blockIdx.x;
17161720
// only need to dequantize this far.
17171721
auto max_t = kv_seqlen[b];
17181722
17191723
// one warp per T/H
1720-
for (auto t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH;
1724+
int h = 0, t = 0;
1725+
uint8_t *row_k{}, *row_v{};
1726+
c10::BFloat16 *row_k_dq{}, *row_v_dq{};
1727+
uint64_t packed{};
1728+
bfx8 kv_dq;
1729+
long t_h{};
1730+
for (t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH;
17211731
t_h += blockDim.y * gridDim.y) {
1722-
auto h = t_h % N_KVH;
1723-
auto t = t_h / N_KVH;
1724-
1725-
auto* row_k = &cache_K[b][t][h][0]; // uint8_t*
1726-
auto* row_v = &cache_V[b][t][h][0];
1727-
bfx8 kv_dq;
1728-
uint8_t qparam_offset_bytes;
1729-
__half2* qparam_k_src;
1730-
__half2* qparam_v_src;
1731-
if (qparam_k_ptr) {
1732-
// read from standalone qparam tensor
1733-
qparam_offset_bytes = 0;
1734-
auto idx = b * (MAX_T * N_KVH) + t * N_KVH + h;
1735-
qparam_k_src = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
1736-
qparam_v_src = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
1737-
} else {
1738-
// read from first row
1739-
qparam_offset_bytes = 4;
1740-
qparam_k_src = reinterpret_cast<__half2*>(&row_k[0]);
1741-
qparam_v_src = reinterpret_cast<__half2*>(&row_v[0]);
1742-
}
1743-
// Assert the quantized row dim is as expected
1744-
CUDA_KERNEL_ASSERT(D_H_q - D_H == qparam_offset_bytes);
1745-
if (4 * threadIdx.x >= D_H) {
1746-
continue;
1747-
}
1748-
// each thread reads 4 x 8 bits
1749-
1750-
uint64_t kq = *reinterpret_cast<uint32_t*>(
1751-
&row_k[threadIdx.x * 4 + qparam_offset_bytes]);
1752-
uint64_t vq = *reinterpret_cast<uint32_t*>(
1753-
&row_v[threadIdx.x * 4 + qparam_offset_bytes]);
1754-
1755-
uint64_t packed = kq | (vq << 32);
1732+
h = t_h % N_KVH;
1733+
t = t_h / N_KVH;
1734+
1735+
row_k = &cache_K[b][t][h][0];
1736+
row_v = &cache_V[b][t][h][0];
1737+
row_k_dq = &cache_K_dq[b][t][h][0];
1738+
row_v_dq = &cache_V_dq[b][t][h][0];
1739+
// Calculate kv_dq for this row
1740+
{
1741+
__half2* qparam_k_src;
1742+
__half2* qparam_v_src;
1743+
if (ExternalQParam) {
1744+
size_t idx = b * (MAX_T * N_KVH) + t * N_KVH + h;
1745+
qparam_k_src = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
1746+
qparam_v_src = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
1747+
} else {
1748+
qparam_k_src = reinterpret_cast<__half2*>(&row_k[0]);
1749+
qparam_v_src = reinterpret_cast<__half2*>(&row_v[0]);
1750+
}
1751+
uint64_t kq =
1752+
*reinterpret_cast<uint32_t*>(&row_k[threadIdx.x * 4 + offset_bytes]);
1753+
uint64_t vq =
1754+
*reinterpret_cast<uint32_t*>(&row_v[threadIdx.x * 4 + offset_bytes]);
17561755
1757-
kv_dq = dequantize_packed_fp8(packed, *qparam_k_src, *qparam_v_src);
1756+
packed = kq | (vq << 32);
17581757
1758+
kv_dq = dequantize_packed_fp8(packed, *qparam_k_src, *qparam_v_src);
1759+
}
17591760
// now, write our outputs
1760-
auto* row_k_dq = &cache_K_dq[b][t][h][0];
1761-
auto* row_v_dq = &cache_V_dq[b][t][h][0];
17621761
// each thread writes 4 elements of type bf16
17631762
*reinterpret_cast<uint2*>(&row_k_dq[4 * threadIdx.x]) =
17641763
*reinterpret_cast<uint2*>(&kv_dq.vals[0]);
17651764
*reinterpret_cast<uint2*>(&row_v_dq[4 * threadIdx.x]) =
17661765
*reinterpret_cast<uint2*>(&kv_dq.vals[2]);
17671766
}
1767+
1768+
max_t = (max_t + 127) / 128 * 128;
1769+
for (; t_h < max_t * N_KVH; t_h += blockDim.y * gridDim.y) {
1770+
h = t_h % N_KVH;
1771+
t = t_h / N_KVH;
1772+
row_k_dq = &cache_K_dq[b][t][h][0];
1773+
row_v_dq = &cache_V_dq[b][t][h][0];
1774+
1775+
memset(&row_k_dq[4 * threadIdx.x], 0, sizeof(uint2));
1776+
memset(&row_v_dq[4 * threadIdx.x], 0, sizeof(uint2));
1777+
}
17681778
}
17691779
17701780
// Cloned from dequantize_fp8_cache_kernel because
@@ -1902,10 +1912,9 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
19021912
// matching shape with the original paged KV and feed the same buffer
19031913
// into this function at every layer to reuse it and prevent allocation.
19041914
1905-
// FIXME: T213958042
1906-
auto cache_K_dq = at::zeros(
1915+
auto cache_K_dq = at::empty(
19071916
{B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
1908-
auto cache_V_dq = at::zeros(
1917+
auto cache_V_dq = at::empty(
19091918
{B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
19101919
19111920
if (B == 0) {
@@ -1919,22 +1928,35 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
19191928
block_tables_b_stride = block_tables.value().stride(0);
19201929
}
19211930
1922-
constexpr int32_t kMaxBlocks = 256;
1931+
constexpr int32_t kMaxBlocks = 512;
19231932
dim3 blocks(B, std::max<int32_t>(1, kMaxBlocks / B));
19241933
dim3 threads(kThreadsPerWarp, kWarpsPerBlock);
19251934
if (block_tables_ptr == nullptr) {
1926-
dequantize_fp8_cache_kernel<<<
1927-
blocks,
1928-
threads,
1929-
0,
1930-
at::cuda::getCurrentCUDAStream()>>>(
1931-
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
1932-
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
1933-
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1934-
cache_K_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1935-
cache_V_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1936-
qparam_k_ptr,
1937-
qparam_v_ptr);
1935+
if (qparam_k_ptr) {
1936+
dequantize_fp8_cache_kernel<true>
1937+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
1938+
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
1939+
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
1940+
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1941+
cache_K_dq
1942+
.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1943+
cache_V_dq
1944+
.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1945+
qparam_k_ptr,
1946+
qparam_v_ptr);
1947+
} else {
1948+
dequantize_fp8_cache_kernel<false>
1949+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
1950+
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
1951+
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
1952+
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1953+
cache_K_dq
1954+
.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1955+
cache_V_dq
1956+
.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1957+
qparam_k_ptr,
1958+
qparam_v_ptr);
1959+
}
19381960
C10_CUDA_KERNEL_LAUNCH_CHECK();
19391961
} else {
19401962
dequantize_fp8_cache_kernel_paged<<<

0 commit comments

Comments
 (0)