Feature: Support Relu2 in FusedMoE fp8 cutlass path (#27261)

This commit is contained in:
amirkl94 2025-11-16 20:39:44 +02:00 committed by GitHub
parent 5a87076d6e
commit 03ee48111d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 42 additions and 20 deletions

View File

@ -77,10 +77,14 @@ class TestData:
@staticmethod
def make_moe_tensors_8bit(
m: int, k: int, n: int, e: int, reorder: bool
m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu"
) -> "TestData":
is_gated = activation != "relu2_no_mul"
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
w13 = torch.randn(
(e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
)
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
# Scale to fp8
@ -190,18 +194,22 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"])
def test_flashinfer_cutlass_moe_fp8_no_graph(
m: int,
n: int,
k: int,
e: int,
topk: int,
activation: str,
monkeypatch,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
td = TestData.make_moe_tensors_8bit(
m, k, n, e, reorder=False, activation=activation
)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
@ -233,7 +241,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
activation="silu",
activation=activation,
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=True,
@ -253,7 +261,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td.layer,
topk_weights,
topk_ids,
activation="silu",
activation=activation,
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=True,

View File

@ -148,8 +148,14 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool | None,
):
assert activation == "silu", (
"Only activation silu is supported in FlashInferExperts"
from flashinfer.fused_moe.core import ActivationType
activation_str_to_value_map = {
"silu": ActivationType.Swiglu, # This is the default
"relu2_no_mul": ActivationType.Relu2,
}
assert activation in activation_str_to_value_map, (
f"{activation=} missing from {activation_str_to_value_map.keys()=}"
)
# Select quantization metadata based on FP8 format/path
@ -215,6 +221,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
ep_size=self.ep_size,
ep_rank=self.ep_rank,
output=output,
activation_type=activation_str_to_value_map[activation],
# Informs FlashInfer to use the block-scale decoding path when True
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
)

View File

@ -354,12 +354,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if (
envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
and self.moe.is_act_and_mul
):
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
if (
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
and not self.moe.is_act_and_mul
):
logger.info_once(
"Non-gated MoE is not supported for min-latency mode,"
"falling back to high-throughput mode"
)
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
)
@ -557,10 +563,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
if self.flashinfer_moe_backend is not None:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
register_moe_scaling_factors(layer)
if self.moe.is_act_and_mul:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
register_moe_scaling_factors(layer)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
@ -570,13 +577,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(),
g1_alphas=layer.output1_scales_gate_scalar.squeeze(),
w2_scale=layer.w2_weight_scale,
g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(),
g2_alphas=layer.output2_scales_scalar.squeeze(),
a1_scale=layer.w13_input_scale,
a1_gscale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
a2_gscale=1.0 / layer.w2_input_scale,
a2_gscale=layer.w2_input_scale_inv,
per_act_token_quant=False,
)
@ -642,9 +649,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert not renormalize
assert activation == "silu", (
f"Expected 'silu' activation but got {activation}"
assert activation in ("silu", "relu2_no_mul"), (
"Expected activation to be in ('silu', 'relu2_no_mul'),"
f"but got {activation}"
)
return flashinfer_cutlass_moe_fp8(
x,