[Bugfix][V1] Handle MLA in kv_cache_interface (#14462)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2025-03-08 01:18:25 -05:00 committed by GitHub
parent ef64044079
commit 333681408f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 10 deletions

View File

@ -23,9 +23,9 @@ class KVCacheSpecBase:
def type_id(self) -> str:
"""
The type identifier of this KV cache.
Return different strings for layers with different KV cache type (e.g.,
different number of tokens like full attention vs sliding window
attention, different KV cache size per token like layers with different
Return different strings for layers with different KV cache type (e.g.,
different number of tokens like full attention vs sliding window
attention, different KV cache size per token like layers with different
number of heads)
Returns:
@ -59,6 +59,7 @@ class FullAttentionSpec(KVCacheSpecBase):
num_kv_heads: int
head_size: int
dtype: torch.dtype
use_mla: bool
@property
def type_id(self) -> str:
@ -66,7 +67,9 @@ class FullAttentionSpec(KVCacheSpecBase):
@property
def page_size_bytes(self) -> int:
return 2 * self.block_size * self.num_kv_heads * self.head_size \
# For MLA we only store a single latent vector
coef = 1 if self.use_mla else 2
return coef * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype)
def bytes_for_tokens(self, num_tokens: int) -> int:
@ -104,7 +107,7 @@ class KVCacheConfig:
2. (not implemented yet) A model with the same number of full attention
layers and sliding window attention layers: two groups, one for full
attention layers and one for sliding window attention layers.
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
"""
groups: list[list[str]]

View File

@ -1460,13 +1460,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: KVCacheSpec = {}
for layer_name, attn_module in forward_ctx.items():
if isinstance(attn_module, FusedMoE):
continue
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
# cross-attention
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
@ -1474,7 +1475,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
)
use_mla=use_mla)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.

View File

@ -303,10 +303,10 @@ class TPUModelRunner:
def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
@ -323,6 +323,7 @@ class TPUModelRunner:
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
use_mla=False,
)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
@ -764,7 +765,7 @@ class TPUModelRunner:
"""
Initialize KV cache based on `kv_cache_config`.
Args:
kv_cache_config: Configuration for the KV cache, including the KV
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if len(kv_cache_config.groups) > 1: