diff --git a/docs/design/cuda_graphs.md b/docs/design/cuda_graphs.md index 177a581587d02..aac7b76eea265 100644 --- a/docs/design/cuda_graphs.md +++ b/docs/design/cuda_graphs.md @@ -177,8 +177,9 @@ The following table lists backends that support full CUDA Graphs at the time of | FlashAttention v3 | `ALWAYS` | has unified routine for both batches, so `FULL` mode is good | | Triton Attention | `ALWAYS` | prefer `FULL_AND_PIECEWISE` since it has different kernels for prefill/mixed and pure decode batches | | AITER FlashAttention | `UNIFORM_BATCH`| | -| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | | +| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | Will be set to `UNIFORM_BATCH` when using TRTLLM attention on Blackwell | | FlashMLA | `UNIFORM_BATCH` | | +| FlashInferMLA | `UNIFORM_BATCH` | | | AITER MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | | | CUTLASS MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | | | Mamba attention| `UNIFORM_SINGLE_TOKEN_DECODE` | | diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 5532ce80d7f15..f144e8435b6cf 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -32,7 +32,7 @@ def create_chunked_local_attention_backend( underlying_builder = underlying_attn_backend.get_builder_cls() class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore - cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER def build( self, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 9cec623814c9f..d9bd52d8f9800 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -207,7 +207,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad # to FULL_AND_PIECEWISE. # TODO(luka, lucas): audit FA2 as part of: # https://github.com/vllm-project/vllm/issues/22945 - cudagraph_support = ( + _cudagraph_support = ( AttentionCGSupport.ALWAYS if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 18bbc3cc3c12b..1ce8e6f3d89f8 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -15,6 +15,7 @@ from flashinfer import ( from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.utils import FP4Tensor +from typing_extensions import override from vllm import envs from vllm.attention.backends.abstract import ( @@ -274,10 +275,6 @@ class FlashInferMetadata: class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = ( - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - ) - reorder_batch_threshold: int = 1 def __init__( @@ -355,6 +352,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): else: self.q_data_type = self.model_config.dtype + # Prefer TRTLLM attention for decoding in all cases. + # This allows us to use AttentionCGSupport.UNIFORM_BATCH mode. + self.use_trtllm_decode_attention = can_use_trtllm self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm) self._cascade_wrapper = None # Wrapper for cascade attention @@ -412,6 +412,24 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): "passing --block-size 32 or --block-size 64." ) + @classmethod + @override + def get_cudagraph_support( + cls: type["FlashInferMetadataBuilder"], + vllm_config: VllmConfig, + kv_cache_spec: AttentionSpec, + ) -> AttentionCGSupport: + has_trtllm_support = can_use_trtllm_attention( + num_qo_heads=vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config + ), + num_kv_heads=kv_cache_spec.num_kv_heads, + ) + if has_trtllm_support: + return AttentionCGSupport.UNIFORM_BATCH + else: + return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + def _get_workspace_buffer(self): if self._workspace_buffer is None: buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE @@ -573,17 +591,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): has_sinks=self.has_sinks, has_spec=uses_spec_reorder, ) - decode_use_trtllm = use_trtllm_attention( - self.num_qo_heads, - self.num_kv_heads, - num_decode_tokens, - max_seq_len, - self.cache_dtype, - self.q_data_type, - is_prefill=False, - has_sinks=self.has_sinks, - has_spec=uses_spec_reorder, - ) + decode_use_trtllm = self.use_trtllm_decode_attention if not (prefill_use_trtllm and decode_use_trtllm): if self.has_sinks: diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 2ca19646911ec..69b5a6fb48564 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -59,7 +59,7 @@ class GDNAttentionMetadata: class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]): - cudagraph_support = AttentionCGSupport.UNIFORM_BATCH + _cudagraph_support = AttentionCGSupport.UNIFORM_BATCH reorder_batch_threshold: int = 1 diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 49d7d6c31b9a0..0d875565fc99a 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -20,7 +20,7 @@ M = TypeVar("M") class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): reorder_batch_threshold: int = 1 - cudagraph_support: ClassVar[AttentionCGSupport] = ( + _cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE ) diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 0a10ce74cd1d4..60cb5022a55eb 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -29,7 +29,7 @@ logger = init_logger(__name__) class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): # enable full CUDA Graph support for decode-only capture - cudagraph_support: ClassVar[AttentionCGSupport] = ( + _cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE ) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 5662acbe32c29..7794e89cc0a94 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -92,7 +92,7 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN reorder_batch_threshold: int = 512 # process small prefills with decode pathway diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index b0f514ba44513..52bb19e039e45 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -29,7 +29,7 @@ FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 8f0364cd58def..3aab1f9bb7fb6 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -96,7 +96,7 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM reorder_batch_threshold: int = 128 # process small prefills with decode pathway # ^ TODO(matt): tune this diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 4794312eb96ef..5fe9c69d35007 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -248,7 +248,7 @@ def triton_convert_req_index_to_global_index( @dataclass class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH def __init__( self, diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 4f071145625fc..37aa5dad89a0e 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -206,7 +206,7 @@ def split_prefill_chunks( class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): - cudagraph_support: ClassVar[AttentionCGSupport] = ( + _cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE ) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 5757aeadba056..e1864526f02cc 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -55,7 +55,7 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): # TODO(luka, lucas): audit this as part of: # https://github.com/vllm-project/vllm/issues/22945 - cudagraph_support: ClassVar[AttentionCGSupport] = ( + _cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE ) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 81991244f5d90..4888ae51d1d3e 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -251,7 +251,7 @@ class AiterFlashAttentionMetadata: class AiterFlashAttentionMetadataBuilder( AttentionMetadataBuilder[AiterFlashAttentionMetadata] ): - cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + _cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE reorder_batch_threshold: int = 1 def __init__( diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 1d2c70f65d0f5..6dfdfc19ccba1 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -63,7 +63,7 @@ class RocmAttentionMetadata: class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS def __init__( self, diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 37c0ae61e65d0..889c79db18ef5 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -67,7 +67,7 @@ class TritonAttentionMetadata: class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS def __init__( self, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 751c5c15a4c98..fd37a665cf05f 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -244,7 +244,8 @@ class AttentionCGSupport(enum.Enum): class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder support CUDA Graphs for attention (default: no). - cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER + # Do not access directly. Call get_cudagraph_support() instead. + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. @@ -263,6 +264,15 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): self.vllm_config = vllm_config self.device = device + @classmethod + def get_cudagraph_support( + cls: type["AttentionMetadataBuilder"], + vllm_config: VllmConfig, + kv_cache_spec: AttentionSpec, + ) -> AttentionCGSupport: + """Get the cudagraph support level of this builder class.""" + return cls._cudagraph_support + def _init_reorder_batch_threshold( self, reorder_batch_threshold: int | None = 1, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b14b6b1c3f52e..987d451fd6baf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4167,14 +4167,16 @@ class GPUModelRunner( return attn_groups attention_backend_maps = [] - attention_backend_set: set[type[AttentionBackend]] = set() + attention_backend_list = [] for kv_cache_group_spec in kv_cache_config.kv_cache_groups: attn_backends = get_attn_backends_for_group(kv_cache_group_spec) attention_backend_maps.append(attn_backends[0]) - attention_backend_set.update(attn_backends[1]) + attention_backend_list.append(attn_backends[1]) # Resolve cudagraph_mode before actually initialize metadata_builders - self._check_and_update_cudagraph_mode(attention_backend_set) + self._check_and_update_cudagraph_mode( + attention_backend_list, kv_cache_config.kv_cache_groups + ) for i, attn_backend_map in enumerate(attention_backend_maps): self.attn_groups.append(create_attn_groups(attn_backend_map, i)) @@ -4203,22 +4205,31 @@ class GPUModelRunner( self.calculate_reorder_batch_threshold() def _check_and_update_cudagraph_mode( - self, attention_backends: set[type[AttentionBackend]] + self, + attention_backends: list[set[type[AttentionBackend]]], + kv_cache_groups: list[KVCacheGroupSpec], ) -> None: """ Resolve the cudagraph_mode when there are multiple attention - backends with potential conflicting CUDA graph support. + groups with potential conflicting CUDA graph support. Then initialize the cudagraph_dispatcher based on the resolved cudagraph_mode. """ min_cg_support = AttentionCGSupport.ALWAYS min_cg_backend_name = None - for attn_backend in attention_backends: - builder_cls = attn_backend.get_builder_cls() - if builder_cls.cudagraph_support.value < min_cg_support.value: - min_cg_support = builder_cls.cudagraph_support - min_cg_backend_name = attn_backend.__name__ + for attn_backend_set, kv_cache_group in zip( + attention_backends, kv_cache_groups + ): + for attn_backend in attn_backend_set: + builder_cls = attn_backend.get_builder_cls() + + cg_support = builder_cls.get_cudagraph_support( + self.vllm_config, kv_cache_group.kv_cache_spec + ) + if cg_support.value < min_cg_support.value: + min_cg_support = cg_support + min_cg_backend_name = attn_backend.__name__ # Flexible resolve the cudagraph mode cudagraph_mode = self.compilation_config.cudagraph_mode # check cudagraph for mixed batch is supported