[EPLB] Support EPLB w/ NVFP4 (#29804)

Signed-off-by: Andrew Briand <abriand@nvidia.com>
Co-authored-by: Andrew Briand <abriand@nvidia.com>
This commit is contained in:
Andrew Briand 2025-12-11 16:59:40 -06:00 committed by GitHub
parent 61249b177d
commit a00d88973d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 376 additions and 5 deletions

View File

@ -0,0 +1,276 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Test that the interaction between EPLB and FusedMoE Layer is okay for DP w/ NVFP4
from dataclasses import dataclass
import pytest
import torch
from tests.kernels.moe.utils import make_test_quant_config
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
get_dp_group,
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config,
ModelOptNvFp4FusedMoE,
)
from .eplb_utils import distributed_run, set_env_vars_and_device
@dataclass
class TestConfig:
num_layers: int
num_experts: int
num_local_experts: int
num_topk: int
hidden_size: int
intermediate_size: int
num_tokens: int
def make_fused_moe_layer(
rank: int,
layer_idx: int,
test_config: TestConfig,
) -> FusedMoE:
quant_config = None
device = torch.device(f"cuda:{rank}")
quant_config = ModelOptNvFp4Config(
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=None,
exclude_modules=[],
)
fml = FusedMoE(
num_experts=test_config.num_experts,
top_k=test_config.num_topk,
hidden_size=test_config.hidden_size,
intermediate_size=test_config.intermediate_size,
prefix=f"dummy_layer_{layer_idx}",
activation="silu",
is_act_and_mul=True,
params_dtype=torch.bfloat16,
quant_config=quant_config,
)
nvfp4_fused_moe = ModelOptNvFp4FusedMoE(quant_config, fml)
nvfp4_fused_moe.create_weights(
fml,
test_config.num_local_experts,
test_config.hidden_size,
test_config.intermediate_size,
params_dtype=torch.uint8,
global_num_experts=test_config.num_experts,
)
fml = fml.to(device)
w1_q, w2_q, quant_config = make_test_quant_config(
test_config.num_local_experts,
test_config.intermediate_size,
test_config.hidden_size,
in_dtype=torch.bfloat16,
quant_dtype="nvfp4",
block_shape=None,
per_act_token_quant=False,
)
fml.w13_weight.data = w1_q
fml.w2_weight.data = w2_q
fml.w2_input_scale.data = torch.randn_like(fml.w2_input_scale.data) / 5
fml.w13_input_scale.data = torch.randn_like(fml.w13_input_scale.data) / 5
fml.w2_weight_scale_2.data = torch.randn_like(fml.w2_weight_scale_2.data) / 5
fml.w13_weight_scale_2.data = torch.randn_like(fml.w13_weight_scale_2.data) / 5
fml.w2_weight_scale.data = (
torch.randn(fml.w2_weight_scale.data.shape, device=device) / 5
).to(fml.w2_weight_scale.data.dtype)
fml.w13_weight_scale.data = (
torch.randn(fml.w13_weight_scale.data.shape, device=device) / 5
).to(fml.w13_weight_scale.data.dtype)
nvfp4_fused_moe.process_weights_after_loading(fml)
fml.maybe_init_modular_kernel()
return fml
def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
set_env_vars_and_device(env)
vllm_config = VllmConfig()
vllm_config.parallel_config.data_parallel_size = world_size
vllm_config.parallel_config.enable_expert_parallel = True
with set_current_vllm_config(vllm_config):
ensure_model_parallel_initialized(
tensor_model_parallel_size=1, pipeline_model_parallel_size=1
)
ep_group = get_dp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
fml_layers = [
make_fused_moe_layer(ep_rank, layer_idx, test_config).to(device)
for layer_idx in range(test_config.num_layers)
]
rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers]
hidden_states = []
router_logits = []
for layer_idx in range(test_config.num_layers):
hidden_states.append(
torch.randn(
(test_config.num_tokens, test_config.hidden_size),
dtype=torch.bfloat16,
device=device,
)
)
router_logits.append(
torch.randn(
(test_config.num_tokens, test_config.num_experts),
dtype=torch.bfloat16,
device=device,
)
)
out_before_shuffle = []
with set_forward_context(
{},
num_tokens=test_config.num_tokens,
num_tokens_across_dp=torch.tensor(
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
),
vllm_config=vllm_config,
):
for lidx, fml in enumerate(fml_layers):
out_before_shuffle.append(
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
)
indices = torch.zeros(
test_config.num_layers, test_config.num_experts, dtype=torch.long
)
for lidx in range(test_config.num_layers):
indices[lidx] = torch.Tensor(range(test_config.num_experts))
shuffled_indices = torch.zeros_like(indices)
for lidx in range(test_config.num_layers):
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
rearrange_expert_weights_inplace(
indices,
shuffled_indices,
rank_expert_weights,
ep_group,
is_profile=False,
)
num_global_experts = test_config.num_experts
logical_to_physical_map_list = []
for lidx, fml in enumerate(fml_layers):
physical_to_logical_map = shuffled_indices[lidx].to(device)
logical_to_physical_map = torch.empty(
(num_global_experts,), dtype=torch.int32, device=device
)
logical_to_physical_map[physical_to_logical_map] = torch.arange(
0, num_global_experts, dtype=torch.int32, device=device
)
logical_to_physical_map_list.append(
logical_to_physical_map.reshape(num_global_experts, 1)
)
logical_to_physical_map = torch.stack(logical_to_physical_map_list)
for lidx, fml in enumerate(fml_layers):
logical_replica_count = torch.ones(
(test_config.num_layers, num_global_experts),
dtype=torch.int32,
device=device,
)
fml.enable_eplb = True
fml.set_eplb_state(
lidx,
torch.zeros(
(test_config.num_layers, num_global_experts),
dtype=torch.int32,
device=device,
),
logical_to_physical_map,
logical_replica_count,
)
out_after_shuffle = []
with set_forward_context(
{},
num_tokens=test_config.num_tokens,
num_tokens_across_dp=torch.tensor(
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
),
vllm_config=vllm_config,
):
for lidx, fml in enumerate(fml_layers):
out_after_shuffle.append(
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
)
for lidx in range(test_config.num_layers):
torch.testing.assert_close(
out_before_shuffle[lidx], out_after_shuffle[lidx], atol=1e-1, rtol=1e-1
)
@pytest.mark.parametrize("world_size", [2, 4])
@pytest.mark.parametrize("num_layers", [8])
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("intermediate_size", [256])
@pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("backend", ["latency", "throughput"])
def test_eplb_fml(
world_size: int,
num_layers: int,
num_experts: int,
hidden_size: int,
intermediate_size: int,
num_tokens: int,
backend: str,
monkeypatch,
):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", backend)
if torch.cuda.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test")
num_local_experts = num_experts // world_size
num_topk = 4
test_config = TestConfig(
num_layers=num_layers,
num_experts=num_experts,
num_local_experts=num_local_experts,
num_topk=num_topk,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_tokens=num_tokens,
)
distributed_run(
_test_eplb_fml,
world_size,
test_config,
)

