mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-10 00:27:05 +08:00
more cleanups
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
f23d126a07
commit
3895bba85a
@ -7,7 +7,7 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
AttentionMetadata)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@ -159,21 +159,6 @@ class MLAImplCommon(AttentionImpl):
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashInferMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferMLAImpl")
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
return self.o_proj_absored(
|
||||
@ -225,7 +210,7 @@ class MLAImplCommon(AttentionImpl):
|
||||
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
#
|
||||
# Perform matrix-absorbtion following
|
||||
# Perform matrix-absorption following
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/551
|
||||
# for decode, as a result we end up with absorbed weights for decode
|
||||
# and another copy of raw weights for prefill.
|
||||
@ -292,14 +277,14 @@ class MLAImplCommon(AttentionImpl):
|
||||
) -> torch.Tensor:
|
||||
if output is not None:
|
||||
raise NotImplementedError(
|
||||
"output is not yet supported for TritonMLAImpl")
|
||||
"output is not yet supported for MLAImplBase")
|
||||
|
||||
is_decode = attn_metadata.decode_metadata is not None
|
||||
is_prefill = attn_metadata.prefill_metadata is not None
|
||||
|
||||
if (is_decode and is_prefill):
|
||||
raise NotImplementedError(
|
||||
"chunked prefill is not supported for FlashInferMLAImpl")
|
||||
"chunked prefill is not supported for MLAImplBase")
|
||||
|
||||
# Restore head dim (for rotary embedding)
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
@ -355,7 +340,8 @@ class MLAImplCommon(AttentionImpl):
|
||||
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
|
||||
# For MLA the v head dim is smaller than the
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim
|
||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
|
||||
|
||||
@ -653,7 +653,7 @@ class TritonMLAImpl(MLAImplCommon):
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
# TODO(lucas) Allocate ahead of prefill
|
||||
# TODO(lucas) Allocate ahead of time
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
B,
|
||||
|
||||
@ -75,20 +75,6 @@ HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
|
||||
PretrainedConfig]]
|
||||
|
||||
|
||||
def _is_flashinfer_available() -> bool:
|
||||
"""Check if FlashInfer is available.
|
||||
|
||||
Returns:
|
||||
bool: True if FlashInfer is installed and available, False otherwise.
|
||||
"""
|
||||
try:
|
||||
from flashinfer import ( # noqa:F401
|
||||
BatchDecodeMlaWithPagedKVCacheWrapper)
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
class SupportsHash(Protocol):
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
@ -832,7 +818,7 @@ class ModelConfig:
|
||||
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
"""Returns the number of KV heads per GPU."""
|
||||
if self.should_use_mla:
|
||||
# TODO(simon): feature flag MLA
|
||||
# When using MLA during decode it becomes MQA
|
||||
return 1
|
||||
|
||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user