[Perf] Warmup FlashInfer attention during startup (#23439)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
Michael Goin 2025-09-10 18:03:17 -04:00 committed by GitHub
parent b5e383cd8b
commit fba7856581
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 55 additions and 20 deletions

View File

@ -10,6 +10,7 @@ from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
@ -19,6 +20,8 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.gpu_worker import Worker
logger = init_logger(__name__)
def kernel_warmup(worker: "Worker"):
# Deep GEMM warmup
@ -30,10 +33,33 @@ def kernel_warmup(worker: "Worker"):
max_tokens = worker.scheduler_config.max_num_batched_tokens
deep_gemm_warmup(model, max_tokens)
# FlashInfer autotune for Blackwell (SM 10.0) GPUs
# FlashInfer kernel autotune for Blackwell (SM 10.0) GPUs
if has_flashinfer() and current_platform.is_device_capability(100):
flashinfer_autotune(worker.model_runner)
# FlashInfer attention warmup
# Only warmup if the model has FlashInfer attention groups
# and is not a pooling model
def _is_flashinfer_backend(backend):
try:
return backend.get_name() == "FLASHINFER_VLLM_V1"
except NotImplementedError:
return False
if not worker.model_runner.is_pooling_model and all(
_is_flashinfer_backend(group.backend)
for groups in worker.model_runner.attn_groups for group in groups):
logger.info("Warming up FlashInfer attention.")
# Warmup with mixed batch containing both prefill and decode tokens
# This is to warm up both prefill and decode attention kernels
worker.model_runner._dummy_run(
num_tokens=16,
skip_eplb=True,
is_profile=True,
force_attention=True,
create_mixed_batch=True,
)
def flashinfer_autotune(runner: "GPUModelRunner") -> None:
"""

View File

@ -549,22 +549,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
return attn_metadata
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with FlashInfer.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"FlashInfer only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
return self.build(0, m)
def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting

View File

@ -2578,6 +2578,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
uniform_decode: bool = False,
skip_eplb: bool = False,
is_profile: bool = False,
create_mixed_batch: bool = False,
remove_lora: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
@ -2596,6 +2597,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
uniform_decode: If True, the batch is a uniform decode batch.
skip_eplb: If True, skip EPLB state update.
is_profile: If True, this is a profile run.
create_mixed_batch: If True, create a mixed batch with both decode
(1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run
"""
assert cudagraph_runtime_mode in {
@ -2627,7 +2630,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
if uniform_decode:
if create_mixed_batch:
assert not uniform_decode
# Create mixed batch:
# first half decode tokens, second half one prefill
num_decode_tokens = num_tokens // 2
num_prefill_tokens = num_tokens - num_decode_tokens
num_reqs = num_decode_tokens + 1
# Create decode requests (1 token each) followed by prefill request
num_scheduled_tokens_list = [1] * num_decode_tokens + [
num_prefill_tokens
]
# Note: Overriding max_query_len to be the prefill tokens
max_query_len = num_prefill_tokens
elif uniform_decode:
assert not create_mixed_batch
num_reqs = cdiv(num_tokens, max_query_len)
assert num_reqs <= max_num_reqs, \
"Do not capture num_reqs > max_num_reqs for uniform batch"
@ -2652,8 +2670,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
attn_metadata = {}
# Make sure max_model_len is used at the graph capture time.
self.seq_lens.np[:num_reqs] = self.max_model_len
if create_mixed_batch:
# In the mixed batch mode (used for FI warmup), we use
# shorter sequence lengths to run faster.
# TODO(luka) better system for describing dummy batches
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
else:
# Make sure max_model_len is used at the graph capture time.
seq_lens = self.max_model_len
self.seq_lens.np[:num_reqs] = seq_lens
self.seq_lens.np[num_reqs:] = 0
self.seq_lens.copy_to_gpu()