more cleanups

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson 2025-01-30 03:11:43 +00:00
parent f23d126a07
commit 3895bba85a
3 changed files with 8 additions and 36 deletions

View File

@ -7,7 +7,7 @@ import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
AttentionMetadata, AttentionType)
AttentionMetadata)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@ -159,21 +159,6 @@ class MLAImplCommon(AttentionImpl):
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"FlashInferMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferMLAImpl")
def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
return self.o_proj_absored(
@ -225,7 +210,7 @@ class MLAImplCommon(AttentionImpl):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
#
# Perform matrix-absorbtion following
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
# for decode, as a result we end up with absorbed weights for decode
# and another copy of raw weights for prefill.
@ -292,14 +277,14 @@ class MLAImplCommon(AttentionImpl):
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for TritonMLAImpl")
"output is not yet supported for MLAImplBase")
is_decode = attn_metadata.decode_metadata is not None
is_prefill = attn_metadata.prefill_metadata is not None
if (is_decode and is_prefill):
raise NotImplementedError(
"chunked prefill is not supported for FlashInferMLAImpl")
"chunked prefill is not supported for MLAImplBase")
# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
@ -355,7 +340,8 @@ class MLAImplCommon(AttentionImpl):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# For MLA the v head dim is smaller than the
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)

View File

@ -653,7 +653,7 @@ class TritonMLAImpl(MLAImplCommon):
dtype=q.dtype,
device=q.device)
# TODO(lucas) Allocate ahead of prefill
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(
(
B,

View File

@ -75,20 +75,6 @@ HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
PretrainedConfig]]
def _is_flashinfer_available() -> bool:
"""Check if FlashInfer is available.
Returns:
bool: True if FlashInfer is installed and available, False otherwise.
"""
try:
from flashinfer import ( # noqa:F401
BatchDecodeMlaWithPagedKVCacheWrapper)
return True
except ImportError:
return False
class SupportsHash(Protocol):
def compute_hash(self) -> str:
@ -832,7 +818,7 @@ class ModelConfig:
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
if self.should_use_mla:
# TODO(simon): feature flag MLA
# When using MLA during decode it becomes MQA
return 1
total_num_kv_heads = self.get_total_num_kv_heads()