mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 00:17:09 +08:00
Bugfix: Cutlass FP8 FusedMoE bad scaling factors (#27255)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
b57789b62b
commit
6b7a81185d
@ -6,7 +6,10 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEQuantConfig,
|
||||||
|
fp8_w8a8_moe_quant_config,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
@ -22,10 +25,10 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||||
|
|
||||||
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
|
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
|
||||||
100
|
90
|
||||||
):
|
):
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"Requires flashinfer_cutlass_fused_moe and nvfp4 support",
|
"Supported for sm >= 90",
|
||||||
allow_module_level=True,
|
allow_module_level=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -131,6 +134,8 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
|||||||
topk: int,
|
topk: int,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
):
|
):
|
||||||
|
if not current_platform.has_device_capability(100):
|
||||||
|
pytest.skip("Test is only supported for sm >= 100")
|
||||||
current_platform.seed_everything(7)
|
current_platform.seed_everything(7)
|
||||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
@ -184,9 +189,6 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
|||||||
torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
|
torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
|
||||||
"Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472"
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
@ -216,9 +218,13 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
|||||||
|
|
||||||
quant_config = fp8_w8a8_moe_quant_config(
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
w1_scale=td.w13_weight_scale,
|
w1_scale=td.w13_weight_scale,
|
||||||
|
g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
|
||||||
w2_scale=td.w2_weight_scale,
|
w2_scale=td.w2_weight_scale,
|
||||||
|
g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
|
||||||
a1_scale=td.a1_scale,
|
a1_scale=td.a1_scale,
|
||||||
|
a1_gscale=td.a1_scale,
|
||||||
a2_scale=td.a2_scale,
|
a2_scale=td.a2_scale,
|
||||||
|
a2_gscale=1.0 / td.a2_scale,
|
||||||
per_act_token_quant=False,
|
per_act_token_quant=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -238,6 +244,12 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
|||||||
|
|
||||||
td.layer.dp_size = 1
|
td.layer.dp_size = 1
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||||
|
return quant_config
|
||||||
|
|
||||||
|
td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
|
||||||
|
td.layer.quant_method = td.layer
|
||||||
|
|
||||||
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
|
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
|
||||||
td.hidden_states,
|
td.hidden_states,
|
||||||
td.layer,
|
td.layer,
|
||||||
|
|||||||
@ -463,6 +463,10 @@ def fp8_w8a8_moe_quant_config(
|
|||||||
per_act_token_quant: bool = False,
|
per_act_token_quant: bool = False,
|
||||||
per_out_ch_quant: bool = False,
|
per_out_ch_quant: bool = False,
|
||||||
block_shape: list[int] | None = None,
|
block_shape: list[int] | None = None,
|
||||||
|
a1_gscale: torch.Tensor | None = None,
|
||||||
|
a2_gscale: torch.Tensor | None = None,
|
||||||
|
g1_alphas: torch.Tensor | None = None,
|
||||||
|
g2_alphas: torch.Tensor | None = None,
|
||||||
) -> FusedMoEQuantConfig:
|
) -> FusedMoEQuantConfig:
|
||||||
"""
|
"""
|
||||||
Construct a quant config for fp8 activations and fp8 weights.
|
Construct a quant config for fp8 activations and fp8 weights.
|
||||||
@ -470,9 +474,13 @@ def fp8_w8a8_moe_quant_config(
|
|||||||
return FusedMoEQuantConfig.make(
|
return FusedMoEQuantConfig.make(
|
||||||
torch.float8_e4m3fn,
|
torch.float8_e4m3fn,
|
||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
|
g1_alphas=g1_alphas,
|
||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
|
g2_alphas=g2_alphas,
|
||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
|
a1_gscale=a1_gscale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
|
a2_gscale=a2_gscale,
|
||||||
per_act_token_quant=per_act_token_quant,
|
per_act_token_quant=per_act_token_quant,
|
||||||
per_out_ch_quant=per_out_ch_quant,
|
per_out_ch_quant=per_out_ch_quant,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
|
|||||||
@ -170,7 +170,7 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
|||||||
self._apply_router_weight_on_input(
|
self._apply_router_weight_on_input(
|
||||||
a1, topk_weights, topk_ids, apply_router_weight_on_input
|
a1, topk_weights, topk_ids, apply_router_weight_on_input
|
||||||
)
|
)
|
||||||
if not self.use_dp:
|
if not self.use_dp and quant_config.quant_dtype == "nvfp4":
|
||||||
return a1, None, None, topk_ids, topk_weights
|
return a1, None, None, topk_ids, topk_weights
|
||||||
|
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
@ -181,11 +181,13 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
|||||||
quant_config.block_shape,
|
quant_config.block_shape,
|
||||||
is_fp4_scale_swizzled=not self.use_dp,
|
is_fp4_scale_swizzled=not self.use_dp,
|
||||||
)
|
)
|
||||||
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
|
|
||||||
[topk_weights, topk_ids, a1q, a1q_scale],
|
if self.use_dp:
|
||||||
dim=0,
|
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
|
||||||
sizes=get_local_sizes(),
|
[topk_weights, topk_ids, a1q, a1q_scale],
|
||||||
)
|
dim=0,
|
||||||
|
sizes=get_local_sizes(),
|
||||||
|
)
|
||||||
if quant_config.quant_dtype == "nvfp4":
|
if quant_config.quant_dtype == "nvfp4":
|
||||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||||
|
|
||||||
|
|||||||
@ -567,9 +567,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
return fp8_w8a8_moe_quant_config(
|
return fp8_w8a8_moe_quant_config(
|
||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(),
|
||||||
w2_scale=layer.w2_weight_scale,
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(),
|
||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a1_gscale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
|
a2_gscale=1.0 / layer.w2_input_scale,
|
||||||
per_act_token_quant=False,
|
per_act_token_quant=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1138,8 +1142,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
moe: FusedMoEConfig,
|
moe: FusedMoEConfig,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
) -> None:
|
) -> None:
|
||||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
|
||||||
detect_nvfp4_moe_support,
|
detect_nvfp4_moe_support, # noqa: E501
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(moe)
|
super().__init__(moe)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user