From 2b4fc9bd9b8321265ff54065ea47bd9e327c6b6f Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 26 Aug 2025 05:41:52 -0700 Subject: [PATCH] Support FlashAttention Backend for Hybrid SSM Models (#23299) Signed-off-by: Chen Zhang --- .../models/language/generation/test_hybrid.py | 3 -- vllm/v1/worker/gpu_model_runner.py | 41 ++++++++----------- 2 files changed, 17 insertions(+), 27 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 2055c44c83cda..7e7cc893ec8aa 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -110,9 +110,6 @@ def test_models( if model in V1_SUPPORTED_MODELS: with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - if model in HYBRID_MODELS: - # required due to reorder_batch behaviour - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, enable_prefix_caching=False) as vllm_model: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4f6cf9a350706..14f2305dadc54 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3023,40 +3023,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): raise NotImplementedError if has_attn and has_mamba: - self._verify_hybrid_attention_mamba_layout(kv_cache_config, - kv_cache_raw_tensors) + self._update_hybrid_attention_mamba_layout(kv_caches) return kv_caches - def _verify_hybrid_attention_mamba_layout( - self, kv_cache_config: KVCacheConfig, - kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None: + def _update_hybrid_attention_mamba_layout( + self, kv_caches: dict[str, torch.Tensor]) -> None: """ - Verify that the KV cache memory layout is compatible for - models with both attention and mamba KV cache groups. + Update the layout of attention layers from (2, num_blocks, ...) to + (num_blocks, 2, ...). Args: - kv_cache_config: The KV cache config - kv_cache_raw_tensors: The KV cache buffer of each layer. + kv_caches: The KV cache buffer of each layer. """ for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): for layer_name in group.layer_names: - raw_tensor = kv_cache_raw_tensors[layer_name] - num_blocks = (raw_tensor.numel() // - kv_cache_spec.page_size_bytes) - if isinstance(kv_cache_spec, AttentionSpec): - - kv_cache_shape = group.backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - if kv_cache_shape[0] != num_blocks or kv_cache_shape[ - 1] != 2: - raise ValueError( - "Hybrid models in V1 require an attention " - "backend with kv_cache_shape=" - "(num_blocks, 2, ...). Please try setting " - "VLLM_ATTENTION_BACKEND=FLASHINFER") + kv_cache = kv_caches[layer_name] + if (isinstance(kv_cache_spec, AttentionSpec) + and kv_cache.shape[0] == 2): + assert kv_cache.shape[1] != 2, \ + "Fail to determine whether the layout is " \ + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " \ + f"a tensor of shape {kv_cache.shape}" + hidden_size = kv_cache.shape[2:].numel() + kv_cache.as_strided_(size=kv_cache.shape, + stride=(hidden_size, 2 * hidden_size, + *kv_cache.stride()[2:])) def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: