diff --git a/csrc/cache.h b/csrc/cache.h index f2a5ec0acf5cd..cbe44c09eb624 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -58,6 +59,15 @@ void cp_gather_cache( torch::Tensor const& cu_seq_lens, // [BATCH+1] int64_t batch_size, std::optional seq_starts = std::nullopt); +// Gather and upconvert FP8 KV cache to BF16 workspace +void cp_gather_and_upconvert_fp8_kv_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] + torch::Tensor const& dst, // [TOT_TOKENS, 576] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& seq_lens, // [BATCH] + torch::Tensor const& workspace_starts, // [BATCH] + int64_t batch_size); + // Indexer K quantization and cache function void indexer_k_quant_and_cache( torch::Tensor& k, // [num_tokens, head_dim] @@ -72,4 +82,4 @@ void cp_gather_indexer_k_quant_cache( torch::Tensor& dst_k, // [num_tokens, head_dim] torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] const torch::Tensor& block_table, // [batch_size, num_blocks] - const torch::Tensor& cu_seq_lens); // [batch_size + 1] \ No newline at end of file + const torch::Tensor& cu_seq_lens); // [batch_size + 1] diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 8a5457206c706..f11c5f24c12ec 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -2,6 +2,7 @@ #include #include #include +#include #include "cuda_utils.h" #include "cuda_compat.h" @@ -514,7 +515,8 @@ __global__ void indexer_k_quant_and_cache_kernel( const int quant_block_size, // quantization block size const int cache_block_size, // cache block size const int cache_stride, // stride for each token in kv_cache - const bool use_ue8m0 // use ue8m0 scale format + + const bool use_ue8m0 // use ue8m0 scale format ) { constexpr int VEC_SIZE = 4; const int64_t token_idx = blockIdx.x; @@ -1061,6 +1063,82 @@ void gather_and_maybe_dequant_cache( } namespace vllm { + +// Gather and upconvert FP8 KV cache tokens to BF16 workspace +// Similar to cp_gather_cache but specifically for FP8->BF16 conversion +__global__ void cp_gather_and_upconvert_fp8_kv_cache( + const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] + __nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ seq_lens, // [BATCH] + const int32_t* __restrict__ workspace_starts, // [BATCH] + const int32_t block_size, const int32_t head_dim, + const int64_t block_table_stride, const int64_t cache_block_stride, + const int64_t cache_entry_stride, const int64_t dst_entry_stride) { + const int64_t bid = blockIdx.x; // Batch ID + const int32_t num_splits = gridDim.y; + const int32_t split = blockIdx.y; + const int32_t seq_start = workspace_starts[bid]; + const int32_t seq_len = seq_lens[bid]; + const int32_t tot_slots = seq_len; + const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits); + + const int32_t split_start = split * split_slots; + const int32_t split_end = min((split + 1) * split_slots, tot_slots); + + const bool is_active_split = (split_start < tot_slots); + + if (!is_active_split) return; + + // Adjust the pointer for the block_table for this batch + const int32_t batch_offset = bid * block_table_stride; + int32_t offset = split_start; + int32_t offset_div = offset / block_size; + offset = offset % block_size; + const int32_t* batch_block_table = block_table + batch_offset; + + // Adjust dst pointer based on the cumulative sequence lengths + dst += seq_start * dst_entry_stride; + + const int tid = threadIdx.x; + + // Process each token in this split + for (int pid = split_start; pid < split_end; ++pid) { + auto block_id = batch_block_table[offset_div]; + const uint8_t* token_ptr = + src_cache + block_id * cache_block_stride + offset * cache_entry_stride; + __nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride; + + // FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16) + const uint8_t* no_pe_ptr = token_ptr; + const float* scales_ptr = reinterpret_cast(token_ptr + 512); + const __nv_bfloat16* rope_ptr = + reinterpret_cast(token_ptr + 512 + 16); + + // Parallelize fp8 dequant (512 elements) and rope copy (64 elements) + if (tid < 512) { + // FP8 dequantization + const int tile = tid >> 7; // each tile is 128 elements + const float scale = scales_ptr[tile]; + const uint8_t val = no_pe_ptr[tid]; + dst_ptr[tid] = + fp8::scaled_convert<__nv_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale); + } else if (tid < 576) { + // Rope copy (64 bf16 elements) + const int rope_idx = tid - 512; + dst_ptr[512 + rope_idx] = rope_ptr[rope_idx]; + } + + // Move to next token + offset += 1; + if (offset == block_size) { + offset_div += 1; + offset = 0; + } + } +} + template // Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by // block_size. @@ -1202,6 +1280,57 @@ void cp_gather_cache( } } +void cp_gather_and_upconvert_fp8_kv_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] + torch::Tensor const& dst, // [TOT_TOKENS, 576] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& seq_lens, // [BATCH] + torch::Tensor const& workspace_starts, // [BATCH] + int64_t batch_size) { + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int32_t block_size = src_cache.size(1); + int32_t head_dim = dst.size(1); + + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32"); + TORCH_CHECK(workspace_starts.dtype() == torch::kInt32, + "workspace_starts must be int32"); + + TORCH_CHECK(src_cache.device() == dst.device(), + "src_cache and dst must be on the same device"); + TORCH_CHECK(src_cache.device() == block_table.device(), + "src_cache and block_table must be on the same device"); + TORCH_CHECK(src_cache.device() == seq_lens.device(), + "src_cache and seq_lens must be on the same device"); + TORCH_CHECK(src_cache.device() == workspace_starts.device(), + "src_cache and workspace_starts must be on the same device"); + + TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8"); + TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16"); + TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA"); + + int64_t block_table_stride = block_table.stride(0); + int64_t cache_block_stride = src_cache.stride(0); + int64_t cache_entry_stride = src_cache.stride(1); + int64_t dst_entry_stride = dst.stride(0); + + // Decide on the number of splits based on the batch size + int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; + dim3 grid(batch_size, num_splits); + dim3 block(576); + + vllm::cp_gather_and_upconvert_fp8_kv_cache<<>>( + src_cache.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()), + block_table.data_ptr(), seq_lens.data_ptr(), + workspace_starts.data_ptr(), block_size, head_dim, + block_table_stride, cache_block_stride, cache_entry_stride, + dst_entry_stride); +} + // Macro to dispatch the kernel based on the data type. #define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ vllm::indexer_k_quant_and_cache_kernel \ diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index d4c6f8c67c516..83d4943d62776 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -754,6 +754,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache); + cache_ops.def( + "cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, " + "Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int " + "batch_size) -> ()"); + cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA, + &cp_gather_and_upconvert_fp8_kv_cache); + cache_ops.def( "indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor " "slot_mapping, " diff --git a/tests/conftest.py b/tests/conftest.py index 5b26a02823c56..b21cfd5ba85c4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -202,6 +202,27 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): cleanup_dist_env_and_memory() +@pytest.fixture +def workspace_init(): + """Initialize the workspace manager for tests that need it. + + This fixture initializes the workspace manager with a CUDA device + if available, and resets it after the test completes. Tests that + create a full vLLM engine should NOT use this fixture as the engine + will initialize the workspace manager itself. + """ + from vllm.v1.worker.workspace import ( + init_workspace_manager, + reset_workspace_manager, + ) + + if torch.cuda.is_available(): + device = torch.device("cuda:0") + init_workspace_manager(device) + yield + reset_workspace_manager() + + @pytest.fixture(autouse=True) def dynamo_reset(): yield diff --git a/tests/kernels/moe/test_batched_deepgemm.py b/tests/kernels/moe/test_batched_deepgemm.py index 59cecd60d3d61..0ba3d8d4c958e 100644 --- a/tests/kernels/moe/test_batched_deepgemm.py +++ b/tests/kernels/moe/test_batched_deepgemm.py @@ -27,7 +27,7 @@ BLOCK_SIZE = [128, 128] @pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert @pytest.mark.parametrize("topk", [2, 4]) def test_batched_deepgemm_vs_triton( - E: int, T: int, K: int, N: int, topk: int, monkeypatch + E: int, T: int, K: int, N: int, topk: int, monkeypatch, workspace_init ): """Compare BatchedDeepGemmExperts to BatchedTritonExperts.""" diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index dab1207d78031..2ef170f1ab308 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -248,6 +248,7 @@ def test_fused_moe_batched_experts( per_act_token_quant: bool, block_shape: list[int] | None, input_scales: bool, + workspace_init, ): """Note: float8_e4m3fn is not supported on CUDA architecture < 89, and those tests will be skipped on unsupported hardware.""" diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index b0ff1e64e3219..53a03f48e24ee 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -137,7 +137,7 @@ def setup_cuda(): @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() def test_w8a8_block_fp8_fused_moe( - M, N, K, E, topk, block_size, dtype, seed, monkeypatch + M, N, K, E, topk, block_size, dtype, seed, monkeypatch, workspace_init ): if topk > E: pytest.skip(f"Skipping test; topk={topk} > E={E}") diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index c15837f145705..0160694d7bb54 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -274,6 +274,7 @@ def test_cutlass_moe_8_bit_no_graph( per_act_token: bool, per_out_ch: bool, monkeypatch, + workspace_init, ep_size: int | None = None, ): current_platform.seed_everything(7) @@ -329,6 +330,7 @@ def test_cutlass_moe_8_bit_cuda_graph( per_act_token: bool, per_out_ch: bool, monkeypatch, + workspace_init, ): current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") @@ -385,9 +387,19 @@ def test_cutlass_moe_8_bit_EP( per_out_channel: bool, ep_size: int, monkeypatch, + workspace_init, ): test_cutlass_moe_8_bit_no_graph( - m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size + m, + n, + k, + e, + topk, + per_act_token, + per_out_channel, + monkeypatch, + workspace_init, + ep_size, ) @@ -419,9 +431,19 @@ def test_cutlass_moe_8_bit_EP_large( per_out_channel: bool, ep_size: int, monkeypatch, + workspace_init, ): test_cutlass_moe_8_bit_no_graph( - m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size + m, + n, + k, + e, + topk, + per_act_token, + per_out_channel, + monkeypatch, + workspace_init, + ep_size, ) @@ -445,6 +467,7 @@ def test_run_cutlass_moe_fp8( per_act_token: bool, per_out_channel: bool, ep_size: int, + workspace_init, ): current_platform.seed_everything(7) with set_current_vllm_config(vllm_config): diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 455ecacef5ec3..f427734ef09e2 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -29,6 +29,7 @@ from vllm.utils.deep_gemm import ( is_deep_gemm_supported, ) from vllm.utils.import_utils import has_deep_ep, has_deep_gemm +from vllm.v1.worker.workspace import init_workspace_manager from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -363,6 +364,9 @@ def _test_deepep_deepgemm_moe( w1_scale: torch.Tensor, w2_scale: torch.Tensor, ): + device = torch.device(f"cuda:{pgi.local_rank}") + init_workspace_manager(device) + current_platform.seed_everything(pgi.rank) w1 = w1.to(device=torch.cuda.current_device()) @@ -445,6 +449,7 @@ def test_ht_deepep_deepgemm_moe( topk: int, world_dp_size: tuple[int, int], disable_deepgemm_ue8m0, + workspace_init, ): """ Tests for High-Throughput DeepEP + DeepGemm integration. @@ -518,6 +523,7 @@ def test_ll_deepep_deepgemm_moe( block_size: list[int], world_dp_size: tuple[int, int], disable_deepgemm_ue8m0, + workspace_init, ): """ Tests for Low-Latency DeepEP + DeepGemm integration. diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index d78b8250463a9..e698ca92a1515 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ) from vllm.platforms import current_platform from vllm.utils.import_utils import has_deep_ep +from vllm.v1.worker.workspace import init_workspace_manager from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -342,6 +343,9 @@ def _deep_ep_moe( use_fp8_dispatch: bool, per_act_token_quant: bool, ): + device = torch.device(f"cuda:{pgi.local_rank}") + init_workspace_manager(device) + if not low_latency_mode: assert not use_fp8_dispatch, ( "FP8 dispatch interface is available only in low-latency mode" @@ -437,6 +441,7 @@ def test_deep_ep_moe( topk: int, world_dp_size: tuple[int, int], per_act_token_quant: bool, + workspace_init, ): low_latency_mode = False use_fp8_dispatch = False @@ -492,6 +497,7 @@ def test_low_latency_deep_ep_moe( topk: int, world_dp_size: tuple[int, int], use_fp8_dispatch: bool, + workspace_init, ): low_latency_mode = True diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index 9b1054f7d0ab8..442b561f8f315 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -143,7 +143,7 @@ NUM_EXPERTS = [32] @pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels") -def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch): +def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_init): with monkeypatch.context() as mp: mp.setenv("VLLM_USE_DEEP_GEMM", "1") diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index a6977f222408d..d553e2820e5ff 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -206,6 +206,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( topk: int, activation: str, monkeypatch, + workspace_init, ): current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index b2be03ecee2f1..133a8a4a30a60 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -51,7 +51,14 @@ MNK_FACTORS = [ @pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"]) @torch.inference_mode() def test_flashinfer_fp4_moe_no_graph( - m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, activation: str + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + activation: str, + workspace_init, ): current_platform.seed_everything(7) with set_current_vllm_config( diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py index 98e80ec029777..384f43db479b5 100644 --- a/tests/kernels/moe/test_gpt_oss_triton_kernels.py +++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py @@ -269,7 +269,7 @@ class Case: ) @pytest.mark.parametrize("num_token", [2]) @pytest.mark.parametrize("tp", [1, 2, 4, 8]) -def test_equiv(num_token, a_dtype, w_dtype, tp): +def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init): from triton_kernels.tensor_details import layout if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"): diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index 2a30ef2355529..6ebf1016c166c 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -16,6 +16,7 @@ from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.v1.worker.workspace import init_workspace_manager from .modular_kernel_tools.common import ( Config, @@ -77,6 +78,10 @@ def rank_worker( weights: WeightTensors, verbose: bool, ): + # Initialize workspace manager in child process + device = torch.device(f"cuda:{pgi.local_rank}") + init_workspace_manager(device) + current_platform.seed_everything(pgi.rank) # sanity check @@ -300,6 +305,7 @@ def test_modular_kernel_combinations_singlegpu( chunk_size: int | None, world_size: int, pytestconfig, + workspace_init, ): """Note: float8_e4m3fn is not supported on CUDA architecture < 89, and those tests will be skipped on unsupported hardware.""" diff --git a/tests/kernels/moe/test_modular_oai_triton_moe.py b/tests/kernels/moe/test_modular_oai_triton_moe.py index c8616f13bbf85..1abb08f878b2b 100644 --- a/tests/kernels/moe/test_modular_oai_triton_moe.py +++ b/tests/kernels/moe/test_modular_oai_triton_moe.py @@ -209,6 +209,7 @@ def test_oai_triton_moe( num_experts: int, topk: int, unfused: bool, + workspace_init, ): current_platform.seed_everything(0) ( diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 82659276af37c..ce99d9691fdc8 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -231,6 +231,7 @@ def test_fused_moe( padding: bool, chunk_size: int, monkeypatch, + workspace_init, ): current_platform.seed_everything(7) diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index aa544fe0e0f63..e67bd76a16181 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -40,7 +40,7 @@ MNK_FACTORS = [ @pytest.mark.parametrize("dtype", [torch.bfloat16]) @torch.inference_mode() def test_cutlass_fp4_moe_no_graph( - m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype + m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, workspace_init ): current_platform.seed_everything(7) with set_current_vllm_config( diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index f671b23d300ce..35e554e16cb38 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -46,6 +46,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ) from vllm.platforms import current_platform from vllm.utils.math_utils import round_up +from vllm.v1.worker.workspace import init_workspace_manager from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -181,6 +182,7 @@ def test_fused_moe_batched_experts( e: int, topk: int, dtype: torch.dtype, + workspace_init, ): current_platform.seed_everything(7) @@ -863,6 +865,9 @@ def _pplx_test_loop( make_weights: bool, test_fn: Callable, ): + device = torch.device(f"cuda:{pgi.local_rank}") + init_workspace_manager(device) + def format_result(msg, ex=None): if ex is not None: x = str(ex) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index b34d587eb362d..8049347280c5a 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -22,10 +22,14 @@ from tests.v1.attention.utils import ( ) from vllm import _custom_ops as ops from vllm.attention.ops import flashmla +from vllm.config import set_current_vllm_config from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseBackend -from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks +from vllm.v1.attention.backends.mla.flashmla_sparse import ( + FlashMLASparseBackend, + triton_convert_req_index_to_global_index, +) +from vllm.v1.attention.backends.utils import split_prefill_chunks SPARSE_BACKEND_BATCH_SPECS = { name: BATCH_SPECS[name] @@ -114,8 +118,12 @@ def _quantize_dequantize_fp8_ds_mla( @pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys())) @pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"]) @pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) +@pytest.mark.skipif( + torch.cuda.get_device_capability() < (9, 0), + reason="FlashMLASparseBackend requires CUDA 9.0 or higher", +) def test_sparse_backend_decode_correctness( - dist_init, batch_name, kv_cache_dtype, tensor_parallel_size + dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init ): if not torch.cuda.is_available(): pytest.skip("CUDA is required for sparse MLA decode test") @@ -320,28 +328,29 @@ def test_sparse_backend_decode_correctness( mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous()) impl_cls = FlashMLASparseBackend.get_impl_cls() - impl = impl_cls( - num_heads=num_heads, - head_size=head_size, - scale=scale, - num_kv_heads=1, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype=vllm_config.cache_config.cache_dtype, - logits_soft_cap=None, - attn_type="decoder", - kv_sharing_target_layer_name=None, - q_lora_rank=None, - kv_lora_rank=kv_lora_rank, - qk_nope_head_dim=qk_nope_head_dim, - qk_rope_head_dim=qk_rope_head_dim, - qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, - v_head_dim=v_head_dim, - kv_b_proj=mock_kv_b_proj, - indexer=mock_indexer, - ) + with set_current_vllm_config(vllm_config): + impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=1, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype=vllm_config.cache_config.cache_dtype, + logits_soft_cap=None, + attn_type="decoder", + kv_sharing_target_layer_name=None, + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_b_proj=mock_kv_b_proj, + indexer=mock_indexer, + ) - impl.process_weights_after_loading(dtype) + impl.process_weights_after_loading(dtype) layer = MockAttentionLayer(device) out_buffer = torch.empty( @@ -366,22 +375,192 @@ def test_sparse_backend_decode_correctness( torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5) +def _triton_convert_reference_impl( + req_ids: torch.Tensor, + block_table: torch.Tensor, + token_indices: torch.Tensor, + block_size: int, + num_topk_tokens: int, + HAS_PREFILL_WORKSPACE: bool = False, + prefill_workspace_request_ids: torch.Tensor | None = None, + prefill_workspace_starts: torch.Tensor | None = None, +) -> torch.Tensor: + """Reference implementation for triton_convert_req_index_to_global_index.""" + num_tokens = req_ids.shape[0] + max_blocks_per_req = block_table.shape[1] + result = torch.empty( + num_tokens, num_topk_tokens, dtype=torch.int32, device=req_ids.device + ) + + for token_id in range(num_tokens): + req_id = req_ids[token_id].item() + + # Determine if this token uses workspace or paged cache + use_prefill_workspace = False + workspace_start = 0 + if HAS_PREFILL_WORKSPACE and prefill_workspace_request_ids is not None: + assert prefill_workspace_starts is not None + prefill_req_id = prefill_workspace_request_ids[token_id].item() + if prefill_req_id >= 0: + use_prefill_workspace = True + workspace_start = prefill_workspace_starts[prefill_req_id].item() + + for idx_id in range(num_topk_tokens): + token_idx = token_indices[token_id, idx_id].item() + + if token_idx == -1: + result[token_id, idx_id] = -1 + elif use_prefill_workspace: + # Prefill + using prefill workspace: map to workspace offset + result[token_id, idx_id] = workspace_start + token_idx + else: + # Decode: map to paged cache + block_id = token_idx // block_size + if block_id >= max_blocks_per_req: + result[token_id, idx_id] = -1 + else: + block_num = block_table[req_id, block_id].item() + offset = token_idx % block_size + result[token_id, idx_id] = block_num * block_size + offset + + return result + + +@pytest.mark.parametrize("block_size", [16, 64, 128]) +@pytest.mark.parametrize("num_topk_tokens", [128, 256, 512]) +@pytest.mark.skipif( + torch.cuda.get_device_capability() < (9, 0), + reason="FlashMLASparseBackend requires CUDA 9.0 or higher", +) +def test_triton_convert_req_index_to_global_index_decode_only( + block_size, num_topk_tokens +): + device = torch.device("cuda") + num_tokens = 8 + num_requests = 4 + max_blocks_per_req = 10 + + req_id = torch.randint( + 0, num_requests, (num_tokens,), dtype=torch.int32, device=device + ) + block_table = torch.randint( + 0, 100, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device + ) + + token_indices = torch.randint( + 0, + block_size * max_blocks_per_req, + (num_tokens, num_topk_tokens), + dtype=torch.int32, + device=device, + ) + + # Set some to -1 to test masking + token_indices[0, :10] = -1 + token_indices[3, 50:60] = -1 + + # Set some to out of bounds + token_indices[2, 100:110] = max_blocks_per_req * block_size + token_indices[6, 150:160] = max_blocks_per_req * block_size + + result = triton_convert_req_index_to_global_index( + req_id, + block_table, + token_indices, + BLOCK_SIZE=block_size, + NUM_TOPK_TOKENS=num_topk_tokens, + ) + + reference_result = _triton_convert_reference_impl( + req_id, + block_table, + token_indices, + block_size, + num_topk_tokens, + ) + + torch.testing.assert_close(result, reference_result, rtol=0, atol=0) + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.skipif( + torch.cuda.get_device_capability() < (9, 0), + reason="FlashMLASparseBackend requires CUDA 9.0 or higher", +) +def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_size): + device = torch.device("cuda") + num_requests = 4 + max_blocks_per_req = 8 + num_topk_tokens = 128 + + # First 6 tokens are decode (reqs 0, 1), last 6 are prefill (reqs 2, 3) + req_id = torch.tensor( + [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], dtype=torch.int32, device=device + ) + prefill_workspace_request_ids = torch.tensor( + [-1, -1, -1, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=torch.int32, device=device + ) + + # Workspace starts for the 2 prefill reqs: req 2 starts at 0, req 3 starts at 100 + prefill_workspace_starts = torch.tensor([0, 100], dtype=torch.int32, device=device) + + block_table = torch.randint( + 0, 50, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device + ) + token_indices = torch.randint( + 0, + block_size * max_blocks_per_req, + (req_id.shape[0], num_topk_tokens), + dtype=torch.int32, + device=device, + ) + + # Set some to -1 to test masking + token_indices[0, :10] = -1 + token_indices[3, 50:60] = -1 + + # Set some to out of bounds + token_indices[2, 100:110] = max_blocks_per_req * block_size + token_indices[6, 150:160] = max_blocks_per_req * block_size + + result = triton_convert_req_index_to_global_index( + req_id, + block_table, + token_indices, + BLOCK_SIZE=block_size, + NUM_TOPK_TOKENS=num_topk_tokens, + HAS_PREFILL_WORKSPACE=True, + prefill_workspace_request_ids=prefill_workspace_request_ids, + prefill_workspace_starts=prefill_workspace_starts, + ) + + reference_result = _triton_convert_reference_impl( + req_id, + block_table, + token_indices, + block_size, + num_topk_tokens, + HAS_PREFILL_WORKSPACE=True, + prefill_workspace_request_ids=prefill_workspace_request_ids, + prefill_workspace_starts=prefill_workspace_starts, + ) + + torch.testing.assert_close(result, reference_result, rtol=0, atol=0) + + @pytest.mark.parametrize( - "seq_lens,max_buf,start,expected", + "seq_lens,max_buf,expected", [ # Basic split: totals per chunk ≤ max_buf - (torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]), - # Non-zero start index - (torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]), - # Exact fits should split between items when adding the next would - # overflow - (torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]), + (torch.tensor([2, 3, 4, 2]), 5, [(0, 2), (2, 3), (3, 4)]), + # Exact fits should split between items when adding the next would overflow + (torch.tensor([5, 5, 5]), 5, [(0, 1), (1, 2), (2, 3)]), # All requests fit in a single chunk - (torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]), - # Large buffer with non-zero start - (torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]), + (torch.tensor([1, 1, 1]), 10, [(0, 3)]), + # Large buffer + (torch.tensor([4, 4, 4]), 100, [(0, 3)]), ], ) -def test_split_prefill_chunks(seq_lens, max_buf, start, expected): - out = split_prefill_chunks(seq_lens, max_buf, start) +def test_split_prefill_chunks(seq_lens, max_buf, expected): + out = split_prefill_chunks(seq_lens, max_buf) assert out == expected diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 6d862c5812560..52a58a082683d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2403,6 +2403,29 @@ def cp_gather_cache( ) +def cp_gather_and_upconvert_fp8_kv_cache( + src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + seq_lens: torch.Tensor, + workspace_starts: torch.Tensor, + batch_size: int, +) -> None: + """Gather and upconvert FP8 KV cache to BF16 workspace. + + Args: + src_cache: FP8 KV cache [num_blocks, block_size, 656] + dst: BF16 output workspace [total_tokens, 576] + block_table: Block indices [num_reqs, max_blocks] + seq_lens: Sequence lengths [num_reqs] + workspace_starts: Workspace start offsets [num_reqs] + batch_size: Number of requests + """ + torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache( + src_cache, dst, block_table, seq_lens, workspace_starts, batch_size + ) + + def indexer_k_quant_and_cache( k: torch.Tensor, kv_cache: torch.Tensor, diff --git a/vllm/envs.py b/vllm/envs.py index cb75ba1a62de9..d0f2798096263 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -239,6 +239,7 @@ if TYPE_CHECKING: VLLM_NCCL_INCLUDE_PATH: str | None = None VLLM_USE_FBGEMM: bool = False VLLM_GC_DEBUG: str = "" + VLLM_DEBUG_WORKSPACE: bool = False VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" @@ -1537,6 +1538,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with # top 5 collected objects "VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""), + # Debug workspace allocations. + # logging of workspace resize operations. + "VLLM_DEBUG_WORKSPACE": lambda: bool(int(os.getenv("VLLM_DEBUG_WORKSPACE", "0"))), # Disables parallel execution of shared_experts via separate cuda stream "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool( int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0")) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 075610ec588ae..9e75a7c08070e 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -22,12 +22,12 @@ from vllm.model_executor.layers.fused_moe.utils import ( from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.v1.worker.ubatching import ( - dbo_current_ubatch_id, dbo_enabled, dbo_maybe_run_recv_hook, dbo_register_recv_hook, dbo_yield, ) +from vllm.v1.worker.workspace import current_workspace_manager logger = init_logger(__name__) @@ -661,25 +661,6 @@ def _slice_scales( return None -class SharedResizableBuffer: - def __init__(self): - self.buffer = None - - def get( - self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype - ) -> torch.Tensor: - assert shape != () - shape_numel = prod(shape) - if ( - self.buffer is None - or self.buffer.numel() < shape_numel - or self.buffer.device != device - or self.buffer.dtype != dtype - ): - self.buffer = torch.empty(shape_numel, device=device, dtype=dtype) - return self.buffer[:shape_numel].view(*shape) - - @final class FusedMoEModularKernel(torch.nn.Module): """ @@ -694,22 +675,6 @@ class FusedMoEModularKernel(torch.nn.Module): objects. """ - class SharedBuffers: - def __init__(self) -> None: - self.fused_out = SharedResizableBuffer() - self.workspace13 = SharedResizableBuffer() - self.workspace2 = SharedResizableBuffer() - - # Persistent buffers that are shared across `FusedMoEModularKernel` - # instances (layers), to save memory and allocattions. - # - # We have two sets of buffers to support dual batch overlap (DBO) where each - # microbatch (ubatch) should use its own set of buffers to avoid - # cross-ubatch contimination. - # NOTE that memory is lazily allocated for these buffers, meaning that if - # DBO isn't being used, the second SharedBuffers will be empty. - shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()] - def __init__( self, prepare_finalize: FusedMoEPrepareAndFinalize, @@ -806,10 +771,6 @@ class FusedMoEModularKernel(torch.nn.Module): assert M_full > 0 and M_chunk > 0 num_chunks, _ = self._chunk_info(M_full) - - # select per-ubatch buffers to avoid cross-ubatch reuse under DBO - ubatch_idx = dbo_current_ubatch_id() - buffers = self.shared_buffers[ubatch_idx] workspace_dtype = self.fused_experts.workspace_dtype(out_dtype) # Force worst-case allocation in profiling run for @@ -832,14 +793,11 @@ class FusedMoEModularKernel(torch.nn.Module): expert_tokens_meta, ) ) - buffers.workspace13.get( - max_workspace_13, device=device, dtype=workspace_dtype - ) - buffers.workspace2.get( - max_workspace_2, device=device, dtype=workspace_dtype - ) - buffers.fused_out.get( - max_fused_out_shape, device=device, dtype=workspace_dtype + + current_workspace_manager().get_simultaneous( + (max_workspace_13, workspace_dtype), + (max_workspace_2, workspace_dtype), + (max_fused_out_shape, out_dtype), ) # Get intermediate workspace shapes based off the chunked M size. @@ -866,22 +824,23 @@ class FusedMoEModularKernel(torch.nn.Module): # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. - workspace13 = buffers.workspace13.get( - workspace13_shape, device=device, dtype=workspace_dtype - ) - workspace2 = buffers.workspace2.get( - workspace2_shape, device=device, dtype=workspace_dtype - ) - # Construct the entire output that can then be processed in chunks. # Reuse workspace13 for the output in the non-chunked case as long # as it is large enough. This will not always be the case for standard # format experts and with experts that have empty workspaces. if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape): + workspace13, workspace2 = current_workspace_manager().get_simultaneous( + (workspace13_shape, workspace_dtype), + (workspace2_shape, workspace_dtype), + ) fused_out = _resize_cache(workspace13, fused_out_shape) else: - fused_out = buffers.fused_out.get( - fused_out_shape, device=device, dtype=out_dtype + workspace13, workspace2, fused_out = ( + current_workspace_manager().get_simultaneous( + (workspace13_shape, workspace_dtype), + (workspace2_shape, workspace_dtype), + (fused_out_shape, out_dtype), + ) ) return workspace13, workspace2, fused_out diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index a9fa76deecbd2..146124153c79d 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -83,6 +83,7 @@ from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerMetadata, ) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec +from vllm.v1.worker.workspace import current_workspace_manager from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP from .utils import ( @@ -616,8 +617,15 @@ def sparse_attn_indexer( # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata fp8_dtype = current_platform.fp8_dtype() + # assert isinstance(attn_metadata, dict) if not isinstance(attn_metadata, dict): + # Reserve workspace for indexer during profiling run + current_workspace_manager().get_simultaneous( + ((total_seq_lens, head_dim), torch.float8_e4m3fn), + ((total_seq_lens, 4), torch.uint8), + ) + return sparse_attn_indexer_fake( hidden_states, k_cache_prefix, @@ -651,17 +659,17 @@ def sparse_attn_indexer( topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: prefill_metadata = attn_metadata.prefill + + # Get the full shared workspace buffers once (will allocate on first use) + workspace_manager = current_workspace_manager() + k_fp8_full, k_scale_full = workspace_manager.get_simultaneous( + ((total_seq_lens, head_dim), fp8_dtype), + ((total_seq_lens, 4), torch.uint8), + ) + for chunk in prefill_metadata.chunks: - k_fp8 = torch.empty( - [chunk.total_seq_lens, head_dim], - device=k.device, - dtype=fp8_dtype, - ) - k_scale = torch.empty( - [chunk.total_seq_lens, 4], - device=k.device, - dtype=torch.uint8, - ) + k_fp8 = k_fp8_full[: chunk.total_seq_lens] + k_scale = k_scale_full[: chunk.total_seq_lens] ops.cp_gather_indexer_k_quant_cache( kv_cache, k_fp8, @@ -777,15 +785,6 @@ def sparse_attn_indexer_fake( total_seq_lens: int, topk_indices_buffer: torch.Tensor | None, ) -> torch.Tensor: - # profile run - # NOTE(Chen): create the max possible flattened_kv. So that - # profile_run can get correct memory usage. - _flattened_kv = torch.empty( - [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8 - ) - fp8_dtype = current_platform.fp8_dtype() - _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() - _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() return topk_indices_buffer diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 1eee1d225293b..f3052fbaf2a65 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -18,7 +18,7 @@ from vllm.attention.ops.flashmla import ( flash_mla_with_kvcache, get_mla_metadata, ) -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.platforms import current_platform @@ -30,13 +30,31 @@ from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + reshape_attn_output_for_spec_decode, + reshape_query_for_spec_decode, + split_decodes_and_prefills, + split_prefill_chunks, ) from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.workspace import current_workspace_manager if TYPE_CHECKING: from vllm.model_executor.models.deepseek_v2 import Indexer logger = init_logger(__name__) + +# For FP8 sparse attention we have two impelementations: +# 1. Mixed batch mode: use the FP8 decode kernel for both prefill and decode this is +# done by treating all tokens as single batch. +# 2. Separate prefill and decode mode: use the BF16 prefill kernel for prefill +# (upconverting the FP8 cache to BF16 then calling the prefill kernel) and using +# the FP8 decode kernel for decode. +# Currently we use #1 when the number of heads per rank is low (i.e. TP) since the BF16 +# prefill kernel requires padding the numer of heads to 128 while the decode does not +# so when the per ranke head count is below MIN_HEADS_FOR_BF16_PREFILL we use the mixed +# batch mode (#2). +MIN_HEADS_FOR_BF16_PREFILL = 32 + """ NOTE: FlashMLA Sparse uses an fp8 cache with the following format @@ -127,19 +145,72 @@ class FlashMLASparseMetadata: dummy_block_table: torch.Tensor cache_lens: torch.Tensor - fp8_extra_metadata: FP8KernelMetadata | None = None + @dataclass + class FP8SeperatePrefillDecode: + @dataclass + class Decode: + kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata" + decode_query_len: int # needed for reshape in spec decode + + @dataclass + class Prefill: + # Sequence lengths (context + query) for prefill requests + # Shape: [num_prefill_reqs] + seq_lens: torch.Tensor + + # Request ID for each token: -1 for decode tokens, request index + # (0, 1, 2, ...) for prefill tokens. + # Shape: [num_actual_tokens] + request_ids: torch.Tensor + + # Workspace start offsets for all prefill requests + # Shape: [num_prefill_reqs], adjusted in-place per chunk to be + # 0-indexed within each chunk. Used to map prefill tokens to workspace + # offsets in convert_logical_index_to_physical_index + workspace_starts: torch.Tensor + + @dataclass + class Chunk: + """Metadata for a chunk of prefill requests. + + Prefill requests may be chunked to fit within the fixed workspace size. + """ + + seq_lens: torch.Tensor + tokens_slice: slice + block_table: torch.Tensor + req_start_idx: int + workspace_starts: torch.Tensor + chunk_tot_seqlen: int + + chunks: list[Chunk] + + num_prefills: int = 0 + num_decodes: int = 0 + num_prefill_tokens: int = 0 + num_decode_tokens: int = 0 + + decode: Decode | None = None + prefill: Prefill | None = None + + fp8_extra_metadata: FP8SeperatePrefillDecode | FP8KernelMetadata | None = None + fp8_use_mixed_batch: bool = False +# Kernel with prefill workspace support @triton.jit def _convert_req_index_to_global_index_kernel( req_id_ptr, # int32 [num_tokens] block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill + workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr # shapes (compile-time where possible) max_num_blocks_per_req: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, # tile width along columns + HAS_PREFILL: tl.constexpr, # strides (in elements) bt_stride0, bt_stride1, @@ -165,7 +236,10 @@ def _convert_req_index_to_global_index_kernel( # Only token == -1 should propagate as -1 is_invalid_tok = tok < 0 - + is_prefill = False + if HAS_PREFILL: + prefill_req_id = tl.load(prefill_request_id_ptr + token_id) + is_prefill = prefill_req_id >= 0 # Compute block id and in-block offset block_id = tok // BLOCK_SIZE inblock_off = tok % BLOCK_SIZE @@ -173,12 +247,18 @@ def _convert_req_index_to_global_index_kernel( # Guard block_table access valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0) bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 - base = tl.load(bt_ptr, mask=valid_block, other=0) + is_invalid_tok |= ~valid_block + base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0) + out_val = base * BLOCK_SIZE + inblock_off - # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset - out_val = tl.where( - is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off - ) + # Override with prefill output if prefill is enabled + if HAS_PREFILL: + workspace_start = tl.load( + workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 + ) + prefill_out = workspace_start + tok + out_val = tl.where(is_prefill, prefill_out, out_val) + out_val = tl.where(is_invalid_tok, -1, out_val) # Store results out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 @@ -192,6 +272,9 @@ def triton_convert_req_index_to_global_index( BLOCK_SIZE: int = 64, NUM_TOPK_TOKENS: int = 2048, BLOCK_N: int = 128, # tile width along columns + HAS_PREFILL_WORKSPACE: bool = False, + prefill_workspace_request_ids: torch.Tensor | None = None, + prefill_workspace_starts: torch.Tensor | None = None, ): """ out[token_id, indice_id] = @@ -202,17 +285,32 @@ def triton_convert_req_index_to_global_index( Only when token_indices[token_id, indice_id] == -1 do we output -1. For safety, we also output -1 if the derived block_id would be out-of-bounds. + + When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets + instead of global cache slots. prefill_workspace_request_ids and + prefill_workspace_starts must be provided. + + prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else + prefill request index (maps to prefill_workspace_starts) + prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace + starts for each prefill request """ assert req_id.dtype == torch.int32 assert block_table.dtype == torch.int32 assert token_indices.dtype == torch.int32 assert token_indices.shape[1] == NUM_TOPK_TOKENS assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( - f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})" + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})" ) + if HAS_PREFILL_WORKSPACE: + assert prefill_workspace_request_ids is not None + assert prefill_workspace_starts is not None + assert prefill_workspace_request_ids.dtype == torch.int32 + assert prefill_workspace_starts.dtype == torch.int32 + num_tokens = req_id.shape[0] - num_requests, max_num_blocks_per_req = block_table.shape + max_num_blocks_per_req = block_table.shape[1] tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N # Ensure contiguous tensors on the same device @@ -226,6 +324,13 @@ def triton_convert_req_index_to_global_index( ti_stride0, ti_stride1 = token_indices_c.stride() out_stride0, out_stride1 = out.stride() + # Prepare prefill pointers + if HAS_PREFILL_WORKSPACE: + assert prefill_workspace_request_ids is not None # for mypy + assert prefill_workspace_starts is not None # for mypy + assert prefill_workspace_request_ids.is_contiguous() + assert prefill_workspace_starts.is_contiguous() + # Exact 2D grid: tokens × column tiles grid = (num_tokens, tiles_per_row) @@ -234,10 +339,13 @@ def triton_convert_req_index_to_global_index( block_table_c, token_indices_c, out, + prefill_workspace_request_ids, + prefill_workspace_starts, # shapes / constexprs max_num_blocks_per_req, BLOCK_SIZE, BLOCK_N, + HAS_PREFILL_WORKSPACE, # strides bt_stride0, bt_stride1, @@ -249,7 +357,16 @@ def triton_convert_req_index_to_global_index( return out -@dataclass +def get_prefill_workspace_size(max_model_len: int): + # NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size. + # May be tuned later. + # Memory usage: 5 * max_model_len * 576 * 2 bytes + # Example: DeepSeek-V3.2 with max_model_len=163840 -> + # 5 * 163840 * 576 * 2 = ~900 MB + # This fits nicely below the typical MoE workspace size of >2GB so this is "free" + return max_model_len * 5 + + class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH @@ -259,29 +376,42 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad layer_names: list[str], vllm_config: VllmConfig, device: torch.device, - ): + ) -> None: + self.vllm_config = vllm_config + self.layer_names = layer_names cache_config = vllm_config.cache_config self.kv_cache_spec = kv_cache_spec self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.device = device + # Treat requests with query length <= 1 as decodes to match the + # DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2) + self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) + props = torch.cuda.get_device_properties(device) sm_count = props.multi_processor_count self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) + self.topk_tokens = vllm_config.model_config.hf_config.index_topk self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla" - self.topk_tokens_tensor = torch.tensor( - [self.topk_tokens], device=device, dtype=torch.int32 + max_num_seqs = vllm_config.scheduler_config.max_num_seqs + # Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG) + self.topk_tokens_tensor = torch.full( + (max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32 ) - self.max_model_len_tensor = torch.tensor( - [self.model_config.max_model_len], device=device, dtype=torch.int32 + # Shape: [max_num_seqs], all elements = max_model_len + self.max_model_len_tensor = torch.full( + (max_num_seqs,), + self.model_config.max_model_len, + device=device, + dtype=torch.int32, ) # this is ignored by `flash_mla_with_kvcache` if indices not None self.dummy_block_table = torch.empty( - (1, 1), dtype=torch.int32, device=self.device + (max_num_seqs, 1), dtype=torch.int32, device=self.device ) # Equation taken from FlashMLA/csrc/pybind.cpp @@ -299,10 +429,9 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad dtype=torch.int32, device=device, ) + # Sized for per-request batching (num_decodes + 1) self.num_splits_buffer = torch.empty( - # We pack all the tokens into one batch for sparse attention. - # Otherwise, we can exceed the sm of `get_mla_metadata`. - (2,), + (max_num_seqs + 1,), dtype=torch.int32, device=device, ) @@ -312,30 +441,171 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad device=device, ) - def build( + def _build_fp8_mixed_decode_prefill( self, - common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False, - ) -> FlashMLASparseMetadata: - num_tokens = common_attn_metadata.num_actual_tokens - starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32) - seg_lengths = np.diff(starts) - req_id_per_token = np.repeat( - np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths - ) - # Zero-fill for cudagraphs - self.req_id_per_token_buffer.fill_(0) - self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( - torch.from_numpy(req_id_per_token), non_blocking=True - ) - req_id_per_token = self.req_id_per_token_buffer[:num_tokens] + ) -> "FlashMLASparseMetadata.FP8KernelMetadata": + """Build FP8 metadata treating all tokens as one mixed batch. + + This matches main branch's approach and avoids the BF16 prefill kernel + which has head padding overhead when num_heads is small (high TP case). + """ + num_tokens = common_attn_metadata.num_actual_tokens + + # Build metadata for all tokens as a single batch + tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens=self.topk_tokens_tensor[:1], # Single batch + num_q_tokens_per_head_k=num_tokens * self.num_heads, + topk=self.topk_tokens, + num_heads_q=self.num_heads, + num_heads_k=1, + is_fp8_kvcache=True, + ) + + num_sm_parts = tile_scheduler_metadata.size(0) + tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[ + :num_sm_parts + ] + tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) + num_splits_view = self.num_splits_buffer[:2] + num_splits_view.copy_(num_splits) + + fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata( + scheduler_metadata=tile_scheduler_metadata_buffer, + num_splits=num_splits_view, + cache_lens=self.max_model_len_tensor[:1], + dummy_block_table=self.dummy_block_table[:1], + ) + + return fp8_metadata + + def _build_fp8_separate_prefill_decode( + self, + common_attn_metadata: CommonAttentionMetadata, + ) -> "FlashMLASparseMetadata.FP8SeperatePrefillDecode": + num_tokens = common_attn_metadata.num_actual_tokens + + (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold or 1, + require_uniform=True, + ) + ) + + FP8Meta = FlashMLASparseMetadata.FP8SeperatePrefillDecode + fp8_metadata = FP8Meta( + num_decodes=num_decodes, + num_prefills=num_prefills, + num_decode_tokens=num_decode_tokens, + num_prefill_tokens=num_prefill_tokens, + ) + + # Extract prefill sequence lengths (context + query, not just query) + # Decode requests come first in the batch, prefill requests follow + prefill_seq_lens = None + prefill_request_id = None + prefill_workspace_starts = None + prefill_chunks = None + + # For pure decode batches, prefill_request_id will be None + # For mixed batches, it will have -1 for decode and request_id for prefill + if num_prefills > 0: + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + seq_lens = common_attn_metadata.seq_lens + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + prefill_seq_lens_cpu = seq_lens_cpu[num_decodes:] + prefill_seq_lens = seq_lens[num_decodes:] + + # Build prefill_request_id: -1 for decode, request index for + # prefill. This enables a single + # convert_logical_index_to_physical_index call for all tokens + prefill_request_id = torch.full( + (num_tokens,), -1, dtype=torch.int32, device=self.device + ) + # Map prefill tokens to their request IDs (0, 1, 2, ...) + for req_idx in range(num_prefills): + # Get query token range for this prefill request + global_req_idx = num_decodes + req_idx + req_query_start = query_start_loc_cpu[global_req_idx] + req_query_end = query_start_loc_cpu[global_req_idx + 1] + prefill_request_id[req_query_start:req_query_end] = req_idx + + # will be adjusted by chunk loop + prefill_workspace_starts_cpu = torch.zeros( + num_prefills, dtype=torch.int32, pin_memory=True + ) + prefill_workspace_starts_cpu[1:] = torch.cumsum( + prefill_seq_lens_cpu[:-1], dim=0 + ) + # populated by non-blocking copy after prefill_workspace_starts_cpu is + # updated by each chunk + prefill_workspace_starts = torch.empty( + num_prefills, dtype=torch.int32, device=self.device + ) + + # Chunk prefill requests to fit within workspace size + max_prefill_buffer_size = get_prefill_workspace_size( + self.vllm_config.model_config.max_model_len + ) + chunk_bounds = split_prefill_chunks( + prefill_seq_lens_cpu, max_prefill_buffer_size + ) + + prefill_chunks = [] + for chunk_start, chunk_end in chunk_bounds: + # Adjust workspace_starts in-place per chunk to be + # 0-indexed within each chunk + # Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]] + # Initial: workspace_starts=[0,10,25,45] + # After: workspace_starts=[0,10,0,20] + # (chunk 0 starts at 0, chunk 1 starts at 0) + offset = prefill_workspace_starts_cpu[chunk_start].item() + prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset + + chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end] + chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum() + token_start = query_start_loc_cpu[num_decodes + chunk_start].item() + token_end = query_start_loc_cpu[num_decodes + chunk_end].item() + tokens_slice = slice(token_start, token_end) + + # Create chunk view of gpu tensor + chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end] + chunk_block_table = common_attn_metadata.block_table_tensor[ + num_decodes + chunk_start : num_decodes + chunk_end + ] + + prefill_chunks.append( + FP8Meta.Prefill.Chunk( + seq_lens=chunk_seq_lens, + tokens_slice=tokens_slice, + block_table=chunk_block_table, + req_start_idx=chunk_start, + workspace_starts=chunk_workspace_starts, + chunk_tot_seqlen=chunk_tot_seqlen, + ) + ) + + prefill_workspace_starts.copy_( + prefill_workspace_starts_cpu, non_blocking=True + ) + + fp8_metadata.prefill = FP8Meta.Prefill( + seq_lens=prefill_seq_lens, + request_ids=prefill_request_id, + workspace_starts=prefill_workspace_starts, + chunks=prefill_chunks, + ) + + if num_decodes > 0: + # Compute decode_query_len for spec decode (uniform due to require_uniform) + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item() - fp8_extra_metadata = None - if self.use_fp8_kv_cache: tile_scheduler_metadata, num_splits = get_mla_metadata( - cache_seqlens=self.topk_tokens_tensor, - num_q_tokens_per_head_k=num_tokens * self.num_heads, + cache_seqlens=self.topk_tokens_tensor[:num_decodes], + num_q_tokens_per_head_k=decode_query_len * self.num_heads, topk=self.topk_tokens, num_heads_q=self.num_heads, num_heads_k=1, @@ -348,33 +618,70 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad :num_sm_parts ] tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) - self.num_splits_buffer.copy_(num_splits) + # num_splits has size [num_decodes + 1] + num_splits_view = self.num_splits_buffer[: num_decodes + 1] + num_splits_view.copy_(num_splits) - fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata( + kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata( scheduler_metadata=tile_scheduler_metadata_buffer, - num_splits=self.num_splits_buffer, - # cache_lens and block_table are basically unused in sparse case - # but the decode kernel will treat -1 and indices >= cache_lens - # as invalid so we make sure cache_lens is large enough to not - # accidentally mark indices invalid, we will use -1 exclusively - # to mark invalid indices - cache_lens=self.max_model_len_tensor, - dummy_block_table=self.dummy_block_table, + num_splits=num_splits_view, + dummy_block_table=self.dummy_block_table[:num_decodes], + cache_lens=self.max_model_len_tensor[:num_decodes], + ) + fp8_metadata.decode = FP8Meta.Decode( + kernel_metadata=kernel_meta, + decode_query_len=decode_query_len, ) + return fp8_metadata + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashMLASparseMetadata: + cm = common_attn_metadata + num_tokens = cm.num_actual_tokens + starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32) + seg_lengths = np.diff(starts) + req_id_per_token = np.repeat( + np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths + ) + # Zero-fill for cudagraphs + self.req_id_per_token_buffer.fill_(0) + self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( + torch.from_numpy(req_id_per_token), non_blocking=True + ) + req_id_per_token = self.req_id_per_token_buffer[:num_tokens] + + fp8_extra_metadata: ( + FlashMLASparseMetadata.FP8SeperatePrefillDecode + | FlashMLASparseMetadata.FP8KernelMetadata + | None + ) = None + fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL + if self.use_fp8_kv_cache: + if fp8_use_mixed_batch: + fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm) + else: + fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm) + metadata = FlashMLASparseMetadata( - num_reqs=common_attn_metadata.num_reqs, - max_query_len=common_attn_metadata.max_query_len, - max_seq_len=common_attn_metadata.max_seq_len, - num_actual_tokens=common_attn_metadata.num_actual_tokens, - query_start_loc=common_attn_metadata.query_start_loc, - slot_mapping=common_attn_metadata.slot_mapping, - block_table=common_attn_metadata.block_table_tensor, + num_reqs=cm.num_reqs, + max_query_len=cm.max_query_len, + max_seq_len=cm.max_seq_len, + num_actual_tokens=cm.num_actual_tokens, + query_start_loc=cm.query_start_loc, + slot_mapping=cm.slot_mapping, + block_table=cm.block_table_tensor, req_id_per_token=req_id_per_token, block_size=self.kv_cache_spec.block_size, topk_tokens=self.topk_tokens, fp8_extra_metadata=fp8_extra_metadata, + fp8_use_mixed_batch=fp8_use_mixed_batch, ) + return metadata @@ -414,12 +721,204 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): self.topk_indices_buffer = indexer.topk_indices_buffer self.padding = 128 if current_platform.is_device_capability(100) else 64 + if kv_cache_dtype == "fp8_ds_mla": + # Reserve workspace during initialization + vllm_config = get_current_vllm_config() + assert vllm_config is not None and vllm_config.model_config is not None + prefill_workspace_size = get_prefill_workspace_size( + vllm_config.model_config.max_model_len + ) + self.prefill_workspace_shape = (prefill_workspace_size, head_size) + (self.prefill_bf16_workspace,) = ( + current_workspace_manager().get_simultaneous( + (self.prefill_workspace_shape, torch.bfloat16) + ) + ) + def _forward_bf16_kv( self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, topk_indices: torch.Tensor, attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: + # Convert per-request indices to global slots (decode) or workspace + # offsets (prefill). + topk_indices = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=topk_indices.shape[1], + ) + + return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices) + + def _forward_fp8_kv_separate_prefill_decode( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: + fp8_metadata = attn_metadata.fp8_extra_metadata + assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode) + num_decodes = fp8_metadata.num_decodes + + prefill_request_ids = None + prefill_workspace_starts = None + has_prefill_workspace = False + if fp8_metadata.prefill is not None: + prefill_request_ids = fp8_metadata.prefill.request_ids + prefill_workspace_starts = fp8_metadata.prefill.workspace_starts + has_prefill_workspace = True + + # Convert per-request indices to global slots (decode) or workspace + # offsets (prefill). + # For FP8 cache: prefill uses workspace mapping (upconverted to BF16) + # For BF16 cache: always use global cache slots (no workspace) + # prefill_workspace_starts has been adjusted in-place per chunk so + # prefill indices automatically come out chunk-local + topk_indices = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=topk_indices.shape[1], + HAS_PREFILL_WORKSPACE=has_prefill_workspace, + prefill_workspace_request_ids=prefill_request_ids, + prefill_workspace_starts=prefill_workspace_starts, + ) + + fp8_metadata = attn_metadata.fp8_extra_metadata + assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode) + + def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor: + # Reshape q: (num_decode_tokens, num_heads, head_dim) + # -> (num_decodes, seq_len, num_heads, head_dim) + q = reshape_query_for_spec_decode(q, num_decodes) + seq_len = q.shape[1] + # Reshape topk_indices: (num_decode_tokens, topk) + # -> (num_decodes, seq_len, topk) + topk_indices = topk_indices.view(num_decodes, seq_len, -1) + assert fp8_metadata.decode is not None + attn_out, _ = self._fp8_flash_mla_kernel( + q=q, + kv_c_and_k_pe_cache=kv_c_and_k_pe_cache, + topk_indices=topk_indices, + kernel_metadata=fp8_metadata.decode.kernel_metadata, + ) + # Reshape output: (num_decodes, seq_len, num_heads, head_dim_v) + # -> (num_decode_tokens, num_heads, head_dim_v) + return reshape_attn_output_for_spec_decode(attn_out) + + num_decode_tokens = fp8_metadata.num_decode_tokens + num_prefill_tokens = fp8_metadata.num_prefill_tokens + + # Pure decode: direct call without allocation + if num_decode_tokens > 0 and num_prefill_tokens == 0: + assert fp8_metadata.decode is not None + attn_out = _fp8_decode(q, topk_indices) + else: + # Mixed or pure prefill: allocate output tensor + attn_out = q.new_empty( + (attn_metadata.num_actual_tokens, self.num_heads, self.kv_lora_rank), + dtype=q.dtype, + device=q.device, + ) + + if num_decode_tokens > 0: + attn_out[:num_decode_tokens] = _fp8_decode( + q[:num_decode_tokens], topk_indices[:num_decode_tokens] + ) + + assert fp8_metadata.prefill is not None + for chunk in fp8_metadata.prefill.chunks: + chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen] + ops.cp_gather_and_upconvert_fp8_kv_cache( + kv_c_and_k_pe_cache, + chunk_workspace, + chunk.block_table, + chunk.seq_lens, + chunk.workspace_starts, + len(chunk.block_table), + ) + + chunk_q = q[chunk.tokens_slice] + chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice] + + attn_out[chunk.tokens_slice] = self._bf16_flash_mla_kernel( + chunk_q, + chunk_workspace, + chunk_topk_indices_workspace, + ) + + return attn_out + + def _forward_fp8_kv_mixed_batch( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: + """Mixed batch FP8 forward path that treats all tokens as one batch. + + This is equivalent to main branch's approach and avoids the BF16 + prefill kernel which has head padding overhead when num_heads is small. + Used when use_mixed_batch is True. + """ + # Convert per-request indices to global slots (decode) or workspace + # offsets (prefill). + topk_indices = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=topk_indices.shape[1], + ) + + assert attn_metadata.fp8_extra_metadata is not None + assert isinstance( + attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata + ) + fp8_metadata = attn_metadata.fp8_extra_metadata + + _attn_out, _ = self._fp8_flash_mla_kernel( + q=q.unsqueeze(0), # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D) + kv_c_and_k_pe_cache=kv_c_and_k_pe_cache, + topk_indices=topk_indices.unsqueeze(0), # (T, topk) -> (1, T, topk) + kernel_metadata=fp8_metadata, + ) + + # Output is (1, T, H, D_v), squeeze back to (T, H, D_v) + return _attn_out.squeeze(0) + + def _fp8_flash_mla_kernel( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata, + ) -> torch.Tensor: + return flash_mla_with_kvcache( + q=q, + k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2), + block_table=kernel_metadata.dummy_block_table, + head_dim_v=512, + cache_seqlens=kernel_metadata.cache_lens, + tile_scheduler_metadata=kernel_metadata.scheduler_metadata, + num_splits=kernel_metadata.num_splits, + is_fp8_kvcache=True, + indices=topk_indices, + softmax_scale=self.softmax_scale, + ) + + def _bf16_flash_mla_kernel( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, ) -> torch.Tensor: num_tokens = q.shape[0] kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( @@ -445,31 +944,6 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): output = output[:, : self.num_heads, :] return output - def _forward_fp8_kv( - self, - q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - topk_indices: torch.Tensor, - attn_metadata: FlashMLASparseMetadata, - ) -> torch.Tensor: - assert attn_metadata.fp8_extra_metadata is not None - extra_metadata = attn_metadata.fp8_extra_metadata - - _attn_out, _ = flash_mla_with_kvcache( - q=q.unsqueeze(0), # unsqueeze to add batch_dim - k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2), - block_table=extra_metadata.dummy_block_table, - head_dim_v=512, - cache_seqlens=extra_metadata.cache_lens, - tile_scheduler_metadata=extra_metadata.scheduler_metadata, - num_splits=extra_metadata.num_splits, - is_fp8_kvcache=True, - indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim - softmax_scale=self.softmax_scale, - ) - - return _attn_out - def forward( self, layer: AttentionLayer, @@ -477,7 +951,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, - attn_metadata: FlashMLASparseMetadata, + attn_metadata: FlashMLASparseMetadata | None, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, @@ -493,6 +967,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ) if attn_metadata is None: + # Dummy run - no need to allocate buffers # The zero fill is required when used with DP + EP # to ensure all ranks within a DP group compute the # same expert outputs. @@ -505,6 +980,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): q = q[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] + topk_indices = self.topk_indices_buffer[:num_actual_toks] q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Convert from (B, N, P) to (N, B, P) @@ -514,16 +990,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): # Convert from (N, B, L) to (B, N, L) ql_nope = ql_nope.transpose(0, 1) - topk_indices = self.topk_indices_buffer[:num_actual_toks] - - # TODO: handle index / kv_cache correctly - topk_indices_global = triton_convert_req_index_to_global_index( - attn_metadata.req_id_per_token, - attn_metadata.block_table, - topk_indices, - BLOCK_SIZE=attn_metadata.block_size, - NUM_TOPK_TOKENS=attn_metadata.topk_tokens, - ) + use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla" q = torch.cat([ql_nope, q_pe], dim=-1) @@ -538,13 +1005,15 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): scale=layer._k_scale, ) - if self.kv_cache_dtype != "fp8_ds_mla": - attn_out = self._forward_bf16_kv( - q, kv_cache, topk_indices_global, attn_metadata + if not use_fp8_cache: + attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata) + elif attn_metadata.fp8_use_mixed_batch: + attn_out = self._forward_fp8_kv_mixed_batch( + q, kv_cache, topk_indices, attn_metadata ) else: - attn_out = self._forward_fp8_kv( - q, kv_cache, topk_indices_global, attn_metadata + attn_out = self._forward_fp8_kv_separate_prefill_decode( + q, kv_cache, topk_indices, attn_metadata ) self._v_up_proj(attn_out, out=output[:num_actual_toks]) diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 77f1ba00d5b04..d0696f60a08c7 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -18,6 +18,7 @@ from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, split_decodes_and_prefills, + split_prefill_chunks, ) logger = init_logger(__name__) @@ -176,40 +177,15 @@ def kv_spans_from_batches( def get_max_prefill_buffer_size(vllm_config: VllmConfig): max_model_len = vllm_config.model_config.max_model_len - # NOTE(Chen): 2 is a magic number for controlling the prefill buffer size. - # May be tuned later. - return max_model_len * 2 - - -def split_prefill_chunks( - seq_lens_cpu: torch.Tensor, max_prefill_buffer_size: int, reqs_start: int -) -> list[tuple[int, int]]: - """ - Split the prefill chunks into a list of tuples of (reqs_start, reqs_end) - such that the total sequence length of each chunk is less than the - maximum prefill buffer size. - - Args: - seq_lens_cpu: The sequence lengths of the prefill requests. - max_prefill_buffer_size: The maximum prefill buffer size. - reqs_start: The start index of the prefill requests. - - Returns: - A list of tuples of (reqs_start, reqs_end). - """ - chunk_seq_ids = [] - total_seq_lens = 0 - for i in range(reqs_start, len(seq_lens_cpu)): - cur_seq_len = seq_lens_cpu[i].item() - assert cur_seq_len <= max_prefill_buffer_size - total_seq_lens += cur_seq_len - if total_seq_lens > max_prefill_buffer_size: - chunk_seq_ids.append((reqs_start, i)) - reqs_start = i - total_seq_lens = cur_seq_len - if total_seq_lens > 0: - chunk_seq_ids.append((reqs_start, len(seq_lens_cpu))) - return chunk_seq_ids + # NOTE(Chen): 40 is a magic number for controlling the prefill buffer size. + # Each entry is 128 fp8 bytes and 4 scale bytes for a total of 132 bytes. + # The flashmla_sparse backend uses a workspace size of 5 * max_model_len. + # The memory usage of the workspace there is 576 * 2 bytes; so we size this as + # (576 * 2 // 132) * 5 = 40 to maximize this workspace size while still fitting + # within the flashmla_sparse workspace. + # For DeepSeek-V3.2, the max_model_len is 163840. + # 40 * 163840 * 132 = 865075200 bytes = 825 MB + return max_model_len * 40 class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): @@ -302,9 +278,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): prefill_metadata = None if num_prefills > 0: chunk_seq_ids = split_prefill_chunks( - common_attn_metadata.seq_lens_cpu, + common_attn_metadata.seq_lens_cpu[num_decodes:], self.max_prefill_buffer_size, - num_decodes, + request_offset=num_decodes, ) chunks = [ self.build_one_prefill_chunk( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 79a1f7d4757d9..da43d87038234 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -937,6 +937,33 @@ def split_decodes_and_prefills( return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) +def split_prefill_chunks( + seq_lens_cpu: torch.Tensor, workspace_size: int, request_offset: int = 0 +) -> list[tuple[int, int]]: + """ + Split the prefill requests into chunks such that the total sequence length + of each chunk is less than or equal to the workspace size. + + Args: + seq_lens_cpu: The sequence lengths of the prefill requests on CPU. + workspace_size: The maximum workspace size (in tokens) per chunk. + request_offset: The offset to add to the request indices. + Returns: + A list of tuples of (reqs_start, reqs_end) representing chunk boundaries. + """ + chunk_bounds = [] + i, n = 0, len(seq_lens_cpu) + assert torch.all(seq_lens_cpu <= workspace_size).item() + + while i < n: + start, chunk_total = i, 0 + while i < n and (chunk_total + (s := seq_lens_cpu[i].item())) <= workspace_size: + chunk_total += s + i += 1 + chunk_bounds.append((start + request_offset, i + request_offset)) + return chunk_bounds + + def reorder_batch_to_split_decodes_and_prefills( input_batch: "InputBatch", scheduler_output: "SchedulerOutput", diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3f20296c27ba7..978224faae65e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -162,6 +162,7 @@ from vllm.v1.worker.ubatch_utils import ( maybe_create_ubatch_slices, ) from vllm.v1.worker.utils import is_residual_scattered_for_sp +from vllm.v1.worker.workspace import lock_workspace from .utils import ( AttentionGroup, @@ -297,6 +298,7 @@ class GPUModelRunner( self.device = device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype + self.kv_cache_dtype = kv_cache_dtype_str_to_dtype( cache_config.cache_dtype, self.model_config ) @@ -4597,6 +4599,10 @@ class GPUModelRunner( # after here. set_cudagraph_capturing_enabled(False) + # Lock workspace to prevent resizing during execution. + # Max workspace sizes should have been captured during warmup/profiling. + lock_workspace() + end_time = time.perf_counter() elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 25ac5aaf99818..21a8564f83c40 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -54,6 +54,7 @@ from vllm.v1.outputs import ( from vllm.v1.utils import report_usage_stats from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.worker_base import WorkerBase +from vllm.v1.worker.workspace import init_workspace_manager logger = init_logger(__name__) @@ -255,6 +256,10 @@ class Worker(WorkerBase): else: raise RuntimeError(f"Not support device type: {self.device_config.device}") + # Initialize workspace manager + num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1 + init_workspace_manager(self.device, num_ubatches) + # Construct the model runner if self.use_v2_model_runner: from vllm.v1.worker.gpu.model_runner import ( diff --git a/vllm/v1/worker/workspace.py b/vllm/v1/worker/workspace.py new file mode 100644 index 0000000000000..a16dde1f67800 --- /dev/null +++ b/vllm/v1/worker/workspace.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import inspect +import os +from itertools import accumulate +from math import prod +from typing import Optional + +import torch + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.utils.math_utils import round_up +from vllm.v1.worker.ubatching import dbo_current_ubatch_id + +logger = init_logger(__name__) + + +def _compute_bytes(shape: tuple[int, ...], dtype: torch.dtype) -> int: + return prod(shape) * dtype.itemsize + + +# Constants +_MB = 1024**2 +_GiB = 1024**3 + +# Global workspace manager instance +_manager: Optional["WorkspaceManager"] = None + + +class WorkspaceManager: + """Manager for workspace allocation. + + Manages workspace buffers for DBO (Dual Batch Overlap) execution. + Can be locked to prevent further growth during execution. + """ + + def __init__(self, device: torch.device, num_ubatches: int | None = None): + self._device = device + # Cache num ubatches at init based on configuration (default to 1) + self._num_ubatches = num_ubatches if num_ubatches is not None else 1 + self._current_workspaces: list[torch.Tensor | None] = [None, None] + self._locked: bool = False + + @staticmethod + def _workspace_size_bytes(workspace: torch.Tensor | None) -> int: + """Get size of workspace in bytes.""" + if workspace is None: + return 0 + return workspace.numel() * workspace.element_size() + + def lock(self) -> None: + """Lock the workspace to prevent further growth. + + After locking, any attempt to allocate a larger workspace will raise + an assertion error. This ensures workspace size is fixed during execution. + """ + self._locked = True + if envs.VLLM_DEBUG_WORKSPACE: + logger.info( + "[WORKSPACE DEBUG] Workspace locked. Current sizes: %s", + [ + self._workspace_size_bytes(ws) / _MB + for ws in self._current_workspaces + if ws is not None + ], + ) + + def is_locked(self) -> bool: + """Check if workspace is locked.""" + return self._locked + + def get_simultaneous( + self, *shapes_and_dtypes: tuple[tuple[int, ...], torch.dtype] + ) -> list[torch.Tensor]: + """Get multiple workspace tensors simultaneously from a single allocation. + + Args: + *shapes_and_dtypes: One or more (shape, dtype) tuples. + + Returns: + List of tensor views into the workspace buffer, one per shape/dtype pair. + """ + actual_bytes = [_compute_bytes(s, d) for s, d in shapes_and_dtypes] + aligned_bytes = [round_up(actual, 256) for actual in actual_bytes] + total_bytes = sum(aligned_bytes) + + # Calculate cumulative offsets using itertools.accumulate + offsets = list(accumulate([0] + aligned_bytes[:-1])) + + current_workspace = self._ensure_workspace_size(total_bytes) + + return [ + current_workspace[offsets[i] : offsets[i] + actual_bytes[i]] + .view(shapes_and_dtypes[i][1]) + .reshape(shapes_and_dtypes[i][0]) + for i in range(len(shapes_and_dtypes)) + ] + + def _ensure_workspace_size(self, required_bytes: int) -> torch.Tensor: + """Ensure workspace is allocated and large enough, return current workspace. + + Args: + required_bytes: The number of bytes required. + + Returns: + The current workspace tensor. + """ + ubatch_id = dbo_current_ubatch_id() + current_workspace = self._current_workspaces[ubatch_id] + current_size = self._workspace_size_bytes(current_workspace) + + if current_size < required_bytes: + + def get_caller_info() -> str: + """Find first frame outside WorkspaceManager.""" + curr_frame = inspect.currentframe() + if curr_frame is None: + return "unknown" + # Walk up the stack skipping WorkspaceManager frames + curr_frame = curr_frame.f_back + while curr_frame is not None: + # TODO: This only catches instance methods (self), missing + # classmethods and staticmethods. Once Python 3.11+ is the + # minimum supported version, use co_qualname instead: + # qualname = curr_frame.f_code.co_qualname + # if qualname.startswith("WorkspaceManager."): + if isinstance(curr_frame.f_locals.get("self"), WorkspaceManager): + curr_frame = curr_frame.f_back + continue + filename = os.path.basename(curr_frame.f_code.co_filename) + return ( + f"{filename}:{curr_frame.f_lineno}:{curr_frame.f_code.co_name}" + ) + return "unknown" + + if self._locked: + raise AssertionError( + f"Workspace is locked but allocation from '{get_caller_info()}' " + f"requires {required_bytes / _MB:.2f} MB, current size is " + f"{current_size / _MB:.2f} MB. " + "Workspace growth is not allowed after locking." + ) + + for ubatch_id in range(self._num_ubatches): + current_workspace = self._current_workspaces[ubatch_id] + if current_workspace is None: + self._current_workspaces[ubatch_id] = torch.empty( + (required_bytes,), dtype=torch.uint8, device=self._device + ) + elif self._workspace_size_bytes(current_workspace) < required_bytes: + current_workspace.resize_(required_bytes) + + if envs.VLLM_DEBUG_WORKSPACE: + logger.info( + "[WORKSPACE DEBUG] Resized workspace from '%s': %.2f MB -> " + "%.2f MB (%d ubatches, total memory %.2f MB)", + get_caller_info(), + current_size / _MB, + required_bytes / _MB, + self._num_ubatches, + required_bytes * self._num_ubatches / _MB, + ) + + current_workspace = self._current_workspaces[dbo_current_ubatch_id()] + + return current_workspace + + +def is_workspace_manager_initialized() -> bool: + """Check if workspace manager has been initialized. + + Returns: + True if workspace manager is initialized, False otherwise. + """ + return _manager is not None + + +def current_workspace_manager() -> "WorkspaceManager": + """Get the current workspace manager instance. + + Raises: + AssertionError: If workspace manager has not been initialized. + """ + assert _manager is not None, ( + "WorkspaceManager not initialized. Call init_workspace_manager() " + "with a device before using workspace functions." + ) + return _manager + + +def init_workspace_manager( + device: torch.device, num_ubatches: int | None = None +) -> None: + """Initialize the workspace manager with a device. + + Must be called before using any workspace functions. Typically called + from GPUModelRunner.__init__. + + Args: + device: The device to allocate workspace on. + num_ubatches: Number of micro-batches. Defaults to 1. + """ + global _manager + if _manager is not None: + logger.warning( + "WorkspaceManager already initialized on device %s, " + "reinitializing on device %s", + _manager._device, + device, + ) + _manager = WorkspaceManager(device, num_ubatches) + + +def lock_workspace() -> None: + """Lock the workspace to prevent further growth. + + After calling this function, any attempt to allocate a workspace larger + than the current size will raise an AssertionError. This ensures that + workspace size is fixed during execution and prevents unexpected memory + allocations in the hot path. + + Example: + # During initialization + init_workspace_manager(device) + reserve_workspace(shape1, dtype1) + reserve_workspace(shape2, dtype2) + + # Lock after warmup/profiling + lock_workspace() + + # Now all get_workspace calls must fit in pre-allocated size + """ + current_workspace_manager().lock() + + +def reset_workspace_manager() -> None: + """Reset the workspace manager to uninitialized state. + + This is primarily intended for testing purposes to allow tests + to reinitialize the workspace manager cleanly. + """ + global _manager + _manager = None