diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index e912b1e9757a5..fc5f3420e394d 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1327,21 +1327,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): [0, q.shape[-1] - v.shape[-1]], value=0) - if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN: - attn_output, attn_softmax_lse = self.triton_fa_func( - q, - k, - v_padded, - None, - prefill_metadata.query_start_loc, - prefill_metadata.context_chunk_cu_seq_lens[i], - prefill_metadata.max_query_len, - prefill_metadata.context_chunk_max_seq_lens[i], - False, # causal - self.scale, - None, # attn_mask is None unless applying ALiBi mask - ) - elif is_vllm_fa: + if is_vllm_fa: attn_output, attn_softmax_lse = self.flash_attn_varlen_func( q=q, k=k, @@ -1416,7 +1402,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) - if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN: + if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context: output = self.triton_fa_func( q, k, diff --git a/vllm/config.py b/vllm/config.py index 2ee45f1837c4e..b61d1a22c8a08 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3450,9 +3450,9 @@ class VllmConfig: self.compilation_config.level = CompilationLevel.NO_COMPILATION if self.model_config and self.model_config.use_mla and \ - not current_platform.is_cuda(): + not (current_platform.is_cuda() or current_platform.is_rocm()): logger.info( - "MLA is enabled on a non-cuda platform; forcing chunked " + "MLA is enabled on a non-GPU platform; forcing chunked " "prefill and prefix caching to be disabled.") self.scheduler_config.enable_chunked_prefill = False self.scheduler_config.chunked_prefill_enabled = False