mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 12:15:01 +08:00
[Bugfix][V1] Handle MLA in kv_cache_interface (#14462)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
ef64044079
commit
333681408f
@ -23,9 +23,9 @@ class KVCacheSpecBase:
|
||||
def type_id(self) -> str:
|
||||
"""
|
||||
The type identifier of this KV cache.
|
||||
Return different strings for layers with different KV cache type (e.g.,
|
||||
different number of tokens like full attention vs sliding window
|
||||
attention, different KV cache size per token like layers with different
|
||||
Return different strings for layers with different KV cache type (e.g.,
|
||||
different number of tokens like full attention vs sliding window
|
||||
attention, different KV cache size per token like layers with different
|
||||
number of heads)
|
||||
|
||||
Returns:
|
||||
@ -59,6 +59,7 @@ class FullAttentionSpec(KVCacheSpecBase):
|
||||
num_kv_heads: int
|
||||
head_size: int
|
||||
dtype: torch.dtype
|
||||
use_mla: bool
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
@ -66,7 +67,9 @@ class FullAttentionSpec(KVCacheSpecBase):
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
return 2 * self.block_size * self.num_kv_heads * self.head_size \
|
||||
# For MLA we only store a single latent vector
|
||||
coef = 1 if self.use_mla else 2
|
||||
return coef * self.block_size * self.num_kv_heads * self.head_size \
|
||||
* get_dtype_size(self.dtype)
|
||||
|
||||
def bytes_for_tokens(self, num_tokens: int) -> int:
|
||||
@ -104,7 +107,7 @@ class KVCacheConfig:
|
||||
2. (not implemented yet) A model with the same number of full attention
|
||||
layers and sliding window attention layers: two groups, one for full
|
||||
attention layers and one for sliding window attention layers.
|
||||
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
|
||||
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
|
||||
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
|
||||
"""
|
||||
groups: list[list[str]]
|
||||
|
||||
@ -1460,13 +1460,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
kv_cache_spec: KVCacheSpec = {}
|
||||
for layer_name, attn_module in forward_ctx.items():
|
||||
if isinstance(attn_module, FusedMoE):
|
||||
continue
|
||||
|
||||
# TODO: Support other attention modules, e.g., sliding window,
|
||||
# cross-attention, MLA.
|
||||
# cross-attention
|
||||
assert isinstance(attn_module, Attention)
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
@ -1474,7 +1475,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype,
|
||||
)
|
||||
use_mla=use_mla)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
|
||||
@ -303,10 +303,10 @@ class TPUModelRunner:
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
Attention module in the static forward context.
|
||||
Returns:
|
||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
@ -323,6 +323,7 @@ class TPUModelRunner:
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype,
|
||||
use_mla=False,
|
||||
)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
@ -764,7 +765,7 @@ class TPUModelRunner:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
Args:
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
if len(kv_cache_config.groups) > 1:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user