mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 04:04:29 +08:00
[Perf] Refactor cudagraph_support to enable full CUDA graphs for spec decoding with FlashInfer (#28479)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
parent
a742134cc5
commit
304419576a
@ -177,8 +177,9 @@ The following table lists backends that support full CUDA Graphs at the time of
|
||||
| FlashAttention v3 | `ALWAYS` | has unified routine for both batches, so `FULL` mode is good |
|
||||
| Triton Attention | `ALWAYS` | prefer `FULL_AND_PIECEWISE` since it has different kernels for prefill/mixed and pure decode batches |
|
||||
| AITER FlashAttention | `UNIFORM_BATCH`| |
|
||||
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
||||
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | Will be set to `UNIFORM_BATCH` when using TRTLLM attention on Blackwell |
|
||||
| FlashMLA | `UNIFORM_BATCH` | |
|
||||
| FlashInferMLA | `UNIFORM_BATCH` | |
|
||||
| AITER MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
||||
| CUTLASS MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
||||
| Mamba attention| `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
||||
|
||||
@ -32,7 +32,7 @@ def create_chunked_local_attention_backend(
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
||||
|
||||
def build(
|
||||
self,
|
||||
|
||||
@ -207,7 +207,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
# to FULL_AND_PIECEWISE.
|
||||
# TODO(luka, lucas): audit FA2 as part of:
|
||||
# https://github.com/vllm-project/vllm/issues/22945
|
||||
cudagraph_support = (
|
||||
_cudagraph_support = (
|
||||
AttentionCGSupport.ALWAYS
|
||||
if get_flash_attn_version() == 3
|
||||
else AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
@ -15,6 +15,7 @@ from flashinfer import (
|
||||
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||
from flashinfer.utils import FP4Tensor
|
||||
from typing_extensions import override
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (
|
||||
@ -274,10 +275,6 @@ class FlashInferMetadata:
|
||||
|
||||
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(
|
||||
@ -355,6 +352,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
else:
|
||||
self.q_data_type = self.model_config.dtype
|
||||
|
||||
# Prefer TRTLLM attention for decoding in all cases.
|
||||
# This allows us to use AttentionCGSupport.UNIFORM_BATCH mode.
|
||||
self.use_trtllm_decode_attention = can_use_trtllm
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)
|
||||
|
||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||
@ -412,6 +412,24 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
"passing --block-size 32 or --block-size 64."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_cudagraph_support(
|
||||
cls: type["FlashInferMetadataBuilder"],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
) -> AttentionCGSupport:
|
||||
has_trtllm_support = can_use_trtllm_attention(
|
||||
num_qo_heads=vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
),
|
||||
num_kv_heads=kv_cache_spec.num_kv_heads,
|
||||
)
|
||||
if has_trtllm_support:
|
||||
return AttentionCGSupport.UNIFORM_BATCH
|
||||
else:
|
||||
return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
|
||||
@ -573,17 +591,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
has_sinks=self.has_sinks,
|
||||
has_spec=uses_spec_reorder,
|
||||
)
|
||||
decode_use_trtllm = use_trtllm_attention(
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
num_decode_tokens,
|
||||
max_seq_len,
|
||||
self.cache_dtype,
|
||||
self.q_data_type,
|
||||
is_prefill=False,
|
||||
has_sinks=self.has_sinks,
|
||||
has_spec=uses_spec_reorder,
|
||||
)
|
||||
decode_use_trtllm = self.use_trtllm_decode_attention
|
||||
|
||||
if not (prefill_use_trtllm and decode_use_trtllm):
|
||||
if self.has_sinks:
|
||||
|
||||
@ -59,7 +59,7 @@ class GDNAttentionMetadata:
|
||||
|
||||
|
||||
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
|
||||
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
|
||||
_cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ M = TypeVar("M")
|
||||
|
||||
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
reorder_batch_threshold: int = 1
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ logger = init_logger(__name__)
|
||||
|
||||
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
# enable full CUDA Graph support for decode-only capture
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
|
||||
@ -92,7 +92,7 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
|
||||
|
||||
|
||||
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
|
||||
reorder_batch_threshold: int = 512 # process small prefills with decode pathway
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
|
||||
|
||||
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
|
||||
|
||||
|
||||
@ -96,7 +96,7 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
reorder_batch_threshold: int = 128 # process small prefills with decode pathway
|
||||
# ^ TODO(matt): tune this
|
||||
|
||||
@ -248,7 +248,7 @@ def triton_convert_req_index_to_global_index(
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -206,7 +206,7 @@ def split_prefill_chunks(
|
||||
|
||||
|
||||
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
# TODO(luka, lucas): audit this as part of:
|
||||
# https://github.com/vllm-project/vllm/issues/22945
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
|
||||
@ -251,7 +251,7 @@ class AiterFlashAttentionMetadata:
|
||||
class AiterFlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
|
||||
):
|
||||
cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
_cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -63,7 +63,7 @@ class RocmAttentionMetadata:
|
||||
|
||||
|
||||
class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -67,7 +67,7 @@ class TritonAttentionMetadata:
|
||||
|
||||
|
||||
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -244,7 +244,8 @@ class AttentionCGSupport(enum.Enum):
|
||||
|
||||
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
# Does this backend/builder support CUDA Graphs for attention (default: no).
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
||||
# Do not access directly. Call get_cudagraph_support() instead.
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
# length that will be pulled into the front of the batch.
|
||||
@ -263,6 +264,15 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
cls: type["AttentionMetadataBuilder"],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
) -> AttentionCGSupport:
|
||||
"""Get the cudagraph support level of this builder class."""
|
||||
return cls._cudagraph_support
|
||||
|
||||
def _init_reorder_batch_threshold(
|
||||
self,
|
||||
reorder_batch_threshold: int | None = 1,
|
||||
|
||||
@ -4167,14 +4167,16 @@ class GPUModelRunner(
|
||||
return attn_groups
|
||||
|
||||
attention_backend_maps = []
|
||||
attention_backend_set: set[type[AttentionBackend]] = set()
|
||||
attention_backend_list = []
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
|
||||
attention_backend_maps.append(attn_backends[0])
|
||||
attention_backend_set.update(attn_backends[1])
|
||||
attention_backend_list.append(attn_backends[1])
|
||||
|
||||
# Resolve cudagraph_mode before actually initialize metadata_builders
|
||||
self._check_and_update_cudagraph_mode(attention_backend_set)
|
||||
self._check_and_update_cudagraph_mode(
|
||||
attention_backend_list, kv_cache_config.kv_cache_groups
|
||||
)
|
||||
|
||||
for i, attn_backend_map in enumerate(attention_backend_maps):
|
||||
self.attn_groups.append(create_attn_groups(attn_backend_map, i))
|
||||
@ -4203,22 +4205,31 @@ class GPUModelRunner(
|
||||
self.calculate_reorder_batch_threshold()
|
||||
|
||||
def _check_and_update_cudagraph_mode(
|
||||
self, attention_backends: set[type[AttentionBackend]]
|
||||
self,
|
||||
attention_backends: list[set[type[AttentionBackend]]],
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
) -> None:
|
||||
"""
|
||||
Resolve the cudagraph_mode when there are multiple attention
|
||||
backends with potential conflicting CUDA graph support.
|
||||
groups with potential conflicting CUDA graph support.
|
||||
Then initialize the cudagraph_dispatcher based on the resolved
|
||||
cudagraph_mode.
|
||||
"""
|
||||
min_cg_support = AttentionCGSupport.ALWAYS
|
||||
min_cg_backend_name = None
|
||||
|
||||
for attn_backend in attention_backends:
|
||||
builder_cls = attn_backend.get_builder_cls()
|
||||
if builder_cls.cudagraph_support.value < min_cg_support.value:
|
||||
min_cg_support = builder_cls.cudagraph_support
|
||||
min_cg_backend_name = attn_backend.__name__
|
||||
for attn_backend_set, kv_cache_group in zip(
|
||||
attention_backends, kv_cache_groups
|
||||
):
|
||||
for attn_backend in attn_backend_set:
|
||||
builder_cls = attn_backend.get_builder_cls()
|
||||
|
||||
cg_support = builder_cls.get_cudagraph_support(
|
||||
self.vllm_config, kv_cache_group.kv_cache_spec
|
||||
)
|
||||
if cg_support.value < min_cg_support.value:
|
||||
min_cg_support = cg_support
|
||||
min_cg_backend_name = attn_backend.__name__
|
||||
# Flexible resolve the cudagraph mode
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
# check cudagraph for mixed batch is supported
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user