[Bugfix][V1] Handle MLA in kv_cache_interface (#14462)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2025-03-08 01:18:25 -05:00 committed by GitHub
parent ef64044079
commit 333681408f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 10 deletions

View File

@ -23,9 +23,9 @@ class KVCacheSpecBase:
def type_id(self) -> str: def type_id(self) -> str:
""" """
The type identifier of this KV cache. The type identifier of this KV cache.
Return different strings for layers with different KV cache type (e.g., Return different strings for layers with different KV cache type (e.g.,
different number of tokens like full attention vs sliding window different number of tokens like full attention vs sliding window
attention, different KV cache size per token like layers with different attention, different KV cache size per token like layers with different
number of heads) number of heads)
Returns: Returns:
@ -59,6 +59,7 @@ class FullAttentionSpec(KVCacheSpecBase):
num_kv_heads: int num_kv_heads: int
head_size: int head_size: int
dtype: torch.dtype dtype: torch.dtype
use_mla: bool
@property @property
def type_id(self) -> str: def type_id(self) -> str:
@ -66,7 +67,9 @@ class FullAttentionSpec(KVCacheSpecBase):
@property @property
def page_size_bytes(self) -> int: def page_size_bytes(self) -> int:
return 2 * self.block_size * self.num_kv_heads * self.head_size \ # For MLA we only store a single latent vector
coef = 1 if self.use_mla else 2
return coef * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype) * get_dtype_size(self.dtype)
def bytes_for_tokens(self, num_tokens: int) -> int: def bytes_for_tokens(self, num_tokens: int) -> int:
@ -104,7 +107,7 @@ class KVCacheConfig:
2. (not implemented yet) A model with the same number of full attention 2. (not implemented yet) A model with the same number of full attention
layers and sliding window attention layers: two groups, one for full layers and sliding window attention layers: two groups, one for full
attention layers and one for sliding window attention layers. attention layers and one for sliding window attention layers.
3. (not implemented yet) A model with 2 full attention layers and 4 sliding 3. (not implemented yet) A model with 2 full attention layers and 4 sliding
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2). window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
""" """
groups: list[list[str]] groups: list[list[str]]

View File

@ -1460,13 +1460,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
forward_ctx = self.vllm_config.compilation_config.static_forward_context forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: KVCacheSpec = {} kv_cache_spec: KVCacheSpec = {}
for layer_name, attn_module in forward_ctx.items(): for layer_name, attn_module in forward_ctx.items():
if isinstance(attn_module, FusedMoE): if isinstance(attn_module, FusedMoE):
continue continue
# TODO: Support other attention modules, e.g., sliding window, # TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA. # cross-attention
assert isinstance(attn_module, Attention) assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER: if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec( kv_cache_spec[layer_name] = FullAttentionSpec(
@ -1474,7 +1475,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=attn_module.dtype, dtype=attn_module.dtype,
) use_mla=use_mla)
elif attn_module.attn_type in (AttentionType.ENCODER, elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY): AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache. # encoder-only attention does not need KV cache.

View File

@ -303,10 +303,10 @@ class TPUModelRunner:
def get_kv_cache_spec(self) -> KVCacheSpec: def get_kv_cache_spec(self) -> KVCacheSpec:
""" """
Generates the KVCacheSpec by parsing the kv cache format from each Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context. Attention module in the static forward context.
Returns: Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included. format. Layers that do not need KV cache are not included.
""" """
@ -323,6 +323,7 @@ class TPUModelRunner:
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=attn_module.dtype, dtype=attn_module.dtype,
use_mla=False,
) )
elif attn_module.attn_type in (AttentionType.ENCODER, elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY): AttentionType.ENCODER_ONLY):
@ -764,7 +765,7 @@ class TPUModelRunner:
""" """
Initialize KV cache based on `kv_cache_config`. Initialize KV cache based on `kv_cache_config`.
Args: Args:
kv_cache_config: Configuration for the KV cache, including the KV kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer cache size of each layer
""" """
if len(kv_cache_config.groups) > 1: if len(kv_cache_config.groups) > 1: