Support FlashAttention Backend for Hybrid SSM Models (#23299)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-08-26 05:41:52 -07:00 committed by GitHub
parent ebd5a77bb5
commit 2b4fc9bd9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 27 deletions

View File

@ -110,9 +110,6 @@ def test_models(
if model in V1_SUPPORTED_MODELS:
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enable_prefix_caching=False) as vllm_model:

View File

@ -3023,40 +3023,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
raise NotImplementedError
if has_attn and has_mamba:
self._verify_hybrid_attention_mamba_layout(kv_cache_config,
kv_cache_raw_tensors)
self._update_hybrid_attention_mamba_layout(kv_caches)
return kv_caches
def _verify_hybrid_attention_mamba_layout(
self, kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None:
def _update_hybrid_attention_mamba_layout(
self, kv_caches: dict[str, torch.Tensor]) -> None:
"""
Verify that the KV cache memory layout is compatible for
models with both attention and mamba KV cache groups.
Update the layout of attention layers from (2, num_blocks, ...) to
(num_blocks, 2, ...).
Args:
kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer.
kv_caches: The KV cache buffer of each layer.
"""
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
for layer_name in group.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name]
num_blocks = (raw_tensor.numel() //
kv_cache_spec.page_size_bytes)
if isinstance(kv_cache_spec, AttentionSpec):
kv_cache_shape = group.backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
if kv_cache_shape[0] != num_blocks or kv_cache_shape[
1] != 2:
raise ValueError(
"Hybrid models in V1 require an attention "
"backend with kv_cache_shape="
"(num_blocks, 2, ...). Please try setting "
"VLLM_ATTENTION_BACKEND=FLASHINFER")
kv_cache = kv_caches[layer_name]
if (isinstance(kv_cache_spec, AttentionSpec)
and kv_cache.shape[0] == 2):
assert kv_cache.shape[1] != 2, \
"Fail to determine whether the layout is " \
"(2, num_blocks, ...) or (num_blocks, 2, ...) for " \
f"a tensor of shape {kv_cache.shape}"
hidden_size = kv_cache.shape[2:].numel()
kv_cache.as_strided_(size=kv_cache.shape,
stride=(hidden_size, 2 * hidden_size,
*kv_cache.stride()[2:]))
def initialize_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: