mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-05 10:04:00 +08:00
[BugFix] Fix MLA assert with CUTLASS MLA (#25478)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
faa58fa791
commit
a986f17028
@ -204,7 +204,7 @@ from vllm.attention.backends.utils import get_mla_dims
|
||||
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -436,6 +436,34 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
"""
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
@staticmethod
|
||||
def determine_chunked_prefill_workspace_size(
|
||||
vllm_config: VllmConfig) -> int:
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
chunked_prefill_workspace_size = min(
|
||||
# Try for 8 full length request or at least 4 pages per-request
|
||||
max(8 * model_config.max_model_len,
|
||||
4 * scheduler_config.max_num_seqs * cache_config.block_size),
|
||||
# For long-context models try not to over-allocate limiting
|
||||
# kv-cache space, limiting it to 64k tokens,
|
||||
# which would result in the workspace being:
|
||||
# 2*(576)*(64*1024) = 144mb
|
||||
# (assuming 576 MLA head dim, and fp16)
|
||||
# which would result in up-projected context being
|
||||
# 2*(192*128)*(64*1024) = 3gb
|
||||
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||
64 * 1024)
|
||||
|
||||
# Enforce that we enough for at least 1 page per request
|
||||
chunked_prefill_workspace_size = max(
|
||||
chunked_prefill_workspace_size,
|
||||
scheduler_config.max_num_seqs * cache_config.block_size)
|
||||
|
||||
return chunked_prefill_workspace_size
|
||||
|
||||
def __init__(self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
@ -448,7 +476,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
cache_config = vllm_config.cache_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.device = device
|
||||
|
||||
@ -468,22 +495,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
if self.aot_schedule:
|
||||
self.page_size = self.kv_cache_spec.block_size
|
||||
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
# Max sure there is enough for 8 full length request or at least
|
||||
# 4 pages of cache per request
|
||||
max(8 * self.model_config.max_model_len,
|
||||
4 * scheduler_config.max_num_seqs * cache_config.block_size),
|
||||
# For long-context models try not to over-allocate limiting
|
||||
# kv-cache space, limiting it to 64k tokens,
|
||||
# which would result in the workspace being:
|
||||
# 2*(576)*(64*1024) = 144mb
|
||||
# (assuming 576 MLA head dim, and fp16)
|
||||
# which would result in up-projected context being
|
||||
# 2*(192*128)*(64*1024) = 3gb
|
||||
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||
64 * 1024)
|
||||
assert self.chunked_prefill_workspace_size >= \
|
||||
scheduler_config.max_num_seqs * cache_config.block_size
|
||||
self.chunked_prefill_workspace_size = \
|
||||
self.determine_chunked_prefill_workspace_size(vllm_config)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
# Note(hc): The local kvcache is incomplete when DCP is triggered,
|
||||
# an additional kvcache allgather across the DCP group is therefore
|
||||
@ -999,6 +1013,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
|
||||
self.dcp_world_size: Optional[int] = None
|
||||
|
||||
self.chunked_prefill_workspace_size = \
|
||||
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
|
||||
get_current_vllm_config())
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(self,
|
||||
q,
|
||||
k,
|
||||
@ -1513,6 +1531,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
" for MLACommonImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# During the profile run try to simulate to worse case output size
|
||||
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
|
||||
# since this can be large
|
||||
_ = torch.empty(
|
||||
(self.chunked_prefill_workspace_size, self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim),
|
||||
device=k_c_normed.device,
|
||||
dtype=k_c_normed.dtype,
|
||||
)
|
||||
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user