[BugFix] Fix prefix caching V0 MLA (#14255)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: Ying Zhong <zhongyingmatrix@gmail.com>
This commit is contained in:
Lucas Wilkinson 2025-03-05 20:07:42 -05:00 committed by GitHub
parent a7ea35aa67
commit 4dacaa4a83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -313,9 +313,10 @@ class MLACommonState(AttentionState, Generic[T]):
cache_config = runner.cache_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
self.enable_prefix_caching = cache_config.enable_prefix_caching
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
if self.chunked_prefill_enabled or self.enable_prefix_caching:
self.context_chunk_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(
@ -330,7 +331,7 @@ class MLACommonState(AttentionState, Generic[T]):
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
assert self.context_chunk_workspace_size >= \
scheduler_config.max_num_seqs * cache_config.block_size
@contextmanager
@ -430,23 +431,23 @@ class MLACommonState(AttentionState, Generic[T]):
"TritonMLAState does not support encoder/decoder yet")
def begin_forward(self, model_input):
if self.chunked_prefill_enabled:
if not hasattr(self, "chunked_prefill_workspace"):
if self.chunked_prefill_enabled or self.enable_prefix_caching:
if not hasattr(self, "context_chunk_workspace"):
# not self.runner.device does not return the correct device
# for this process, (init_device sets the correct device but
# only on the Worker). The only way Ive figured out to get the
# correct device is to allocate the workspace on the first call
# to begin_forward and use the device of the input tokens
assert model_input.input_tokens is not None
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
self.context_chunk_workspace = torch.empty(
(self.context_chunk_workspace_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=model_input.input_tokens.device,
)
model_input.attn_metadata.chunked_prefill_workspace = \
self.chunked_prefill_workspace
model_input.attn_metadata.context_chunk_workspace = \
self.context_chunk_workspace
@dataclass
@ -537,7 +538,7 @@ class MLACommonMetadata(AttentionMetadata):
context_chunk_seq_tot: Optional[List[int]] = None
context_chunk_max_seq_lens: Optional[List[int]] = None
# Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted
chunked_prefill_workspace: Optional[torch.Tensor] = None
context_chunk_workspace: Optional[torch.Tensor] = None
def __post_init__(self):
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
@ -747,11 +748,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
self.block_size = input_builder.block_size
self.chunked_prefill_enabled = \
self.runner.scheduler_config.chunked_prefill_enabled
self.enable_prefix_caching = \
self.runner.cache_config.enable_prefix_caching
if self.chunked_prefill_enabled:
if self.chunked_prefill_enabled or self.enable_prefix_caching:
attn_state = self.input_builder.runner.attn_state
self.chunked_prefill_workspace_size = \
attn_state.chunked_prefill_workspace_size
self.context_chunk_workspace_size = \
attn_state.context_chunk_workspace_size
self.page_size = self.runner.block_size
def prepare(self):
@ -920,7 +923,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
context_chunk_seq_tot = None
context_chunk_max_seq_lens = None
if self.chunked_prefill_enabled and self.num_prefills > 0 \
if (self.chunked_prefill_enabled or self.enable_prefix_caching) \
and self.num_prefills > 0 \
and context_lens_tensor is not None \
and context_lens_tensor[:self.num_prefills].max() > 0:
@ -936,7 +940,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
# algorithm here and allocate more workspace to prefills with
# longer context lengths
max_context_chunk = \
self.chunked_prefill_workspace_size // num_prefills_with_context
self.context_chunk_workspace_size // num_prefills_with_context
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
@ -965,7 +969,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
chunk_seq_lens.max(dim=1).values.tolist()
context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist()
assert max(context_chunk_seq_tot) <= \
self.chunked_prefill_workspace_size
self.context_chunk_workspace_size
return self.runner.attn_backend.make_metadata(
# Required by ModelRunner
@ -1288,8 +1292,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# Fetch from attn_metadata directly, since it late bound by
# MLAAttentionState, grabbing it directly `attn_metadata` can avoid
# any weirdness around prefill_metadata caching
assert attn_metadata.chunked_prefill_workspace is not None
workspace = attn_metadata.chunked_prefill_workspace
assert attn_metadata.context_chunk_workspace is not None
workspace = attn_metadata.context_chunk_workspace
for i in range(iters):
toks = prefill_metadata.context_chunk_seq_tot[i]
@ -1502,12 +1506,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
"output is not yet supported for MLAImplBase")
if attn_metadata.is_profile_run and \
attn_metadata.chunked_prefill_workspace is not None:
attn_metadata.context_chunk_workspace is not None:
# During the profile run try to simulate to worse case output size
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
# since this can be large
_ = torch.empty(
(attn_metadata.chunked_prefill_workspace.shape[0],
(attn_metadata.context_chunk_workspace.shape[0],
self.num_heads, self.qk_nope_head_dim + self.v_head_dim),
device=k_c_normed.device,
dtype=k_c_normed.dtype,