View File

@ -38,6 +38,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
prepare_static_weights_for_trtllm_fp4_moe,
reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl,
@ -1325,7 +1326,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"Accuracy may be affected."
)
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
# Common processing for input scales and alphas
@ -1482,6 +1483,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
a2_gscale=layer.w2_input_scale_quant,
)
@property
def supports_eplb(self) -> bool:
return True
def apply(
self,
layer: FusedMoE,
@ -1500,11 +1505,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
and not layer.enable_eplb
):
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
)
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
@ -1522,6 +1524,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
router_logits=router_logits,
)
# EPLB path
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
return flashinfer_trtllm_fp4_routed_moe(
layer=layer,
x=x,
topk_ids=topk_ids,
topk_weights=topk_weights,
top_k=layer.top_k,
global_num_experts=layer.global_num_experts,
)
if self.use_marlin:
return fused_marlin_moe(
x,

View File

@ -331,3 +331,82 @@ def flashinfer_trtllm_fp4_moe(
)[0]
return out
def flashinfer_trtllm_fp4_routed_moe(
layer: torch.nn.Module,
x: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
top_k: int,
global_num_experts: int,
) -> torch.Tensor:
"""
Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
input top k expert indices and scores rather than computing
top k expert indices from scores.
Args:
layer: The MoE layer with weights and scales
x: Input tensor
topk_ids: Ids of selected experts
top_k: Number of experts to select per token
global_num_experts: Total number of experts across all ranks
Returns:
Output tensor from the MoE layer
"""
import flashinfer
# Pack top k ids and expert weights into a single int32 tensor, as
# required by TRT-LLM
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16
).view(torch.int16)
# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
topk_ids=packed_tensor,
routing_bias=None,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).flatten(),
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=0,
topk_group=0,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
tile_tokens_dim=None,
routing_method_type=1,
do_finalize=True,
)[0]
return out