[Perf] Refactor cudagraph_support to enable full CUDA graphs for spec decoding with FlashInfer (#28479)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett 2025-11-12 11:56:40 -05:00 committed by GitHub
parent a742134cc5
commit 304419576a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 71 additions and 41 deletions

View File

@ -177,8 +177,9 @@ The following table lists backends that support full CUDA Graphs at the time of
| FlashAttention v3 | `ALWAYS` | has unified routine for both batches, so `FULL` mode is good |
| Triton Attention | `ALWAYS` | prefer `FULL_AND_PIECEWISE` since it has different kernels for prefill/mixed and pure decode batches |
| AITER FlashAttention | `UNIFORM_BATCH`| |
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | Will be set to `UNIFORM_BATCH` when using TRTLLM attention on Blackwell |
| FlashMLA | `UNIFORM_BATCH` | |
| FlashInferMLA | `UNIFORM_BATCH` | |
| AITER MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| CUTLASS MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| Mamba attention| `UNIFORM_SINGLE_TOKEN_DECODE` | |

View File

@ -32,7 +32,7 @@ def create_chunked_local_attention_backend(
underlying_builder = underlying_attn_backend.get_builder_cls()
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
def build(
self,

View File

@ -207,7 +207,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
# to FULL_AND_PIECEWISE.
# TODO(luka, lucas): audit FA2 as part of:
# https://github.com/vllm-project/vllm/issues/22945
cudagraph_support = (
_cudagraph_support = (
AttentionCGSupport.ALWAYS
if get_flash_attn_version() == 3
else AttentionCGSupport.UNIFORM_BATCH

View File

@ -15,6 +15,7 @@ from flashinfer import (
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor
from typing_extensions import override
from vllm import envs
from vllm.attention.backends.abstract import (
@ -274,10 +275,6 @@ class FlashInferMetadata:
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
reorder_batch_threshold: int = 1
def __init__(
@ -355,6 +352,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
else:
self.q_data_type = self.model_config.dtype
# Prefer TRTLLM attention for decoding in all cases.
# This allows us to use AttentionCGSupport.UNIFORM_BATCH mode.
self.use_trtllm_decode_attention = can_use_trtllm
self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)
self._cascade_wrapper = None # Wrapper for cascade attention
@ -412,6 +412,24 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
"passing --block-size 32 or --block-size 64."
)
@classmethod
@override
def get_cudagraph_support(
cls: type["FlashInferMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
has_trtllm_support = can_use_trtllm_attention(
num_qo_heads=vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config
),
num_kv_heads=kv_cache_spec.num_kv_heads,
)
if has_trtllm_support:
return AttentionCGSupport.UNIFORM_BATCH
else:
return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
@ -573,17 +591,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder,
)
decode_use_trtllm = use_trtllm_attention(
self.num_qo_heads,
self.num_kv_heads,
num_decode_tokens,
max_seq_len,
self.cache_dtype,
self.q_data_type,
is_prefill=False,
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder,
)
decode_use_trtllm = self.use_trtllm_decode_attention
if not (prefill_use_trtllm and decode_use_trtllm):
if self.has_sinks:

View File

@ -59,7 +59,7 @@ class GDNAttentionMetadata:
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
_cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: int = 1

View File

@ -20,7 +20,7 @@ M = TypeVar("M")
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
reorder_batch_threshold: int = 1
cudagraph_support: ClassVar[AttentionCGSupport] = (
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)

View File

@ -29,7 +29,7 @@ logger = init_logger(__name__)
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture
cudagraph_support: ClassVar[AttentionCGSupport] = (
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)

View File

@ -92,7 +92,7 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
reorder_batch_threshold: int = 512 # process small prefills with decode pathway

View File

@ -29,7 +29,7 @@ FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM

View File

@ -96,7 +96,7 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
reorder_batch_threshold: int = 128 # process small prefills with decode pathway
# ^ TODO(matt): tune this

View File

@ -248,7 +248,7 @@ def triton_convert_req_index_to_global_index(
@dataclass
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
def __init__(
self,

View File

@ -206,7 +206,7 @@ def split_prefill_chunks(
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
cudagraph_support: ClassVar[AttentionCGSupport] = (
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)

View File

@ -55,7 +55,7 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# TODO(luka, lucas): audit this as part of:
# https://github.com/vllm-project/vllm/issues/22945
cudagraph_support: ClassVar[AttentionCGSupport] = (
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)

View File

@ -251,7 +251,7 @@ class AiterFlashAttentionMetadata:
class AiterFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
):
cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
_cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: int = 1
def __init__(

View File

@ -63,7 +63,7 @@ class RocmAttentionMetadata:
class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(
self,

View File

@ -67,7 +67,7 @@ class TritonAttentionMetadata:
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(
self,

View File

@ -244,7 +244,8 @@ class AttentionCGSupport(enum.Enum):
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention (default: no).
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
# Do not access directly. Call get_cudagraph_support() instead.
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
@ -263,6 +264,15 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
self.vllm_config = vllm_config
self.device = device
@classmethod
def get_cudagraph_support(
cls: type["AttentionMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
"""Get the cudagraph support level of this builder class."""
return cls._cudagraph_support
def _init_reorder_batch_threshold(
self,
reorder_batch_threshold: int | None = 1,

View File

@ -4167,14 +4167,16 @@ class GPUModelRunner(
return attn_groups
attention_backend_maps = []
attention_backend_set: set[type[AttentionBackend]] = set()
attention_backend_list = []
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
attention_backend_maps.append(attn_backends[0])
attention_backend_set.update(attn_backends[1])
attention_backend_list.append(attn_backends[1])
# Resolve cudagraph_mode before actually initialize metadata_builders
self._check_and_update_cudagraph_mode(attention_backend_set)
self._check_and_update_cudagraph_mode(
attention_backend_list, kv_cache_config.kv_cache_groups
)
for i, attn_backend_map in enumerate(attention_backend_maps):
self.attn_groups.append(create_attn_groups(attn_backend_map, i))
@ -4203,22 +4205,31 @@ class GPUModelRunner(
self.calculate_reorder_batch_threshold()
def _check_and_update_cudagraph_mode(
self, attention_backends: set[type[AttentionBackend]]
self,
attention_backends: list[set[type[AttentionBackend]]],
kv_cache_groups: list[KVCacheGroupSpec],
) -> None:
"""
Resolve the cudagraph_mode when there are multiple attention
backends with potential conflicting CUDA graph support.
groups with potential conflicting CUDA graph support.
Then initialize the cudagraph_dispatcher based on the resolved
cudagraph_mode.
"""
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_backend_name = None
for attn_backend in attention_backends:
builder_cls = attn_backend.get_builder_cls()
if builder_cls.cudagraph_support.value < min_cg_support.value:
min_cg_support = builder_cls.cudagraph_support
min_cg_backend_name = attn_backend.__name__
for attn_backend_set, kv_cache_group in zip(
attention_backends, kv_cache_groups
):
for attn_backend in attn_backend_set:
builder_cls = attn_backend.get_builder_cls()
cg_support = builder_cls.get_cudagraph_support(
self.vllm_config, kv_cache_group.kv_cache_spec
)
if cg_support.value < min_cg_support.value:
min_cg_support = cg_support
min_cg_backend_name = attn_backend.__name__
# Flexible resolve the cudagraph mode
cudagraph_mode = self.compilation_config.cudagraph_mode
# check cudagraph for mixed batch is supported