From 072d7e53e534d337b41262dd44ded9b44aa699ef Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Date: Thu, 18 Sep 2025 18:27:49 +0400 Subject: [PATCH] [PERF] Add `conv1d` metadata to GDN attn (#25105) Signed-off-by: Vadim Gimpelson --- vllm/model_executor/layers/mamba/mamba2_metadata.py | 8 +++++--- vllm/model_executor/models/qwen3_next.py | 10 +++++++++- vllm/v1/attention/backends/gdn_attn.py | 6 ++++++ vllm/v1/attention/backends/mamba2_attn.py | 4 ++-- vllm/v1/attention/backends/short_conv_attn.py | 4 ++-- 5 files changed, 24 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index 368bfe3af1d3f..c926e17a2c197 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -11,6 +11,7 @@ from vllm.attention.backends.placeholder_attn import ( PlaceholderAttentionMetadata) from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.platforms import current_platform +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm.v1.attention.backends.mamba2_attn import ( Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets) @@ -45,8 +46,8 @@ class Mamba2Metadata: """ nums_dict: Optional[dict] = None cu_seqlen: Optional[int] = None - batch_ptr: Optional[torch.tensor] = None - token_chunk_offset_ptr: Optional[torch.tensor] = None + batch_ptr: Optional[torch.Tensor] = None + token_chunk_offset_ptr: Optional[torch.Tensor] = None def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: @@ -117,7 +118,8 @@ def prepare_mamba2_metadata( def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor, mamba2_metadata: Union[Mamba2Metadata, - Mamba2AttentionMetadata]): + Mamba2AttentionMetadata, + GDNAttentionMetadata]): """ this is triggered upon handling a new input at the first layer """ diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index eb060cb90f44c..0c974ee44eee2 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -35,6 +35,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata from vllm.model_executor.layers.mamba.mamba_mixer2 import ( mamba_v2_sharded_weight_loader) from vllm.model_executor.layers.mamba.mamba_utils import ( @@ -414,6 +415,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] + conv_metadata = attn_metadata assert isinstance(attn_metadata, GDNAttentionMetadata) has_initial_state = attn_metadata.has_initial_state spec_query_start_loc = attn_metadata.spec_query_start_loc @@ -475,10 +477,15 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): # 2.2: process the remaining part if attn_metadata.num_prefills > 0: + mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) + if conv_metadata.cu_seqlen is None: + conv_metadata = update_metadata(mixed_qkv_non_spec_T, + non_spec_query_start_loc, + conv_metadata) # - "cache_indices" updates the conv_state cache in positions # pointed to by "mamba_cache_params.state_indices_tensor" mixed_qkv_non_spec = causal_conv1d_fn( - mixed_qkv_non_spec.transpose(0, 1), + mixed_qkv_non_spec_T, conv_weights, self.conv1d.bias, activation=self.activation, @@ -486,6 +493,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): has_initial_state=has_initial_state, cache_indices=non_spec_state_indices_tensor, query_start_loc=non_spec_query_start_loc, + metadata=conv_metadata, ).transpose(0, 1) elif attn_metadata.num_decodes > 0: mixed_qkv_non_spec = causal_conv1d_update( diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index ba89f93e8b56d..5dadc52d0fb1c 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -50,6 +50,12 @@ class GDNAttentionMetadata: Tensor] = None # shape: [num_prefill_tokens + num_decode_tokens,] num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,] + # The following attributes are for triton implementation of causal_conv1d + nums_dict: Optional[dict] = None + cu_seqlen: Optional[int] = None + batch_ptr: Optional[torch.Tensor] = None + token_chunk_offset_ptr: Optional[torch.Tensor] = None + class GDNAttentionMetadataBuilder( AttentionMetadataBuilder[GDNAttentionMetadata]): diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 359bad1ea9dee..2fe1f14ca1db0 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -132,8 +132,8 @@ class Mamba2AttentionMetadata: # The following attributes are for triton implementation of causal_conv1d nums_dict: Optional[dict] = None cu_seqlen: Optional[int] = None - batch_ptr: Optional[torch.tensor] = None - token_chunk_offset_ptr: Optional[torch.tensor] = None + batch_ptr: Optional[torch.Tensor] = None + token_chunk_offset_ptr: Optional[torch.Tensor] = None class Mamba2AttentionMetadataBuilder( diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index f5ad65b02b4d4..717c40b37ecfb 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -34,8 +34,8 @@ class ShortConvAttentionMetadata: # For causal_conv1d nums_dict: Optional[dict] = None cu_seqlen: Optional[int] = None - batch_ptr: Optional[torch.tensor] = None - token_chunk_offset_ptr: Optional[torch.tensor] = None + batch_ptr: Optional[torch.Tensor] = None + token_chunk_offset_ptr: Optional[torch.Tensor] = None class ShortConvAttentionMetadataBuilder(