[Kernel] Enable FusedMoEModularKernel support bias (#27754)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-11-01 10:05:12 +08:00 committed by GitHub
parent 0cdbe7b744
commit bc4486d609
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 30 deletions

View File

@ -15,9 +15,7 @@ from vllm.distributed.parallel_state import (
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
_get_config_dtype_str,
mxfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
modular_marlin_fused_moe,
@ -26,13 +24,16 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe,
try_get_optimal_moe_config,
)
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Config
class FusedMoEWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: FusedMoE) -> None:
super().__init__()
self.base_layer = base_layer
assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.device = base_layer.w2_weight.device
@ -42,17 +43,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
moe_state_dict = {}
top_k = self.base_layer.top_k
if self.base_layer.quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
elif not isinstance(self.base_layer.quant_config, Mxfp4Config):
quant_config = self.base_layer.quant_config
else:
quant_config = mxfp4_w4a16_moe_quant_config(
w1_bias=self.base_layer.w13_bias,
w2_bias=self.base_layer.w2_bias,
w1_scale=self.base_layer.w13_weight_scale,
w2_scale=self.base_layer.w2_weight_scale,
)
self.base_layer.ensure_moe_quant_config_init()
quant_config = self.base_layer.quant_method.moe_quant_config
m_fused_moe_fn = (
modular_triton_fused_moe(
@ -69,7 +61,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
moe_state_dict["global_num_experts"] = kwargs["global_num_experts"]
moe_state_dict["expert_map"] = kwargs["expert_map"]
moe_state_dict["apply_router_weight_on_input"] = kwargs[
"apply_router_weight_on_input"
@ -86,7 +77,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
hidden_states = moe_state_dict["hidden_states"]
topk_weights = moe_state_dict["topk_weights"]
curr_topk_ids = moe_state_dict["topk_ids"]
global_num_experts = moe_state_dict["global_num_experts"]
expert_map = moe_state_dict["expert_map"]
config_dtype = _get_config_dtype_str(
@ -118,7 +109,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
curr_topk_ids,
num_tokens,
config["BLOCK_SIZE_M"],
global_num_experts,
self.base_layer.local_num_experts,
max_loras,
expert_map,
)
@ -236,14 +227,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
) -> None:
"""Initializes lora matrices."""
assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
self.w1_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
self.base_layer.hidden_size,
),
@ -253,7 +240,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w1_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank,
),
@ -264,7 +251,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w2_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
self.base_layer.intermediate_size_per_partition,
),
@ -274,7 +261,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w2_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.local_num_experts,
self.base_layer.hidden_size,
lora_config.max_lora_rank,
),
@ -285,7 +272,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w3_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
self.base_layer.hidden_size,
),
@ -295,7 +282,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w3_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank,
),
@ -308,7 +295,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.lora_a_stacked = []
self.lora_b_stacked = []
for lora_id in range(max_loras):
for experts_id in range(self.base_layer.global_num_experts):
for experts_id in range(self.base_layer.local_num_experts):
# gate_proj,down_proj,up_proj
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])

View File

@ -672,8 +672,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif self.fused_experts is not None:
if self.moe.has_bias:
raise ValueError("FusedMoEModularKernel does not support bias.")
result = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,