diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 4dd562be3838..225fee8d2a0d 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -232,6 +232,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) from vllm.multimodal import MultiModalPlaceholderMap +from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down try: @@ -1371,18 +1372,35 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) - output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_softmax_lse=has_context, - ) + if has_context: + if not current_platform.is_cuda(): + raise NotImplementedError( + "Chunked Prefill for MLA is not currently supported on" + "non-cuda platforms") + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=True, + ) + else: + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + ) if has_context: suffix_output, suffix_lse = output diff --git a/vllm/config.py b/vllm/config.py index 8e1ce87438af..a5d8ee9303d0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3422,6 +3422,20 @@ class VllmConfig: "Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION + if self.model_config and self.model_config.use_mla and \ + not current_platform.is_cuda(): + logger.info( + "MLA is enabled on a non-cuda platform; forcing chunked " + "prefill and prefix caching to be disabled.") + self.scheduler_config.enable_chunked_prefill = False + self.scheduler_config.chunked_prefill_enabled = False + self.scheduler_config.max_num_batched_tokens = max( + self.scheduler_config.max_model_len, + _DEFAULT_MAX_NUM_BATCHED_TOKENS) + + if self.cache_config is not None: + self.cache_config.enable_prefix_caching = False + current_platform.check_and_update_config(self) if not self.instance_id: