mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:15:01 +08:00
[Bugfix] Fix gpt-oss w4a8 DP/EP on B200 (#26729)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
f95da13c3d
commit
5ff5d94e77
@ -170,3 +170,23 @@ def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatc
|
||||
def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
|
||||
can_initialize("openai/gpt-oss-20b", hf_overrides=HF_OVERRIDE_TEXT)
|
||||
|
||||
|
||||
def test_gptoss_dp2_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
|
||||
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput")
|
||||
can_initialize(
|
||||
"openai/gpt-oss-20b",
|
||||
extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"],
|
||||
hf_overrides=HF_OVERRIDE_TEXT,
|
||||
)
|
||||
|
||||
|
||||
def test_gptoss_dp2_mxfp4bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1")
|
||||
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput")
|
||||
can_initialize(
|
||||
"openai/gpt-oss-20b",
|
||||
extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"],
|
||||
hf_overrides=HF_OVERRIDE_TEXT,
|
||||
)
|
||||
|
||||
@ -517,6 +517,26 @@ def mxfp4_w4a16_moe_quant_config(
|
||||
)
|
||||
|
||||
|
||||
def mxfp4_mxfp8_moe_quant_config(
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for mxfp4 activations and mxfp4 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc("mxfp8"),
|
||||
_a2=FusedMoEQuantDesc("mxfp8"),
|
||||
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
|
||||
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
|
||||
)
|
||||
|
||||
|
||||
def ocp_mx_moe_quant_config(
|
||||
quant_dtype: str,
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
|
||||
@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
mxfp4_mxfp8_moe_quant_config,
|
||||
mxfp4_w4a16_moe_quant_config,
|
||||
ocp_mx_moe_quant_config,
|
||||
)
|
||||
@ -747,6 +748,23 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
elif self.mxfp4_backend in [
|
||||
Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM,
|
||||
Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS,
|
||||
]:
|
||||
return mxfp4_mxfp8_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
)
|
||||
elif self.mxfp4_backend in [Mxfp4Backend.SM100_FI_MXFP4_BF16]:
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
)
|
||||
else:
|
||||
w1_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
|
||||
@ -18,4 +18,7 @@ def mxfp8_e4m3_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"`pip install flashinfer`"
|
||||
) from err
|
||||
|
||||
return mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False)
|
||||
x_q, x_scales = mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False)
|
||||
if x_scales.ndim == 1:
|
||||
x_scales = x_scales.view(x.size(0), -1)
|
||||
return x_q, x_scales
|
||||
|
||||
@ -11,6 +11,7 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
|
||||
from vllm.platforms import current_platform
|
||||
@ -24,6 +25,20 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool:
|
||||
"""
|
||||
Record known issues with vllm + flashinfer autotune here. Return True if
|
||||
and only if flashinfer autotune will run through without issues.
|
||||
"""
|
||||
return not (
|
||||
vllm_config.parallel_config.data_parallel_size > 1
|
||||
and (
|
||||
envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def kernel_warmup(worker: "Worker"):
|
||||
# Deep GEMM warmup
|
||||
do_deep_gemm_warmup = (
|
||||
@ -37,7 +52,11 @@ def kernel_warmup(worker: "Worker"):
|
||||
deep_gemm_warmup(model, max_tokens)
|
||||
|
||||
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
|
||||
if has_flashinfer() and current_platform.has_device_capability(90):
|
||||
if (
|
||||
has_flashinfer()
|
||||
and current_platform.has_device_capability(90)
|
||||
and flashinfer_autotune_supported(worker.vllm_config)
|
||||
):
|
||||
flashinfer_autotune(worker.model_runner)
|
||||
|
||||
# FlashInfer attention warmup
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user