[Bugfix] Fix gpt-oss w4a8 DP/EP on B200 (#26729)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-10-21 01:51:14 -04:00 committed by GitHub
parent f95da13c3d
commit 5ff5d94e77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 82 additions and 2 deletions

View File

@ -170,3 +170,23 @@ def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatc
def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
can_initialize("openai/gpt-oss-20b", hf_overrides=HF_OVERRIDE_TEXT)
def test_gptoss_dp2_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput")
can_initialize(
"openai/gpt-oss-20b",
extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"],
hf_overrides=HF_OVERRIDE_TEXT,
)
def test_gptoss_dp2_mxfp4bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1")
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput")
can_initialize(
"openai/gpt-oss-20b",
extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"],
hf_overrides=HF_OVERRIDE_TEXT,
)

View File

@ -517,6 +517,26 @@ def mxfp4_w4a16_moe_quant_config(
)
def mxfp4_mxfp8_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and mxfp4 weights.
"""
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc("mxfp8"),
_a2=FusedMoEQuantDesc("mxfp8"),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
)
def ocp_mx_moe_quant_config(
quant_dtype: str,
w1_scale: Union[torch.Tensor, "PrecisionConfig"],

View File

@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe import (
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
mxfp4_mxfp8_moe_quant_config,
mxfp4_w4a16_moe_quant_config,
ocp_mx_moe_quant_config,
)
@ -747,6 +748,23 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1_scale=w1_scale,
w2_scale=w2_scale,
)
elif self.mxfp4_backend in [
Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM,
Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS,
]:
return mxfp4_mxfp8_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
elif self.mxfp4_backend in [Mxfp4Backend.SM100_FI_MXFP4_BF16]:
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
else:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale

View File

@ -18,4 +18,7 @@ def mxfp8_e4m3_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"`pip install flashinfer`"
) from err
return mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False)
x_q, x_scales = mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False)
if x_scales.ndim == 1:
x_scales = x_scales.view(x.size(0), -1)
return x_q, x_scales

View File

@ -11,6 +11,7 @@ from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
from vllm.platforms import current_platform
@ -24,6 +25,20 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool:
"""
Record known issues with vllm + flashinfer autotune here. Return True if
and only if flashinfer autotune will run through without issues.
"""
return not (
vllm_config.parallel_config.data_parallel_size > 1
and (
envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
)
)
def kernel_warmup(worker: "Worker"):
# Deep GEMM warmup
do_deep_gemm_warmup = (
@ -37,7 +52,11 @@ def kernel_warmup(worker: "Worker"):
deep_gemm_warmup(model, max_tokens)
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
if has_flashinfer() and current_platform.has_device_capability(90):
if (
has_flashinfer()
and current_platform.has_device_capability(90)
and flashinfer_autotune_supported(worker.vllm_config)
):
flashinfer_autotune(worker.model_runner)
# FlashInfer attention warmup