mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 04:34:59 +08:00
Feature: Support Relu2 in FusedMoE fp8 cutlass path (#27261)
This commit is contained in:
parent
5a87076d6e
commit
03ee48111d
@ -77,10 +77,14 @@ class TestData:
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(
|
||||
m: int, k: int, n: int, e: int, reorder: bool
|
||||
m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu"
|
||||
) -> "TestData":
|
||||
is_gated = activation != "relu2_no_mul"
|
||||
|
||||
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
|
||||
w13 = torch.randn(
|
||||
(e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Scale to fp8
|
||||
@ -190,18 +194,22 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"])
|
||||
def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
activation: str,
|
||||
monkeypatch,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
|
||||
td = TestData.make_moe_tensors_8bit(
|
||||
m, k, n, e, reorder=False, activation=activation
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
@ -233,7 +241,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
activation=activation,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=True,
|
||||
@ -253,7 +261,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
td.layer,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation="silu",
|
||||
activation=activation,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=True,
|
||||
|
||||
@ -148,8 +148,14 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool | None,
|
||||
):
|
||||
assert activation == "silu", (
|
||||
"Only activation silu is supported in FlashInferExperts"
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
activation_str_to_value_map = {
|
||||
"silu": ActivationType.Swiglu, # This is the default
|
||||
"relu2_no_mul": ActivationType.Relu2,
|
||||
}
|
||||
assert activation in activation_str_to_value_map, (
|
||||
f"{activation=} missing from {activation_str_to_value_map.keys()=}"
|
||||
)
|
||||
|
||||
# Select quantization metadata based on FP8 format/path
|
||||
@ -215,6 +221,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
output=output,
|
||||
activation_type=activation_str_to_value_map[activation],
|
||||
# Informs FlashInfer to use the block-scale decoding path when True
|
||||
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
|
||||
)
|
||||
|
||||
@ -354,12 +354,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
|
||||
if (
|
||||
envs.VLLM_USE_FLASHINFER_MOE_FP8
|
||||
and has_flashinfer_moe()
|
||||
and self.moe.is_act_and_mul
|
||||
):
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
if (
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
and not self.moe.is_act_and_mul
|
||||
):
|
||||
logger.info_once(
|
||||
"Non-gated MoE is not supported for min-latency mode,"
|
||||
"falling back to high-throughput mode"
|
||||
)
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||
|
||||
logger.info_once(
|
||||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||
)
|
||||
@ -557,10 +563,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
if self.flashinfer_moe_backend is not None:
|
||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||
register_moe_scaling_factors(layer)
|
||||
if self.moe.is_act_and_mul:
|
||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
|
||||
register_moe_scaling_factors(layer)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
@ -570,13 +577,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(),
|
||||
g1_alphas=layer.output1_scales_gate_scalar.squeeze(),
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(),
|
||||
g2_alphas=layer.output2_scales_scalar.squeeze(),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a1_gscale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
a2_gscale=1.0 / layer.w2_input_scale,
|
||||
a2_gscale=layer.w2_input_scale_inv,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
@ -642,9 +649,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
assert not renormalize
|
||||
assert activation == "silu", (
|
||||
f"Expected 'silu' activation but got {activation}"
|
||||
assert activation in ("silu", "relu2_no_mul"), (
|
||||
"Expected activation to be in ('silu', 'relu2_no_mul'),"
|
||||
f"but got {activation}"
|
||||
)
|
||||
return flashinfer_cutlass_moe_fp8(
|
||||
x,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user