diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f2610671f769e..646e4fec836bd 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -416,7 +416,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self.model_config = vllm_config.model_config cache_config = vllm_config.cache_config parallel_config = vllm_config.parallel_config - self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled self.num_heads = self.model_config.get_num_attention_heads( parallel_config) self.mla_dims = get_mla_dims(self.model_config) @@ -426,30 +425,28 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self.aot_schedule: self.page_size = self.kv_cache_spec.block_size - if self.chunked_prefill_enabled: - self.chunked_prefill_workspace_size = min( - # Max sure there is enough for 8 full length request or at least - # 4 pages of cache per request - max( - 8 * self.model_config.max_model_len, 4 * - scheduler_config.max_num_seqs * cache_config.block_size), - # For long-context models try not to over-allocate limiting - # kv-cache space, limiting it to 64k tokens, - # which would result in the workspace being: - # 2*(576)*(64*1024) = 144mb - # (assuming 576 MLA head dim, and fp16) - # which would result in up-projected context being - # 2*(192*128)*(64*1024) = 3gb - # (assuming 192 QK head dim, 128 heads, and fp16) - 128 * 1024) - assert self.chunked_prefill_workspace_size >= \ - scheduler_config.max_num_seqs * cache_config.block_size - self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, - self.model_config.get_head_size()), - dtype=self.model_config.dtype, - device=device, - ) + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * self.model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * cache_config.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) self._use_cudnn_prefill = use_cudnn_prefill() self._use_fi_prefill = use_flashinfer_prefill() @@ -620,8 +617,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): reqs_start:] - query_start_loc[reqs_start] chunked_context_metadata = None - if self.chunked_prefill_enabled and num_prefills > 0 \ - and max_context_len_cpu > 0: + if max_context_len_cpu > 0: # NOTE: it is recommend you read the `Chunked Prefill` section # in the comment at the top of the file before trying to # understand the following code