Remove chunked_prefill_enabled flag in V1 MLA (#23183)

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
Matthew Bonanni 2025-08-20 17:43:17 -04:00 committed by GitHub
parent 1b125004be
commit a4fbb32fab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -416,7 +416,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
self.num_heads = self.model_config.get_num_attention_heads( self.num_heads = self.model_config.get_num_attention_heads(
parallel_config) parallel_config)
self.mla_dims = get_mla_dims(self.model_config) self.mla_dims = get_mla_dims(self.model_config)
@ -426,30 +425,28 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
if self.aot_schedule: if self.aot_schedule:
self.page_size = self.kv_cache_spec.block_size self.page_size = self.kv_cache_spec.block_size
if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min(
self.chunked_prefill_workspace_size = min( # Max sure there is enough for 8 full length request or at least
# Max sure there is enough for 8 full length request or at least # 4 pages of cache per request
# 4 pages of cache per request max(8 * self.model_config.max_model_len,
max( 4 * scheduler_config.max_num_seqs * cache_config.block_size),
8 * self.model_config.max_model_len, 4 * # For long-context models try not to over-allocate limiting
scheduler_config.max_num_seqs * cache_config.block_size), # kv-cache space, limiting it to 64k tokens,
# For long-context models try not to over-allocate limiting # which would result in the workspace being:
# kv-cache space, limiting it to 64k tokens, # 2*(576)*(64*1024) = 144mb
# which would result in the workspace being: # (assuming 576 MLA head dim, and fp16)
# 2*(576)*(64*1024) = 144mb # which would result in up-projected context being
# (assuming 576 MLA head dim, and fp16) # 2*(192*128)*(64*1024) = 3gb
# which would result in up-projected context being # (assuming 192 QK head dim, 128 heads, and fp16)
# 2*(192*128)*(64*1024) = 3gb 128 * 1024)
# (assuming 192 QK head dim, 128 heads, and fp16) assert self.chunked_prefill_workspace_size >= \
128 * 1024) scheduler_config.max_num_seqs * cache_config.block_size
assert self.chunked_prefill_workspace_size >= \ self.chunked_prefill_workspace = torch.empty(
scheduler_config.max_num_seqs * cache_config.block_size (self.chunked_prefill_workspace_size,
self.chunked_prefill_workspace = torch.empty( self.model_config.get_head_size()),
(self.chunked_prefill_workspace_size, dtype=self.model_config.dtype,
self.model_config.get_head_size()), device=device,
dtype=self.model_config.dtype, )
device=device,
)
self._use_cudnn_prefill = use_cudnn_prefill() self._use_cudnn_prefill = use_cudnn_prefill()
self._use_fi_prefill = use_flashinfer_prefill() self._use_fi_prefill = use_flashinfer_prefill()
@ -620,8 +617,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
reqs_start:] - query_start_loc[reqs_start] reqs_start:] - query_start_loc[reqs_start]
chunked_context_metadata = None chunked_context_metadata = None
if self.chunked_prefill_enabled and num_prefills > 0 \ if max_context_len_cpu > 0:
and max_context_len_cpu > 0:
# NOTE: it is recommend you read the `Chunked Prefill` section # NOTE: it is recommend you read the `Chunked Prefill` section
# in the comment at the top of the file before trying to # in the comment at the top of the file before trying to
# understand the following code # understand the following code