From 5ff5d94e77851f8ca11592bbb9aa414e65f4c353 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 21 Oct 2025 01:51:14 -0400 Subject: [PATCH] [Bugfix] Fix gpt-oss w4a8 DP/EP on B200 (#26729) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath Co-authored-by: Michael Goin --- tests/quantization/test_blackwell_moe.py | 20 ++++++++++++++++++ .../model_executor/layers/fused_moe/config.py | 20 ++++++++++++++++++ .../layers/quantization/mxfp4.py | 18 ++++++++++++++++ .../layers/quantization/utils/mxfp8_utils.py | 5 ++++- vllm/model_executor/warmup/kernel_warmup.py | 21 ++++++++++++++++++- 5 files changed, 82 insertions(+), 2 deletions(-) diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index 3773d1f2afa6..3cae6f46147b 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -170,3 +170,23 @@ def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatc def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1") can_initialize("openai/gpt-oss-20b", hf_overrides=HF_OVERRIDE_TEXT) + + +def test_gptoss_dp2_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1") + monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput") + can_initialize( + "openai/gpt-oss-20b", + extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"], + hf_overrides=HF_OVERRIDE_TEXT, + ) + + +def test_gptoss_dp2_mxfp4bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1") + monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput") + can_initialize( + "openai/gpt-oss-20b", + extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"], + hf_overrides=HF_OVERRIDE_TEXT, + ) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 924736b274f3..9e84f0d00b08 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -517,6 +517,26 @@ def mxfp4_w4a16_moe_quant_config( ) +def mxfp4_mxfp8_moe_quant_config( + w1_scale: Union[torch.Tensor, "PrecisionConfig"], + w2_scale: Union[torch.Tensor, "PrecisionConfig"], + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, + block_shape: list[int] | None = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for mxfp4 activations and mxfp4 weights. + """ + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc("mxfp8"), + _a2=FusedMoEQuantDesc("mxfp8"), + _w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias), + _w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias), + ) + + def ocp_mx_moe_quant_config( quant_dtype: str, w1_scale: Union[torch.Tensor, "PrecisionConfig"], diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 12b0c208dd34..cf983fcf43c9 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, + mxfp4_mxfp8_moe_quant_config, mxfp4_w4a16_moe_quant_config, ocp_mx_moe_quant_config, ) @@ -747,6 +748,23 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): w1_scale=w1_scale, w2_scale=w2_scale, ) + elif self.mxfp4_backend in [ + Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, + Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS, + ]: + return mxfp4_mxfp8_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + ) + elif self.mxfp4_backend in [Mxfp4Backend.SM100_FI_MXFP4_BF16]: + return mxfp4_w4a16_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + ) else: w1_scale = layer.w13_weight_scale w2_scale = layer.w2_weight_scale diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py index 248b2d6c4af2..bed771fd1c4d 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py @@ -18,4 +18,7 @@ def mxfp8_e4m3_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: "`pip install flashinfer`" ) from err - return mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False) + x_q, x_scales = mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False) + if x_scales.ndim == 1: + x_scales = x_scales.view(x.size(0), -1) + return x_q, x_scales diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 28792338f036..79d1927d3210 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING import torch import vllm.envs as envs +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup from vllm.platforms import current_platform @@ -24,6 +25,20 @@ if TYPE_CHECKING: logger = init_logger(__name__) +def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool: + """ + Record known issues with vllm + flashinfer autotune here. Return True if + and only if flashinfer autotune will run through without issues. + """ + return not ( + vllm_config.parallel_config.data_parallel_size > 1 + and ( + envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + ) + ) + + def kernel_warmup(worker: "Worker"): # Deep GEMM warmup do_deep_gemm_warmup = ( @@ -37,7 +52,11 @@ def kernel_warmup(worker: "Worker"): deep_gemm_warmup(model, max_tokens) # FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs - if has_flashinfer() and current_platform.has_device_capability(90): + if ( + has_flashinfer() + and current_platform.has_device_capability(90) + and flashinfer_autotune_supported(worker.vllm_config) + ): flashinfer_autotune(worker.model_runner) # FlashInfer attention warmup