diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 6ad41ceee23e8..f650a854e4d50 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -7,7 +7,7 @@ import torch from vllm import _custom_ops as ops from vllm import envs from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer, - AttentionMetadata, AttentionType) + AttentionMetadata) from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -159,21 +159,6 @@ class MLAImplCommon(AttentionImpl): self.kv_b_proj = kv_b_proj self.o_proj = o_proj - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] - if any(unsupported_features): - raise NotImplementedError( - "FlashInferMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashInferMLAImpl") - def _v_up_proj_and_o_proj(self, x): if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: return self.o_proj_absored( @@ -225,7 +210,7 @@ class MLAImplCommon(AttentionImpl): if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: # - # Perform matrix-absorbtion following + # Perform matrix-absorption following # https://github.com/flashinfer-ai/flashinfer/pull/551 # for decode, as a result we end up with absorbed weights for decode # and another copy of raw weights for prefill. @@ -292,14 +277,14 @@ class MLAImplCommon(AttentionImpl): ) -> torch.Tensor: if output is not None: raise NotImplementedError( - "output is not yet supported for TritonMLAImpl") + "output is not yet supported for MLAImplBase") is_decode = attn_metadata.decode_metadata is not None is_prefill = attn_metadata.prefill_metadata is not None if (is_decode and is_prefill): raise NotImplementedError( - "chunked prefill is not supported for FlashInferMLAImpl") + "chunked prefill is not supported for MLAImplBase") # Restore head dim (for rotary embedding) k_pe = k_pe.unsqueeze(1) @@ -355,7 +340,8 @@ class MLAImplCommon(AttentionImpl): k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than the + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 43f5caf338b1f..cf9151cd2b30a 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -653,7 +653,7 @@ class TritonMLAImpl(MLAImplCommon): dtype=q.dtype, device=q.device) - # TODO(lucas) Allocate ahead of prefill + # TODO(lucas) Allocate ahead of time attn_logits = torch.empty( ( B, diff --git a/vllm/config.py b/vllm/config.py index f22dde291ae1b..7cdc7b1adf9ab 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -75,20 +75,6 @@ HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig], PretrainedConfig]] -def _is_flashinfer_available() -> bool: - """Check if FlashInfer is available. - - Returns: - bool: True if FlashInfer is installed and available, False otherwise. - """ - try: - from flashinfer import ( # noqa:F401 - BatchDecodeMlaWithPagedKVCacheWrapper) - return True - except ImportError: - return False - - class SupportsHash(Protocol): def compute_hash(self) -> str: @@ -832,7 +818,7 @@ class ModelConfig: def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: """Returns the number of KV heads per GPU.""" if self.should_use_mla: - # TODO(simon): feature flag MLA + # When using MLA during decode it becomes MQA return 1 total_num_kv_heads = self.get_total_num_kv_heads()