mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-06 11:28:46 +08:00
[BugFix] Fix prefix caching V0 MLA (#14255)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: Ying Zhong <zhongyingmatrix@gmail.com>
This commit is contained in:
parent
a7ea35aa67
commit
4dacaa4a83
@ -313,9 +313,10 @@ class MLACommonState(AttentionState, Generic[T]):
|
||||
cache_config = runner.cache_config
|
||||
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
self.enable_prefix_caching = cache_config.enable_prefix_caching
|
||||
|
||||
if self.chunked_prefill_enabled:
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
if self.chunked_prefill_enabled or self.enable_prefix_caching:
|
||||
self.context_chunk_workspace_size = min(
|
||||
# Max sure there is enough for 8 full length request or at least
|
||||
# 4 pages of cache per request
|
||||
max(
|
||||
@ -330,7 +331,7 @@ class MLACommonState(AttentionState, Generic[T]):
|
||||
# 2*(192*128)*(64*1024) = 3gb
|
||||
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||
128 * 1024)
|
||||
assert self.chunked_prefill_workspace_size >= \
|
||||
assert self.context_chunk_workspace_size >= \
|
||||
scheduler_config.max_num_seqs * cache_config.block_size
|
||||
|
||||
@contextmanager
|
||||
@ -430,23 +431,23 @@ class MLACommonState(AttentionState, Generic[T]):
|
||||
"TritonMLAState does not support encoder/decoder yet")
|
||||
|
||||
def begin_forward(self, model_input):
|
||||
if self.chunked_prefill_enabled:
|
||||
if not hasattr(self, "chunked_prefill_workspace"):
|
||||
if self.chunked_prefill_enabled or self.enable_prefix_caching:
|
||||
if not hasattr(self, "context_chunk_workspace"):
|
||||
# not self.runner.device does not return the correct device
|
||||
# for this process, (init_device sets the correct device but
|
||||
# only on the Worker). The only way Ive figured out to get the
|
||||
# correct device is to allocate the workspace on the first call
|
||||
# to begin_forward and use the device of the input tokens
|
||||
assert model_input.input_tokens is not None
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
self.context_chunk_workspace = torch.empty(
|
||||
(self.context_chunk_workspace_size,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=model_input.input_tokens.device,
|
||||
)
|
||||
|
||||
model_input.attn_metadata.chunked_prefill_workspace = \
|
||||
self.chunked_prefill_workspace
|
||||
model_input.attn_metadata.context_chunk_workspace = \
|
||||
self.context_chunk_workspace
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -537,7 +538,7 @@ class MLACommonMetadata(AttentionMetadata):
|
||||
context_chunk_seq_tot: Optional[List[int]] = None
|
||||
context_chunk_max_seq_lens: Optional[List[int]] = None
|
||||
# Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted
|
||||
chunked_prefill_workspace: Optional[torch.Tensor] = None
|
||||
context_chunk_workspace: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
|
||||
@ -747,11 +748,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||
self.block_size = input_builder.block_size
|
||||
self.chunked_prefill_enabled = \
|
||||
self.runner.scheduler_config.chunked_prefill_enabled
|
||||
self.enable_prefix_caching = \
|
||||
self.runner.cache_config.enable_prefix_caching
|
||||
|
||||
if self.chunked_prefill_enabled:
|
||||
if self.chunked_prefill_enabled or self.enable_prefix_caching:
|
||||
attn_state = self.input_builder.runner.attn_state
|
||||
self.chunked_prefill_workspace_size = \
|
||||
attn_state.chunked_prefill_workspace_size
|
||||
self.context_chunk_workspace_size = \
|
||||
attn_state.context_chunk_workspace_size
|
||||
self.page_size = self.runner.block_size
|
||||
|
||||
def prepare(self):
|
||||
@ -920,7 +923,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||
context_chunk_seq_tot = None
|
||||
context_chunk_max_seq_lens = None
|
||||
|
||||
if self.chunked_prefill_enabled and self.num_prefills > 0 \
|
||||
if (self.chunked_prefill_enabled or self.enable_prefix_caching) \
|
||||
and self.num_prefills > 0 \
|
||||
and context_lens_tensor is not None \
|
||||
and context_lens_tensor[:self.num_prefills].max() > 0:
|
||||
|
||||
@ -936,7 +940,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||
# algorithm here and allocate more workspace to prefills with
|
||||
# longer context lengths
|
||||
max_context_chunk = \
|
||||
self.chunked_prefill_workspace_size // num_prefills_with_context
|
||||
self.context_chunk_workspace_size // num_prefills_with_context
|
||||
|
||||
# align max_context_chunk to page_size by rounding down,
|
||||
# currently the `gather_cache` kernel cannot handle
|
||||
@ -965,7 +969,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||
chunk_seq_lens.max(dim=1).values.tolist()
|
||||
context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist()
|
||||
assert max(context_chunk_seq_tot) <= \
|
||||
self.chunked_prefill_workspace_size
|
||||
self.context_chunk_workspace_size
|
||||
|
||||
return self.runner.attn_backend.make_metadata(
|
||||
# Required by ModelRunner
|
||||
@ -1288,8 +1292,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
# Fetch from attn_metadata directly, since it late bound by
|
||||
# MLAAttentionState, grabbing it directly `attn_metadata` can avoid
|
||||
# any weirdness around prefill_metadata caching
|
||||
assert attn_metadata.chunked_prefill_workspace is not None
|
||||
workspace = attn_metadata.chunked_prefill_workspace
|
||||
assert attn_metadata.context_chunk_workspace is not None
|
||||
workspace = attn_metadata.context_chunk_workspace
|
||||
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.context_chunk_seq_tot[i]
|
||||
@ -1502,12 +1506,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
"output is not yet supported for MLAImplBase")
|
||||
|
||||
if attn_metadata.is_profile_run and \
|
||||
attn_metadata.chunked_prefill_workspace is not None:
|
||||
attn_metadata.context_chunk_workspace is not None:
|
||||
# During the profile run try to simulate to worse case output size
|
||||
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
|
||||
# since this can be large
|
||||
_ = torch.empty(
|
||||
(attn_metadata.chunked_prefill_workspace.shape[0],
|
||||
(attn_metadata.context_chunk_workspace.shape[0],
|
||||
self.num_heads, self.qk_nope_head_dim + self.v_head_dim),
|
||||
device=k_c_normed.device,
|
||||
dtype=k_c_normed.dtype,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user