[Bugfix][V1] Handle MLA in kv_cache_interface (#14462)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2025-03-08 01:18:25 -05:00 committed by GitHub
parent ef64044079
commit 333681408f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 10 deletions

View File

@ -59,6 +59,7 @@ class FullAttentionSpec(KVCacheSpecBase):
num_kv_heads: int num_kv_heads: int
head_size: int head_size: int
dtype: torch.dtype dtype: torch.dtype
use_mla: bool
@property @property
def type_id(self) -> str: def type_id(self) -> str:
@ -66,7 +67,9 @@ class FullAttentionSpec(KVCacheSpecBase):
@property @property
def page_size_bytes(self) -> int: def page_size_bytes(self) -> int:
return 2 * self.block_size * self.num_kv_heads * self.head_size \ # For MLA we only store a single latent vector
coef = 1 if self.use_mla else 2
return coef * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype) * get_dtype_size(self.dtype)
def bytes_for_tokens(self, num_tokens: int) -> int: def bytes_for_tokens(self, num_tokens: int) -> int:

View File

@ -1460,13 +1460,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
forward_ctx = self.vllm_config.compilation_config.static_forward_context forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: KVCacheSpec = {} kv_cache_spec: KVCacheSpec = {}
for layer_name, attn_module in forward_ctx.items(): for layer_name, attn_module in forward_ctx.items():
if isinstance(attn_module, FusedMoE): if isinstance(attn_module, FusedMoE):
continue continue
# TODO: Support other attention modules, e.g., sliding window, # TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA. # cross-attention
assert isinstance(attn_module, Attention) assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER: if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec( kv_cache_spec[layer_name] = FullAttentionSpec(
@ -1474,7 +1475,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=attn_module.dtype, dtype=attn_module.dtype,
) use_mla=use_mla)
elif attn_module.attn_type in (AttentionType.ENCODER, elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY): AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache. # encoder-only attention does not need KV cache.

View File

@ -323,6 +323,7 @@ class TPUModelRunner:
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=attn_module.dtype, dtype=attn_module.dtype,
use_mla=False,
) )
elif attn_module.attn_type in (AttentionType.ENCODER, elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY): AttentionType.ENCODER_ONLY):