vllm/vllm/model_executor/warmup/kernel_warmup.py
Varun Sundar Rabindranath e5e076cad7
[BugFix] Stopgap - Flashinfer Autotuner + GPT-OSS + DP/TP (#27762)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
2025-10-30 08:24:31 -07:00

116 lines
4.1 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Warmup kernels used during model execution.
This is useful specifically for JIT'ed kernels as we don't want JIT'ing to
happen during model execution.
"""
from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.gpu_worker import Worker
logger = init_logger(__name__)
def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool:
"""
Record known issues with vllm + flashinfer autotune here. Return True if
and only if flashinfer autotune will run through without issues.
"""
is_tp_or_dp = (vllm_config.parallel_config.data_parallel_size > 1) or (
vllm_config.parallel_config.tensor_parallel_size > 1
)
is_fi_mxfp4_backend = (
envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
) or (
current_platform.is_cuda() and current_platform.is_device_capability(100)
) # on >=sm100, default mxfp4 backend is flashinfer
is_eager = vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
return not (is_tp_or_dp and is_fi_mxfp4_backend and is_eager)
def kernel_warmup(worker: "Worker"):
# Deep GEMM warmup
do_deep_gemm_warmup = (
envs.VLLM_USE_DEEP_GEMM
and is_deep_gemm_supported()
and envs.VLLM_DEEP_GEMM_WARMUP != "skip"
)
if do_deep_gemm_warmup:
model = worker.get_model()
max_tokens = worker.scheduler_config.max_num_batched_tokens
deep_gemm_warmup(model, max_tokens)
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
if (
has_flashinfer()
and current_platform.has_device_capability(90)
and flashinfer_autotune_supported(worker.vllm_config)
):
flashinfer_autotune(worker.model_runner)
# FlashInfer attention warmup
# Only warmup if the model has FlashInfer attention groups
# and is not a pooling model
def _is_flashinfer_backend(backend):
try:
return backend.get_name() == "FLASHINFER"
except NotImplementedError:
return False
if not worker.model_runner.is_pooling_model and all(
_is_flashinfer_backend(group.backend)
for groups in worker.model_runner.attn_groups
for group in groups
):
logger.info("Warming up FlashInfer attention.")
# Warmup with mixed batch containing both prefill and decode tokens
# This is to warm up both prefill and decode attention kernels
worker.model_runner._dummy_run(
num_tokens=16,
skip_eplb=True,
is_profile=True,
force_attention=True,
create_mixed_batch=True,
)
def flashinfer_autotune(runner: "GPUModelRunner") -> None:
"""
Autotune FlashInfer operations.
FlashInfer have many implementations for the same operation,
autotuning runs benchmarks for each implementation and stores
the results. The results are cached transparently and
future calls to FlashInfer will use the best implementation.
Without autotuning, FlashInfer will rely on heuristics, which may
be significantly slower.
"""
from vllm.utils.flashinfer import autotune
with torch.inference_mode(), autotune():
# We skip EPLB here since we don't want to record dummy metrics
# When autotuning with number of tokens m, flashinfer will autotune
# operations for all number of tokens up to m.
# So we only need to run with the max number of tokens.
runner._dummy_run(
runner.scheduler_config.max_num_batched_tokens,
skip_eplb=True,
is_profile=True,
)