diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 761172e4d3616..e42e34ebc77b9 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING import torch import vllm.envs as envs +from vllm.logger import init_logger from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup from vllm.platforms import current_platform from vllm.utils.deep_gemm import is_deep_gemm_supported @@ -19,6 +20,8 @@ if TYPE_CHECKING: from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_worker import Worker +logger = init_logger(__name__) + def kernel_warmup(worker: "Worker"): # Deep GEMM warmup @@ -30,10 +33,33 @@ def kernel_warmup(worker: "Worker"): max_tokens = worker.scheduler_config.max_num_batched_tokens deep_gemm_warmup(model, max_tokens) - # FlashInfer autotune for Blackwell (SM 10.0) GPUs + # FlashInfer kernel autotune for Blackwell (SM 10.0) GPUs if has_flashinfer() and current_platform.is_device_capability(100): flashinfer_autotune(worker.model_runner) + # FlashInfer attention warmup + # Only warmup if the model has FlashInfer attention groups + # and is not a pooling model + def _is_flashinfer_backend(backend): + try: + return backend.get_name() == "FLASHINFER_VLLM_V1" + except NotImplementedError: + return False + + if not worker.model_runner.is_pooling_model and all( + _is_flashinfer_backend(group.backend) + for groups in worker.model_runner.attn_groups for group in groups): + logger.info("Warming up FlashInfer attention.") + # Warmup with mixed batch containing both prefill and decode tokens + # This is to warm up both prefill and decode attention kernels + worker.model_runner._dummy_run( + num_tokens=16, + skip_eplb=True, + is_profile=True, + force_attention=True, + create_mixed_batch=True, + ) + def flashinfer_autotune(runner: "GPUModelRunner") -> None: """ diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index afa5a7c14d4d0..9e05cc8ab2f18 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -549,22 +549,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) return attn_metadata - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): - """ - This method builds the metadata for full cudagraph capture. - Currently, only decode is supported for full cudagraphs with FlashInfer. - """ - m = common_attn_metadata - - assert m.num_reqs == m.num_actual_tokens, \ - "FlashInfer only supports decode-only full CUDAGraph capture. " \ - "Make sure all cudagraph capture sizes <= max_num_seq." - - m.max_query_len = 1 # decode-only - - return self.build(0, m) - def use_cascade_attention(self, *args, **kwargs) -> bool: if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: # TODO: The cascade wrapper currently does not support setting diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d3822251b5b67..b75756fbdae85 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2578,6 +2578,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): uniform_decode: bool = False, skip_eplb: bool = False, is_profile: bool = False, + create_mixed_batch: bool = False, remove_lora: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -2596,6 +2597,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): uniform_decode: If True, the batch is a uniform decode batch. skip_eplb: If True, skip EPLB state update. is_profile: If True, this is a profile run. + create_mixed_batch: If True, create a mixed batch with both decode + (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run """ assert cudagraph_runtime_mode in { @@ -2627,7 +2630,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # has num_tokens in total. assert num_tokens <= self.scheduler_config.max_num_batched_tokens max_num_reqs = self.scheduler_config.max_num_seqs - if uniform_decode: + if create_mixed_batch: + assert not uniform_decode + # Create mixed batch: + # first half decode tokens, second half one prefill + num_decode_tokens = num_tokens // 2 + num_prefill_tokens = num_tokens - num_decode_tokens + num_reqs = num_decode_tokens + 1 + + # Create decode requests (1 token each) followed by prefill request + num_scheduled_tokens_list = [1] * num_decode_tokens + [ + num_prefill_tokens + ] + # Note: Overriding max_query_len to be the prefill tokens + max_query_len = num_prefill_tokens + elif uniform_decode: + assert not create_mixed_batch num_reqs = cdiv(num_tokens, max_query_len) assert num_reqs <= max_num_reqs, \ "Do not capture num_reqs > max_num_reqs for uniform batch" @@ -2652,8 +2670,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: attn_metadata = {} - # Make sure max_model_len is used at the graph capture time. - self.seq_lens.np[:num_reqs] = self.max_model_len + if create_mixed_batch: + # In the mixed batch mode (used for FI warmup), we use + # shorter sequence lengths to run faster. + # TODO(luka) better system for describing dummy batches + seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] + else: + # Make sure max_model_len is used at the graph capture time. + seq_lens = self.max_model_len + self.seq_lens.np[:num_reqs] = seq_lens self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu()