From c6f384dafdf08fbec7284fbfc133fdaee03c8be2 Mon Sep 17 00:00:00 2001 From: Paul Pak <52512091+paulpak58@users.noreply.github.com> Date: Fri, 3 Oct 2025 21:59:48 +0900 Subject: [PATCH] [backends][short_conv] CUDA graph piecewise edits (#24215) Signed-off-by: Paul Pak Signed-off-by: yewentao256 --- .../model_executor/layers/mamba/short_conv.py | 2 +- vllm/v1/attention/backends/short_conv_attn.py | 40 +++++++++---------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index cc424760e229f..eb4223ade5f0d 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -115,7 +115,7 @@ class ShortConv(MambaBase, CustomOp): self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) state_indices_tensor = attn_metadata.state_indices_tensor - has_initial_states_p = attn_metadata.has_initial_states + has_initial_states_p = attn_metadata.has_initial_states_p BCx, _ = self.in_proj(hidden_states) diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index df7f0d2310ab4..ba0fba4281e57 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -6,12 +6,12 @@ from typing import Optional import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadataBuilder) +from vllm.v1.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionMetadata, compute_causal_conv1d_metadata, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec class ShortConvAttentionBackend(AttentionBackend): @@ -29,8 +29,8 @@ class ShortConvAttentionMetadata: num_decode_tokens: int query_start_loc: torch.Tensor - has_initial_states: torch.Tensor - state_indices_tensor: torch.Tensor # shape: [batch,] + state_indices_tensor: torch.Tensor + has_initial_states_p: Optional[torch.Tensor] # For causal_conv1d nums_dict: Optional[dict] = None @@ -39,14 +39,7 @@ class ShortConvAttentionMetadata: class ShortConvAttentionMetadataBuilder( - AttentionMetadataBuilder[ShortConvAttentionMetadata]): - - reorder_batch_threshold: int = 1 - - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): - super().__init__(kv_cache_spec, layer_names, vllm_config, device) - assert isinstance(kv_cache_spec, MambaSpec) + BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata]): def build(self, common_prefix_len: int, @@ -54,7 +47,6 @@ class ShortConvAttentionMetadataBuilder( fast_build: bool = False) -> ShortConvAttentionMetadata: num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] # for causal_conv1d @@ -64,13 +56,13 @@ class ShortConvAttentionMetadataBuilder( split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold)) - has_initial_states = None + + has_initial_states_p = None if num_prefills > 0: - #[batch,] has_initial_states_cpu = ( common_attn_metadata. num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) - has_initial_states = has_initial_states_cpu.to( + has_initial_states_p = has_initial_states_cpu.to( query_start_loc.device) query_start_loc_p = common_attn_metadata.query_start_loc[ @@ -79,14 +71,22 @@ class ShortConvAttentionMetadataBuilder( nums_dict, batch_ptr, token_chunk_offset_ptr = \ compute_causal_conv1d_metadata(query_start_loc_p) + elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.full_cuda_graph): + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) + self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor, + non_blocking=True) + state_indices_tensor = self.state_indices_tensor[:num_input_tokens] + state_indices_tensor[num_decodes:] = PAD_SLOT_ID + attn_metadata = ShortConvAttentionMetadata( + query_start_loc=query_start_loc, + state_indices_tensor=state_indices_tensor, + has_initial_states_p=has_initial_states_p, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, - query_start_loc=query_start_loc, - has_initial_states=has_initial_states, - state_indices_tensor=state_indices_tensor, nums_dict=nums_dict, batch_ptr=batch_ptr, token_chunk_offset_ptr=token_chunk_offset_ptr,