mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:55:01 +08:00
[Perf] Warmup FlashInfer attention during startup (#23439)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
parent
b5e383cd8b
commit
fba7856581
@ -10,6 +10,7 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
@ -19,6 +20,8 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.gpu_worker import Worker
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def kernel_warmup(worker: "Worker"):
|
||||
# Deep GEMM warmup
|
||||
@ -30,10 +33,33 @@ def kernel_warmup(worker: "Worker"):
|
||||
max_tokens = worker.scheduler_config.max_num_batched_tokens
|
||||
deep_gemm_warmup(model, max_tokens)
|
||||
|
||||
# FlashInfer autotune for Blackwell (SM 10.0) GPUs
|
||||
# FlashInfer kernel autotune for Blackwell (SM 10.0) GPUs
|
||||
if has_flashinfer() and current_platform.is_device_capability(100):
|
||||
flashinfer_autotune(worker.model_runner)
|
||||
|
||||
# FlashInfer attention warmup
|
||||
# Only warmup if the model has FlashInfer attention groups
|
||||
# and is not a pooling model
|
||||
def _is_flashinfer_backend(backend):
|
||||
try:
|
||||
return backend.get_name() == "FLASHINFER_VLLM_V1"
|
||||
except NotImplementedError:
|
||||
return False
|
||||
|
||||
if not worker.model_runner.is_pooling_model and all(
|
||||
_is_flashinfer_backend(group.backend)
|
||||
for groups in worker.model_runner.attn_groups for group in groups):
|
||||
logger.info("Warming up FlashInfer attention.")
|
||||
# Warmup with mixed batch containing both prefill and decode tokens
|
||||
# This is to warm up both prefill and decode attention kernels
|
||||
worker.model_runner._dummy_run(
|
||||
num_tokens=16,
|
||||
skip_eplb=True,
|
||||
is_profile=True,
|
||||
force_attention=True,
|
||||
create_mixed_batch=True,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_autotune(runner: "GPUModelRunner") -> None:
|
||||
"""
|
||||
|
||||
@ -549,22 +549,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata):
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with FlashInfer.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert m.num_reqs == m.num_actual_tokens, \
|
||||
"FlashInfer only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
m.max_query_len = 1 # decode-only
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
|
||||
# TODO: The cascade wrapper currently does not support setting
|
||||
|
||||
@ -2578,6 +2578,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
uniform_decode: bool = False,
|
||||
skip_eplb: bool = False,
|
||||
is_profile: bool = False,
|
||||
create_mixed_batch: bool = False,
|
||||
remove_lora: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@ -2596,6 +2597,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
uniform_decode: If True, the batch is a uniform decode batch.
|
||||
skip_eplb: If True, skip EPLB state update.
|
||||
is_profile: If True, this is a profile run.
|
||||
create_mixed_batch: If True, create a mixed batch with both decode
|
||||
(1 token) and prefill (multiple tokens) requests.
|
||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||
"""
|
||||
assert cudagraph_runtime_mode in {
|
||||
@ -2627,7 +2630,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# has num_tokens in total.
|
||||
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
||||
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
if uniform_decode:
|
||||
if create_mixed_batch:
|
||||
assert not uniform_decode
|
||||
# Create mixed batch:
|
||||
# first half decode tokens, second half one prefill
|
||||
num_decode_tokens = num_tokens // 2
|
||||
num_prefill_tokens = num_tokens - num_decode_tokens
|
||||
num_reqs = num_decode_tokens + 1
|
||||
|
||||
# Create decode requests (1 token each) followed by prefill request
|
||||
num_scheduled_tokens_list = [1] * num_decode_tokens + [
|
||||
num_prefill_tokens
|
||||
]
|
||||
# Note: Overriding max_query_len to be the prefill tokens
|
||||
max_query_len = num_prefill_tokens
|
||||
elif uniform_decode:
|
||||
assert not create_mixed_batch
|
||||
num_reqs = cdiv(num_tokens, max_query_len)
|
||||
assert num_reqs <= max_num_reqs, \
|
||||
"Do not capture num_reqs > max_num_reqs for uniform batch"
|
||||
@ -2652,8 +2670,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
attn_metadata = {}
|
||||
|
||||
# Make sure max_model_len is used at the graph capture time.
|
||||
self.seq_lens.np[:num_reqs] = self.max_model_len
|
||||
if create_mixed_batch:
|
||||
# In the mixed batch mode (used for FI warmup), we use
|
||||
# shorter sequence lengths to run faster.
|
||||
# TODO(luka) better system for describing dummy batches
|
||||
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
|
||||
else:
|
||||
# Make sure max_model_len is used at the graph capture time.
|
||||
seq_lens = self.max_model_len
|
||||
self.seq_lens.np[:num_reqs] = seq_lens
|
||||
self.seq_lens.np[num_reqs:] = 0
|
||||
self.seq_lens.copy_to_gpu()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user