From a986f17028b3d113899363ddfbb569178181cc05 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 23 Sep 2025 21:09:43 -0400 Subject: [PATCH] [BugFix] Fix MLA assert with CUTLASS MLA (#25478) Signed-off-by: Lucas Wilkinson Signed-off-by: yewentao256 --- vllm/v1/attention/backends/mla/common.py | 64 +++++++++++++++++------- 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index a177117a50bd1..e84f2d89943e7 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -204,7 +204,7 @@ from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -436,6 +436,34 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ reorder_batch_threshold: ClassVar[int] = 1 + @staticmethod + def determine_chunked_prefill_workspace_size( + vllm_config: VllmConfig) -> int: + scheduler_config = vllm_config.scheduler_config + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + + chunked_prefill_workspace_size = min( + # Try for 8 full length request or at least 4 pages per-request + max(8 * model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 64 * 1024) + + # Enforce that we enough for at least 1 page per request + chunked_prefill_workspace_size = max( + chunked_prefill_workspace_size, + scheduler_config.max_num_seqs * cache_config.block_size) + + return chunked_prefill_workspace_size + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], @@ -448,7 +476,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config - cache_config = vllm_config.cache_config self.compilation_config = vllm_config.compilation_config self.device = device @@ -468,22 +495,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self.aot_schedule: self.page_size = self.kv_cache_spec.block_size - self.chunked_prefill_workspace_size = min( - # Max sure there is enough for 8 full length request or at least - # 4 pages of cache per request - max(8 * self.model_config.max_model_len, - 4 * scheduler_config.max_num_seqs * cache_config.block_size), - # For long-context models try not to over-allocate limiting - # kv-cache space, limiting it to 64k tokens, - # which would result in the workspace being: - # 2*(576)*(64*1024) = 144mb - # (assuming 576 MLA head dim, and fp16) - # which would result in up-projected context being - # 2*(192*128)*(64*1024) = 3gb - # (assuming 192 QK head dim, 128 heads, and fp16) - 64 * 1024) - assert self.chunked_prefill_workspace_size >= \ - scheduler_config.max_num_seqs * cache_config.block_size + self.chunked_prefill_workspace_size = \ + self.determine_chunked_prefill_workspace_size(vllm_config) + if self.dcp_world_size > 1: # Note(hc): The local kvcache is incomplete when DCP is triggered, # an additional kvcache allgather across the DCP group is therefore @@ -999,6 +1013,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.dcp_world_size: Optional[int] = None + self.chunked_prefill_workspace_size = \ + MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( + get_current_vllm_config()) + def _flash_attn_varlen_diff_headdims(self, q, k, @@ -1513,6 +1531,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): " for MLACommonImpl") if attn_metadata is None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + (self.chunked_prefill_workspace_size, self.num_heads, + self.qk_nope_head_dim + self.v_head_dim), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + # The zero fill is required when used with DP + EP # to ensure all ranks within a DP group compute the # same expert outputs.