diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 50819bb2bb943..03cbc56e3f4a6 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -459,6 +459,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> FlashInferMetadata: + # For full cudagraph capture, ensure decode-only mode + if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1: + # This is likely a cudagraph capture scenario + assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \ + "FlashInfer only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ @@ -577,22 +584,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): return attn_metadata - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): - """ - This method builds the metadata for full cudagraph capture. - Currently, only decode is supported for full cudagraphs with FlashInfer. - """ - m = common_attn_metadata - - assert m.num_reqs == m.num_actual_tokens, \ - "FlashInfer only supports decode-only full CUDAGraph capture. " \ - "Make sure all cudagraph capture sizes <= max_num_seq." - - m.max_query_len = 1 # decode-only - - return self.build(0, m) - def use_cascade_attention(self, *args, **kwargs) -> bool: if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: # TODO: The cascade wrapper currently does not support setting diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 97a1aa86dda0d..0fc8601e48940 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -43,6 +43,13 @@ class Mamba1AttentionMetadataBuilder( common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> Mamba1AttentionMetadata: + # For full cudagraph capture, ensure decode-only mode + if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1: + # This is likely a cudagraph capture scenario + assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \ + "Mamba only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + query_start_loc = common_attn_metadata.query_start_loc state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index ed30884fdbc94..e9d0ae6b3eeab 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -101,6 +101,13 @@ class Mamba2AttentionMetadataBuilder( common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> Mamba2AttentionMetadata: + # For full cudagraph capture, ensure decode-only mode + if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1: + # This is likely a cudagraph capture scenario + assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \ + "Mamba only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 07ef7cb69a160..43f9c69124d5a 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -38,18 +38,3 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): device=device, ) - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: - """ - This method builds the metadata for full cudagraph capture. - Currently, only decode is supported for full cudagraphs with Mamba. - """ - m = common_attn_metadata - - assert m.num_reqs == m.num_actual_tokens, \ - "Mamba only supports decode-only full CUDAGraph capture. " \ - "Make sure all cudagraph capture sizes <= max_num_seq." - - m.max_query_len = 1 # decode-only - - return self.build(0, m) \ No newline at end of file diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index ce45b34f64355..3ded260296f25 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -561,25 +561,17 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): seq_lens=seq_lens, ) - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: - """ - This method builds the metadata for full cudagraph capture. - Currently, only decode is supported for full cudagraphs with MLA. - """ - m = common_attn_metadata - assert m.num_reqs == m.num_actual_tokens, \ - "MLA only supports decode-only full CUDAGraph capture. " \ - "Make sure all cudagraph capture sizes <= max_num_seq." - - assert m.max_query_len == 1 # decode-only - - return self.build(0, m) - def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> M: + # For full cudagraph capture, ensure decode-only mode + if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1: + # This is likely a cudagraph capture scenario + assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \ + "MLA only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 403ad8e88a958..48e26b017c5e2 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -254,19 +254,19 @@ class AiterFlashAttentionMetadataBuilder( self.aot_sliding_window: Optional[tuple[int, int]] = None self.total_tokens: int = 0 - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): - self.total_tokens = self.model_config.max_model_len \ - * self.vllm_config.scheduler_config.max_num_partial_prefills - res = self.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata) - self.total_tokens = 0 - return res - def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> 'AiterFlashAttentionMetadata': + # Handle total_tokens for cudagraph capture scenarios + is_cudagraph_capture = (common_prefix_len == 0 and + common_attn_metadata.max_query_len == 1) + if is_cudagraph_capture: + original_total_tokens = self.total_tokens + self.total_tokens = self.model_config.max_model_len \ + * self.vllm_config.scheduler_config.max_num_partial_prefills + else: + original_total_tokens = None num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -310,6 +310,11 @@ class AiterFlashAttentionMetadataBuilder( common_prefix_len=common_prefix_len, total_tokens=self.total_tokens, ) + + # Restore total_tokens value if this was a cudagraph capture + if is_cudagraph_capture: + self.total_tokens = original_total_tokens + return attn_metadata def use_cascade_attention(self, *args, **kwargs) -> bool: diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index b12036c599799..e509dc9329c6b 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -73,16 +73,6 @@ class TritonAttentionMetadataBuilder( vllm_config.parallel_config) self.headdim = model_config.get_head_size() - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata - ) -> TritonAttentionMetadata: - attn_metadata = self.build(0, common_attn_metadata) - # When doing full graph capture, setting seq_lens to - # max_model_len will cause graph capture to be extremely - # slow, so here we set it to 1. - attn_metadata.seq_lens.fill_(1) - return attn_metadata - def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, @@ -129,6 +119,14 @@ class TritonAttentionMetadataBuilder( suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, ) + + # Handle cudagraph capture optimizations + if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1: + # When doing full graph capture, setting seq_lens to + # max_model_len will cause graph capture to be extremely + # slow, so here we set it to 1. + attn_metadata.seq_lens.fill_(1) + return attn_metadata diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 39bdbe125635b..00134bab27fb5 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -205,16 +205,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): """ raise NotImplementedError - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: - """ - Build attention metadata for CUDA graph capture. Uses build by default. - Subclasses that override this method should call self.build or - super().build_for_cudagraph_capture. - """ - return self.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata) - def build_for_drafting( self, common_attn_metadata: CommonAttentionMetadata, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5d49bbaf270bb..2dfe075ab4557 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2306,7 +2306,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for attn_group in self.attn_groups[kv_cache_group_id]: attn_metadata_i = attn_group.metadata_builder\ - .build_for_cudagraph_capture(common_attn_metadata) + .build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i