mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 06:35:01 +08:00
[Bugfix][V1] Handle MLA in kv_cache_interface (#14462)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
ef64044079
commit
333681408f
@ -59,6 +59,7 @@ class FullAttentionSpec(KVCacheSpecBase):
|
||||
num_kv_heads: int
|
||||
head_size: int
|
||||
dtype: torch.dtype
|
||||
use_mla: bool
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
@ -66,7 +67,9 @@ class FullAttentionSpec(KVCacheSpecBase):
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
return 2 * self.block_size * self.num_kv_heads * self.head_size \
|
||||
# For MLA we only store a single latent vector
|
||||
coef = 1 if self.use_mla else 2
|
||||
return coef * self.block_size * self.num_kv_heads * self.head_size \
|
||||
* get_dtype_size(self.dtype)
|
||||
|
||||
def bytes_for_tokens(self, num_tokens: int) -> int:
|
||||
|
||||
@ -1460,13 +1460,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
kv_cache_spec: KVCacheSpec = {}
|
||||
for layer_name, attn_module in forward_ctx.items():
|
||||
if isinstance(attn_module, FusedMoE):
|
||||
continue
|
||||
|
||||
# TODO: Support other attention modules, e.g., sliding window,
|
||||
# cross-attention, MLA.
|
||||
# cross-attention
|
||||
assert isinstance(attn_module, Attention)
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
@ -1474,7 +1475,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype,
|
||||
)
|
||||
use_mla=use_mla)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
|
||||
@ -323,6 +323,7 @@ class TPUModelRunner:
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype,
|
||||
use_mla=False,
|
||||
)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user