mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 16:56:31 +08:00
Support FlashAttention Backend for Hybrid SSM Models (#23299)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
ebd5a77bb5
commit
2b4fc9bd9b
@ -110,9 +110,6 @@ def test_models(
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
if model in HYBRID_MODELS:
|
||||
# required due to reorder_batch behaviour
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
enable_prefix_caching=False) as vllm_model:
|
||||
|
||||
@ -3023,40 +3023,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
raise NotImplementedError
|
||||
|
||||
if has_attn and has_mamba:
|
||||
self._verify_hybrid_attention_mamba_layout(kv_cache_config,
|
||||
kv_cache_raw_tensors)
|
||||
self._update_hybrid_attention_mamba_layout(kv_caches)
|
||||
|
||||
return kv_caches
|
||||
|
||||
def _verify_hybrid_attention_mamba_layout(
|
||||
self, kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None:
|
||||
def _update_hybrid_attention_mamba_layout(
|
||||
self, kv_caches: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Verify that the KV cache memory layout is compatible for
|
||||
models with both attention and mamba KV cache groups.
|
||||
Update the layout of attention layers from (2, num_blocks, ...) to
|
||||
(num_blocks, 2, ...).
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
kv_cache_raw_tensors: The KV cache buffer of each layer.
|
||||
kv_caches: The KV cache buffer of each layer.
|
||||
"""
|
||||
|
||||
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
||||
for layer_name in group.layer_names:
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
num_blocks = (raw_tensor.numel() //
|
||||
kv_cache_spec.page_size_bytes)
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
|
||||
kv_cache_shape = group.backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
if kv_cache_shape[0] != num_blocks or kv_cache_shape[
|
||||
1] != 2:
|
||||
raise ValueError(
|
||||
"Hybrid models in V1 require an attention "
|
||||
"backend with kv_cache_shape="
|
||||
"(num_blocks, 2, ...). Please try setting "
|
||||
"VLLM_ATTENTION_BACKEND=FLASHINFER")
|
||||
kv_cache = kv_caches[layer_name]
|
||||
if (isinstance(kv_cache_spec, AttentionSpec)
|
||||
and kv_cache.shape[0] == 2):
|
||||
assert kv_cache.shape[1] != 2, \
|
||||
"Fail to determine whether the layout is " \
|
||||
"(2, num_blocks, ...) or (num_blocks, 2, ...) for " \
|
||||
f"a tensor of shape {kv_cache.shape}"
|
||||
hidden_size = kv_cache.shape[2:].numel()
|
||||
kv_cache.as_strided_(size=kv_cache.shape,
|
||||
stride=(hidden_size, 2 * hidden_size,
|
||||
*kv_cache.stride()[2:]))
|
||||
|
||||
def initialize_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user