@@ -1692,6 +1692,7 @@ at::Tensor xpos_qkv_decoding(
1692
1692
1693
1693
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \
1694
1694
(defined(CUDA_VERSION) && CUDA_VERSION >= 12000 )
1695
+ template <bool ExternalQParam>
1695
1696
__global__ void dequantize_fp8_cache_kernel (
1696
1697
// This code currently represents FP8 version not int4
1697
1698
at::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
@@ -1711,60 +1712,69 @@ __global__ void dequantize_fp8_cache_kernel(
1711
1712
auto D_H_q = cache_K.size (3 );
1712
1713
// TODO: support D_H < 128 for small model used in testing.
1713
1714
CUDA_KERNEL_ASSERT (D_H == 128 );
1715
+ CUDA_KERNEL_ASSERT (4 * threadIdx .x < D_H)
1716
+ const uint8_t offset_bytes = (ExternalQParam) ? 0 : 4 ;
1717
+ CUDA_KERNEL_ASSERT (D_H_q - D_H == 0 );
1714
1718
1715
1719
auto b = blockIdx .x ;
1716
1720
// only need to dequantize this far.
1717
1721
auto max_t = kv_seqlen[b];
1718
1722
1719
1723
// one warp per T/H
1720
- for (auto t_h = threadIdx .y + blockIdx .y * blockDim .y ; t_h < max_t * N_KVH;
1724
+ int h = 0 , t = 0 ;
1725
+ uint8_t *row_k{}, *row_v{};
1726
+ c10::BFloat16 *row_k_dq{}, *row_v_dq{};
1727
+ uint64_t packed{};
1728
+ bfx8 kv_dq;
1729
+ long t_h{};
1730
+ for (t_h = threadIdx .y + blockIdx .y * blockDim .y ; t_h < max_t * N_KVH;
1721
1731
t_h += blockDim .y * gridDim .y ) {
1722
- auto h = t_h % N_KVH;
1723
- auto t = t_h / N_KVH;
1724
-
1725
- auto * row_k = &cache_K[b][t][h][0 ]; // uint8_t*
1726
- auto * row_v = &cache_V[b][t][h][0 ];
1727
- bfx8 kv_dq;
1728
- uint8_t qparam_offset_bytes;
1729
- __half2* qparam_k_src;
1730
- __half2* qparam_v_src;
1731
- if (qparam_k_ptr) {
1732
- // read from standalone qparam tensor
1733
- qparam_offset_bytes = 0 ;
1734
- auto idx = b * (MAX_T * N_KVH) + t * N_KVH + h;
1735
- qparam_k_src = reinterpret_cast <__half2*>(&qparam_k_ptr[idx]);
1736
- qparam_v_src = reinterpret_cast <__half2*>(&qparam_v_ptr[idx]);
1737
- } else {
1738
- // read from first row
1739
- qparam_offset_bytes = 4 ;
1740
- qparam_k_src = reinterpret_cast <__half2*>(&row_k[0 ]);
1741
- qparam_v_src = reinterpret_cast <__half2*>(&row_v[0 ]);
1742
- }
1743
- // Assert the quantized row dim is as expected
1744
- CUDA_KERNEL_ASSERT (D_H_q - D_H == qparam_offset_bytes);
1745
- if (4 * threadIdx .x >= D_H) {
1746
- continue ;
1747
- }
1748
- // each thread reads 4 x 8 bits
1749
-
1750
- uint64_t kq = *reinterpret_cast <uint32_t *>(
1751
- &row_k[threadIdx .x * 4 + qparam_offset_bytes]);
1752
- uint64_t vq = *reinterpret_cast <uint32_t *>(
1753
- &row_v[threadIdx .x * 4 + qparam_offset_bytes]);
1754
-
1755
- uint64_t packed = kq | (vq << 32 );
1732
+ h = t_h % N_KVH;
1733
+ t = t_h / N_KVH;
1734
+
1735
+ row_k = &cache_K[b][t][h][0 ];
1736
+ row_v = &cache_V[b][t][h][0 ];
1737
+ row_k_dq = &cache_K_dq[b][t][h][0 ];
1738
+ row_v_dq = &cache_V_dq[b][t][h][0 ];
1739
+ // Calculate kv_dq for this row
1740
+ {
1741
+ __half2* qparam_k_src;
1742
+ __half2* qparam_v_src;
1743
+ if (ExternalQParam) {
1744
+ size_t idx = b * (MAX_T * N_KVH) + t * N_KVH + h;
1745
+ qparam_k_src = reinterpret_cast <__half2*>(&qparam_k_ptr[idx]);
1746
+ qparam_v_src = reinterpret_cast <__half2*>(&qparam_v_ptr[idx]);
1747
+ } else {
1748
+ qparam_k_src = reinterpret_cast <__half2*>(&row_k[0 ]);
1749
+ qparam_v_src = reinterpret_cast <__half2*>(&row_v[0 ]);
1750
+ }
1751
+ uint64_t kq =
1752
+ *reinterpret_cast <uint32_t *>(&row_k[threadIdx .x * 4 + offset_bytes]);
1753
+ uint64_t vq =
1754
+ *reinterpret_cast <uint32_t *>(&row_v[threadIdx .x * 4 + offset_bytes]);
1756
1755
1757
- kv_dq = dequantize_packed_fp8 (packed, *qparam_k_src, *qparam_v_src );
1756
+ packed = kq | (vq << 32 );
1758
1757
1758
+ kv_dq = dequantize_packed_fp8 (packed, *qparam_k_src, *qparam_v_src);
1759
+ }
1759
1760
// now, write our outputs
1760
- auto * row_k_dq = &cache_K_dq[b][t][h][0 ];
1761
- auto * row_v_dq = &cache_V_dq[b][t][h][0 ];
1762
1761
// each thread writes 4 elements of type bf16
1763
1762
*reinterpret_cast <uint2 *>(&row_k_dq[4 * threadIdx .x ]) =
1764
1763
*reinterpret_cast <uint2 *>(&kv_dq.vals [0 ]);
1765
1764
*reinterpret_cast <uint2 *>(&row_v_dq[4 * threadIdx .x ]) =
1766
1765
*reinterpret_cast <uint2 *>(&kv_dq.vals [2 ]);
1767
1766
}
1767
+
1768
+ max_t = (max_t + 127 ) / 128 * 128 ;
1769
+ for (; t_h < max_t * N_KVH; t_h += blockDim .y * gridDim .y ) {
1770
+ h = t_h % N_KVH;
1771
+ t = t_h / N_KVH;
1772
+ row_k_dq = &cache_K_dq[b][t][h][0 ];
1773
+ row_v_dq = &cache_V_dq[b][t][h][0 ];
1774
+
1775
+ memset (&row_k_dq[4 * threadIdx .x ], 0 , sizeof (uint2 ));
1776
+ memset (&row_v_dq[4 * threadIdx .x ], 0 , sizeof (uint2 ));
1777
+ }
1768
1778
}
1769
1779
1770
1780
// Cloned from dequantize_fp8_cache_kernel because
@@ -1902,10 +1912,9 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
1902
1912
// matching shape with the original paged KV and feed the same buffer
1903
1913
// into this function at every layer to reuse it and prevent allocation.
1904
1914
1905
- // FIXME: T213958042
1906
- auto cache_K_dq = at::zeros (
1915
+ auto cache_K_dq = at::empty (
1907
1916
{B_KV, MAX_T, N_KVH, D_H}, cache_K.options ().dtype (at::kBFloat16 ));
1908
- auto cache_V_dq = at::zeros (
1917
+ auto cache_V_dq = at::empty (
1909
1918
{B_KV, MAX_T, N_KVH, D_H}, cache_K.options ().dtype (at::kBFloat16 ));
1910
1919
1911
1920
if (B == 0 ) {
@@ -1919,22 +1928,35 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
1919
1928
block_tables_b_stride = block_tables.value ().stride (0 );
1920
1929
}
1921
1930
1922
- constexpr int32_t kMaxBlocks = 256 ;
1931
+ constexpr int32_t kMaxBlocks = 512 ;
1923
1932
dim3 blocks (B, std::max<int32_t >(1 , kMaxBlocks / B));
1924
1933
dim3 threads (kThreadsPerWarp , kWarpsPerBlock );
1925
1934
if (block_tables_ptr == nullptr ) {
1926
- dequantize_fp8_cache_kernel<<<
1927
- blocks,
1928
- threads,
1929
- 0 ,
1930
- at::cuda::getCurrentCUDAStream ()>>>(
1931
- cache_K.packed_accessor64<uint8_t , 4 , at::RestrictPtrTraits>(),
1932
- cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
1933
- kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1934
- cache_K_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1935
- cache_V_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1936
- qparam_k_ptr,
1937
- qparam_v_ptr);
1935
+ if (qparam_k_ptr) {
1936
+ dequantize_fp8_cache_kernel<true >
1937
+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
1938
+ cache_K.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(),
1939
+ cache_V.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(),
1940
+ kv_seqlen.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1941
+ cache_K_dq
1942
+ .packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1943
+ cache_V_dq
1944
+ .packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1945
+ qparam_k_ptr,
1946
+ qparam_v_ptr);
1947
+ } else {
1948
+ dequantize_fp8_cache_kernel<false >
1949
+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
1950
+ cache_K.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(),
1951
+ cache_V.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(),
1952
+ kv_seqlen.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1953
+ cache_K_dq
1954
+ .packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1955
+ cache_V_dq
1956
+ .packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1957
+ qparam_k_ptr,
1958
+ qparam_v_ptr);
1959
+ }
1938
1960
C10_CUDA_KERNEL_LAUNCH_CHECK ();
1939
1961
} else {
1940
1962
dequantize_fp8_cache_kernel_paged<<<
0 commit comments