[Feat] Support non-gated activations in NVFP4 modelopt path (#29004)

This commit is contained in:
Omer Ullman Argov 2025-11-30 18:02:40 +02:00 committed by GitHub
parent cd719de5cb
commit 39d28108f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 98 additions and 22 deletions

View File

@ -16,11 +16,11 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
is_valid_flashinfer_cutlass_fused_moe,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (
create_flashinfer_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
@ -48,9 +48,10 @@ MNK_FACTORS = [
@pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
@torch.inference_mode()
def test_flashinfer_fp4_moe_no_graph(
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, activation: str
):
current_platform.seed_everything(7)
with set_current_vllm_config(
@ -59,6 +60,7 @@ def test_flashinfer_fp4_moe_no_graph(
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16
is_gated_act = activation == "silu_and_mul"
w1_q, w2_q, quant_config = make_test_quant_config(
e,
@ -68,6 +70,7 @@ def test_flashinfer_fp4_moe_no_graph(
quant_dtype="nvfp4",
block_shape=None,
per_act_token_quant=False,
make_gate=is_gated_act,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
@ -76,16 +79,19 @@ def test_flashinfer_fp4_moe_no_graph(
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
flashinfer_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True),
FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
)
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
flashinfer_output = flashinfer_experts(
hidden_states=a,
w1=w1_q,
w2=w2_q,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=fi_activation,
)
# Reference check:
@ -103,7 +109,9 @@ def test_flashinfer_fp4_moe_no_graph(
block_size=quant_blocksize,
)
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
w1_d = torch.empty(
(e, (2 if is_gated_act else 1) * n, k), device="cuda", dtype=dtype
)
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
@ -124,7 +132,9 @@ def test_flashinfer_fp4_moe_no_graph(
block_size=quant_blocksize,
)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
torch_output = torch_moe(
a_in_dtype, w1_d, w2_d, score, topk, activation=activation
)
torch.testing.assert_close(
torch_output, flashinfer_output, atol=1e-1, rtol=1e-1

View File

@ -264,13 +264,20 @@ def make_test_weights(
quant_dtype: torch.dtype | str | None = None,
block_shape: list[int] | None = None,
per_out_ch_quant: bool = False,
make_gate: bool = True,
) -> tuple[
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
]:
return (
make_test_weight(
e, 2 * n, k, in_dtype, quant_dtype, block_shape, per_out_ch_quant
e,
(2 if make_gate else 1) * n,
k,
in_dtype,
quant_dtype,
block_shape,
per_out_ch_quant,
),
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
)
@ -297,6 +304,7 @@ def make_test_quant_config(
quant_dtype: torch.dtype | str | None = None,
per_act_token_quant: bool = False,
block_shape: list[int] | None = None,
make_gate: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
e,
@ -306,6 +314,7 @@ def make_test_quant_config(
quant_dtype,
per_out_ch_quant=per_act_token_quant,
block_shape=block_shape,
make_gate=make_gate,
)
# Hacky/trivial scales for nvfp4.

View File

@ -14,6 +14,7 @@ from torch._prims_common import TensorLikeType
from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.attention.backends.abstract import AttentionType
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.torch_utils import make_tensor_with_pad
@ -839,6 +840,7 @@ def torch_experts(
per_act_token_quant=False,
block_shape: list[int] | None = None,
apply_router_weights_on_input: bool = False,
activation: str = "silu_and_mul",
) -> torch.Tensor:
assert (
global_num_experts == -1
@ -881,6 +883,8 @@ def torch_experts(
f32 = torch.float32
act = CustomOp.op_registry[activation]
for i in range(num_experts):
mask = topk_ids == i
if mask.sum():
@ -888,7 +892,7 @@ def torch_experts(
tmp1 = a[mask] @ w1[i].transpose(0, 1)
if b_bias1 is not None:
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
tmp2 = SiluAndMul()(tmp1)
tmp2 = act()(tmp1)
out[mask] = tmp2 @ w2[i].transpose(0, 1)
if b_bias2 is not None:
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype)
@ -969,6 +973,7 @@ def torch_moe(
b_bias2: torch.Tensor | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
activation: str = "silu_and_mul",
) -> torch.Tensor:
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
@ -982,6 +987,7 @@ def torch_moe(
b_bias1,
b_bias2,
expert_map,
activation=activation,
)

View File

@ -600,14 +600,20 @@ class FusedMoE(CustomOp):
# Avoid circular import
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptFp8MoEMethod,
ModelOptNvFp4FusedMoE,
)
if not isinstance(
self.quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)
self.quant_method,
(
UnquantizedFusedMoEMethod,
ModelOptFp8MoEMethod,
ModelOptNvFp4FusedMoE,
),
):
raise NotImplementedError(
"is_act_and_mul=False is supported only for unquantized "
"and ModelOpt FP8 moe for now"
", ModelOpt FP8, and ModelOpt NvFp4 checkpoints"
)
if not current_platform.is_cuda():
raise NotImplementedError(
@ -1277,7 +1283,7 @@ class FusedMoE(CustomOp):
self._load_combined_w13_weight_scale(
shard_dim=shard_dim,
loaded_weight=loaded_weight,
param=param,
param=expert_data,
tp_rank=self.tp_rank,
)
return True if return_success else None

View File

@ -1216,7 +1216,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
2 * intermediate_size_per_partition,
(2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
dtype=weight_dtype,
@ -1245,7 +1245,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w13_weight_scale = ModelWeightParameter(
data=torch.empty(
num_experts,
2 * intermediate_size_per_partition,
(2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // self.quant_config.group_size,
dtype=weight_scale_dtype,
@ -1275,7 +1275,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)
w13_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(num_experts, 2, dtype=torch.float32),
data=torch.empty(
num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32
),
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
@ -1296,7 +1298,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_scale_num_experts = global_num_experts if use_global_sf else num_experts
w13_input_scale = PerTensorScaleParameter(
data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32),
data=torch.empty(
global_scale_num_experts,
2 if self.moe.is_act_and_mul else 1,
dtype=torch.float32,
),
weight_loader=weight_loader,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
@ -1312,9 +1318,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
gemm1_weight = layer.w13_weight.data
gemm1_weight_scale = layer.w13_weight_scale.data
if self.allow_flashinfer and (
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
if (
self.allow_flashinfer
and (
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
)
and self.moe.is_act_and_mul
):
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
gemm1_weight, gemm1_weight_scale, dim=-2
@ -1324,7 +1334,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
# Common processing for w13_weight_scale_2
if not torch.allclose(
if self.moe.is_act_and_mul and not torch.allclose(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
):
logger.warning_once(
@ -1437,11 +1447,39 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w13_blockscale_swizzled, requires_grad=False
)
w13_weight = layer.w13_weight
intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1)
if intermediate_size_pad:
# padding gated activations will require to split w1 and w3
# and pad them individually
assert not self.moe.is_act_and_mul, (
"The intermediate size required padding, "
"but padding is not implemented for gated activations"
)
layer.w13_weight = Parameter(
torch.nn.functional.pad(
w13_weight, (0, 0, 0, intermediate_size_pad)
),
requires_grad=False,
)
layer.w2_weight = Parameter(
torch.nn.functional.pad(
layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0)
),
requires_grad=False,
)
layer.w2_weight_scale = Parameter(
torch.nn.functional.pad(
layer.w2_weight_scale, (0, intermediate_size_pad // 16)
),
requires_grad=False,
)
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
layer.w2_weight_scale = Parameter(
w2_blockscale_swizzled, requires_grad=False
)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
@ -1484,7 +1522,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", "Only SiLU activation is supported."
if not self.moe.is_act_and_mul:
assert (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
), (
"Non-gated activations are only supported by the"
" flashinfer CUTLASS backend for modelopt checkpoints"
)
if (
self.allow_flashinfer