mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-04 12:37:07 +08:00
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
116 lines
4.1 KiB
Python
116 lines
4.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
Warmup kernels used during model execution.
|
|
This is useful specifically for JIT'ed kernels as we don't want JIT'ing to
|
|
happen during model execution.
|
|
"""
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch
|
|
|
|
import vllm.envs as envs
|
|
from vllm.config import CUDAGraphMode, VllmConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
|
from vllm.utils.flashinfer import has_flashinfer
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
|
from vllm.v1.worker.gpu_worker import Worker
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool:
|
|
"""
|
|
Record known issues with vllm + flashinfer autotune here. Return True if
|
|
and only if flashinfer autotune will run through without issues.
|
|
"""
|
|
is_tp_or_dp = (vllm_config.parallel_config.data_parallel_size > 1) or (
|
|
vllm_config.parallel_config.tensor_parallel_size > 1
|
|
)
|
|
is_fi_mxfp4_backend = (
|
|
envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
|
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
|
|
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
|
|
) or (
|
|
current_platform.is_cuda() and current_platform.is_device_capability(100)
|
|
) # on >=sm100, default mxfp4 backend is flashinfer
|
|
is_eager = vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
|
|
|
return not (is_tp_or_dp and is_fi_mxfp4_backend and is_eager)
|
|
|
|
|
|
def kernel_warmup(worker: "Worker"):
|
|
# Deep GEMM warmup
|
|
do_deep_gemm_warmup = (
|
|
envs.VLLM_USE_DEEP_GEMM
|
|
and is_deep_gemm_supported()
|
|
and envs.VLLM_DEEP_GEMM_WARMUP != "skip"
|
|
)
|
|
if do_deep_gemm_warmup:
|
|
model = worker.get_model()
|
|
max_tokens = worker.scheduler_config.max_num_batched_tokens
|
|
deep_gemm_warmup(model, max_tokens)
|
|
|
|
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
|
|
if (
|
|
has_flashinfer()
|
|
and current_platform.has_device_capability(90)
|
|
and flashinfer_autotune_supported(worker.vllm_config)
|
|
):
|
|
flashinfer_autotune(worker.model_runner)
|
|
|
|
# FlashInfer attention warmup
|
|
# Only warmup if the model has FlashInfer attention groups
|
|
# and is not a pooling model
|
|
def _is_flashinfer_backend(backend):
|
|
try:
|
|
return backend.get_name() == "FLASHINFER"
|
|
except NotImplementedError:
|
|
return False
|
|
|
|
if not worker.model_runner.is_pooling_model and all(
|
|
_is_flashinfer_backend(group.backend)
|
|
for groups in worker.model_runner.attn_groups
|
|
for group in groups
|
|
):
|
|
logger.info("Warming up FlashInfer attention.")
|
|
# Warmup with mixed batch containing both prefill and decode tokens
|
|
# This is to warm up both prefill and decode attention kernels
|
|
worker.model_runner._dummy_run(
|
|
num_tokens=16,
|
|
skip_eplb=True,
|
|
is_profile=True,
|
|
force_attention=True,
|
|
create_mixed_batch=True,
|
|
)
|
|
|
|
|
|
def flashinfer_autotune(runner: "GPUModelRunner") -> None:
|
|
"""
|
|
Autotune FlashInfer operations.
|
|
FlashInfer have many implementations for the same operation,
|
|
autotuning runs benchmarks for each implementation and stores
|
|
the results. The results are cached transparently and
|
|
future calls to FlashInfer will use the best implementation.
|
|
Without autotuning, FlashInfer will rely on heuristics, which may
|
|
be significantly slower.
|
|
"""
|
|
from vllm.utils.flashinfer import autotune
|
|
|
|
with torch.inference_mode(), autotune():
|
|
# We skip EPLB here since we don't want to record dummy metrics
|
|
# When autotuning with number of tokens m, flashinfer will autotune
|
|
# operations for all number of tokens up to m.
|
|
# So we only need to run with the max number of tokens.
|
|
runner._dummy_run(
|
|
runner.scheduler_config.max_num_batched_tokens,
|
|
skip_eplb=True,
|
|
is_profile=True,
|
|
)
|