mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 19:14:57 +08:00
[BugFix][Performance] Restore flashinfer autotuning for all scenarios (#27904)
This commit is contained in:
parent
53f6e81dfd
commit
4022a9d279
@ -172,21 +172,9 @@ def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch
|
|||||||
can_initialize("openai/gpt-oss-20b", hf_overrides=HF_OVERRIDE_TEXT)
|
can_initialize("openai/gpt-oss-20b", hf_overrides=HF_OVERRIDE_TEXT)
|
||||||
|
|
||||||
|
|
||||||
def test_gptoss_dp2_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
|
def test_gptoss_eager(monkeypatch: pytest.MonkeyPatch):
|
||||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
|
|
||||||
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput")
|
|
||||||
can_initialize(
|
can_initialize(
|
||||||
"openai/gpt-oss-20b",
|
"openai/gpt-oss-20b",
|
||||||
extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"],
|
|
||||||
hf_overrides=HF_OVERRIDE_TEXT,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_gptoss_dp2_mxfp4bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
|
|
||||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1")
|
|
||||||
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput")
|
|
||||||
can_initialize(
|
|
||||||
"openai/gpt-oss-20b",
|
|
||||||
extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"],
|
|
||||||
hf_overrides=HF_OVERRIDE_TEXT,
|
hf_overrides=HF_OVERRIDE_TEXT,
|
||||||
|
extra_args=["--enforce-eager"],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -127,10 +127,17 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
"routing_method_type": 1,
|
"routing_method_type": 1,
|
||||||
"do_finalize": True,
|
"do_finalize": True,
|
||||||
"output": output,
|
"output": output,
|
||||||
"tune_max_num_tokens": self.max_capture_size,
|
"tune_max_num_tokens": max(self.max_capture_size, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
from flashinfer import trtllm_fp4_block_scale_routed_moe
|
from flashinfer import trtllm_fp4_block_scale_routed_moe
|
||||||
|
|
||||||
|
from vllm.utils.flashinfer import autotune
|
||||||
|
|
||||||
|
with autotune(False):
|
||||||
|
# Enable autotune when,
|
||||||
|
# https://github.com/flashinfer-ai/flashinfer/issues/2023 is
|
||||||
|
# resolved.
|
||||||
trtllm_fp4_block_scale_routed_moe(**kwargs)
|
trtllm_fp4_block_scale_routed_moe(**kwargs)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@ -1047,7 +1047,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
None,
|
None,
|
||||||
1 if renormalize else 0, # routing_method_type, renormalize
|
1 if renormalize else 0, # routing_method_type, renormalize
|
||||||
True, # do finalize
|
True, # do finalize
|
||||||
tune_max_num_tokens=self.max_capture_size,
|
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||||
)[0]
|
)[0]
|
||||||
return trtllm_gen_output
|
return trtllm_gen_output
|
||||||
elif (
|
elif (
|
||||||
@ -1122,7 +1122,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
tp_rank=self.moe.tp_rank,
|
tp_rank=self.moe.tp_rank,
|
||||||
ep_size=self.moe.ep_size,
|
ep_size=self.moe.ep_size,
|
||||||
ep_rank=self.moe.ep_rank,
|
ep_rank=self.moe.ep_rank,
|
||||||
tune_max_num_tokens=self.max_capture_size,
|
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||||
**extra_kwargs,
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from typing import TYPE_CHECKING
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
|
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -25,26 +24,6 @@ if TYPE_CHECKING:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool:
|
|
||||||
"""
|
|
||||||
Record known issues with vllm + flashinfer autotune here. Return True if
|
|
||||||
and only if flashinfer autotune will run through without issues.
|
|
||||||
"""
|
|
||||||
is_tp_or_dp = (vllm_config.parallel_config.data_parallel_size > 1) or (
|
|
||||||
vllm_config.parallel_config.tensor_parallel_size > 1
|
|
||||||
)
|
|
||||||
is_fi_mxfp4_backend = (
|
|
||||||
envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
|
||||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
|
|
||||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
|
|
||||||
) or (
|
|
||||||
current_platform.is_cuda() and current_platform.is_device_capability(100)
|
|
||||||
) # on >=sm100, default mxfp4 backend is flashinfer
|
|
||||||
is_eager = vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
|
||||||
|
|
||||||
return not (is_tp_or_dp and is_fi_mxfp4_backend and is_eager)
|
|
||||||
|
|
||||||
|
|
||||||
def kernel_warmup(worker: "Worker"):
|
def kernel_warmup(worker: "Worker"):
|
||||||
# Deep GEMM warmup
|
# Deep GEMM warmup
|
||||||
do_deep_gemm_warmup = (
|
do_deep_gemm_warmup = (
|
||||||
@ -58,11 +37,7 @@ def kernel_warmup(worker: "Worker"):
|
|||||||
deep_gemm_warmup(model, max_tokens)
|
deep_gemm_warmup(model, max_tokens)
|
||||||
|
|
||||||
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
|
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
|
||||||
if (
|
if has_flashinfer() and current_platform.has_device_capability(90):
|
||||||
has_flashinfer()
|
|
||||||
and current_platform.has_device_capability(90)
|
|
||||||
and flashinfer_autotune_supported(worker.vllm_config)
|
|
||||||
):
|
|
||||||
flashinfer_autotune(worker.model_runner)
|
flashinfer_autotune(worker.model_runner)
|
||||||
|
|
||||||
# FlashInfer attention warmup
|
# FlashInfer attention warmup
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user