mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-05 09:51:22 +08:00
[Kernel] Enable FusedMoEModularKernel support bias (#27754)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
0cdbe7b744
commit
bc4486d609
@ -15,9 +15,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
from vllm.lora.layers.base import BaseLayerWithLoRA
|
from vllm.lora.layers.base import BaseLayerWithLoRA
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
|
||||||
_get_config_dtype_str,
|
_get_config_dtype_str,
|
||||||
mxfp4_w4a16_moe_quant_config,
|
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||||
modular_marlin_fused_moe,
|
modular_marlin_fused_moe,
|
||||||
@ -26,13 +24,16 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|||||||
modular_triton_fused_moe,
|
modular_triton_fused_moe,
|
||||||
try_get_optimal_moe_config,
|
try_get_optimal_moe_config,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Config
|
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||||
def __init__(self, base_layer: FusedMoE) -> None:
|
def __init__(self, base_layer: FusedMoE) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.base_layer = base_layer
|
self.base_layer = base_layer
|
||||||
|
|
||||||
|
assert not self.base_layer.use_ep, (
|
||||||
|
"EP support for Fused MoE LoRA is not implemented yet."
|
||||||
|
)
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.device = base_layer.w2_weight.device
|
self.device = base_layer.w2_weight.device
|
||||||
@ -42,17 +43,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
moe_state_dict = {}
|
moe_state_dict = {}
|
||||||
top_k = self.base_layer.top_k
|
top_k = self.base_layer.top_k
|
||||||
|
|
||||||
if self.base_layer.quant_config is None:
|
self.base_layer.ensure_moe_quant_config_init()
|
||||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
quant_config = self.base_layer.quant_method.moe_quant_config
|
||||||
elif not isinstance(self.base_layer.quant_config, Mxfp4Config):
|
|
||||||
quant_config = self.base_layer.quant_config
|
|
||||||
else:
|
|
||||||
quant_config = mxfp4_w4a16_moe_quant_config(
|
|
||||||
w1_bias=self.base_layer.w13_bias,
|
|
||||||
w2_bias=self.base_layer.w2_bias,
|
|
||||||
w1_scale=self.base_layer.w13_weight_scale,
|
|
||||||
w2_scale=self.base_layer.w2_weight_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
m_fused_moe_fn = (
|
m_fused_moe_fn = (
|
||||||
modular_triton_fused_moe(
|
modular_triton_fused_moe(
|
||||||
@ -69,7 +61,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
|
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
|
||||||
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
|
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
|
||||||
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
|
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
|
||||||
moe_state_dict["global_num_experts"] = kwargs["global_num_experts"]
|
|
||||||
moe_state_dict["expert_map"] = kwargs["expert_map"]
|
moe_state_dict["expert_map"] = kwargs["expert_map"]
|
||||||
moe_state_dict["apply_router_weight_on_input"] = kwargs[
|
moe_state_dict["apply_router_weight_on_input"] = kwargs[
|
||||||
"apply_router_weight_on_input"
|
"apply_router_weight_on_input"
|
||||||
@ -86,7 +77,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
hidden_states = moe_state_dict["hidden_states"]
|
hidden_states = moe_state_dict["hidden_states"]
|
||||||
topk_weights = moe_state_dict["topk_weights"]
|
topk_weights = moe_state_dict["topk_weights"]
|
||||||
curr_topk_ids = moe_state_dict["topk_ids"]
|
curr_topk_ids = moe_state_dict["topk_ids"]
|
||||||
global_num_experts = moe_state_dict["global_num_experts"]
|
|
||||||
expert_map = moe_state_dict["expert_map"]
|
expert_map = moe_state_dict["expert_map"]
|
||||||
|
|
||||||
config_dtype = _get_config_dtype_str(
|
config_dtype = _get_config_dtype_str(
|
||||||
@ -118,7 +109,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
curr_topk_ids,
|
curr_topk_ids,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
config["BLOCK_SIZE_M"],
|
config["BLOCK_SIZE_M"],
|
||||||
global_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
max_loras,
|
max_loras,
|
||||||
expert_map,
|
expert_map,
|
||||||
)
|
)
|
||||||
@ -236,14 +227,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Initializes lora matrices."""
|
"""Initializes lora matrices."""
|
||||||
|
|
||||||
assert not self.base_layer.use_ep, (
|
|
||||||
"EP support for Fused MoE LoRA is not implemented yet."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.w1_lora_a_stacked = torch.zeros(
|
self.w1_lora_a_stacked = torch.zeros(
|
||||||
(
|
(
|
||||||
max_loras,
|
max_loras,
|
||||||
self.base_layer.global_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank,
|
||||||
self.base_layer.hidden_size,
|
self.base_layer.hidden_size,
|
||||||
),
|
),
|
||||||
@ -253,7 +240,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.w1_lora_b_stacked = torch.zeros(
|
self.w1_lora_b_stacked = torch.zeros(
|
||||||
(
|
(
|
||||||
max_loras,
|
max_loras,
|
||||||
self.base_layer.global_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
self.base_layer.intermediate_size_per_partition,
|
self.base_layer.intermediate_size_per_partition,
|
||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank,
|
||||||
),
|
),
|
||||||
@ -264,7 +251,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.w2_lora_a_stacked = torch.zeros(
|
self.w2_lora_a_stacked = torch.zeros(
|
||||||
(
|
(
|
||||||
max_loras,
|
max_loras,
|
||||||
self.base_layer.global_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank,
|
||||||
self.base_layer.intermediate_size_per_partition,
|
self.base_layer.intermediate_size_per_partition,
|
||||||
),
|
),
|
||||||
@ -274,7 +261,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.w2_lora_b_stacked = torch.zeros(
|
self.w2_lora_b_stacked = torch.zeros(
|
||||||
(
|
(
|
||||||
max_loras,
|
max_loras,
|
||||||
self.base_layer.global_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
self.base_layer.hidden_size,
|
self.base_layer.hidden_size,
|
||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank,
|
||||||
),
|
),
|
||||||
@ -285,7 +272,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.w3_lora_a_stacked = torch.zeros(
|
self.w3_lora_a_stacked = torch.zeros(
|
||||||
(
|
(
|
||||||
max_loras,
|
max_loras,
|
||||||
self.base_layer.global_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank,
|
||||||
self.base_layer.hidden_size,
|
self.base_layer.hidden_size,
|
||||||
),
|
),
|
||||||
@ -295,7 +282,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.w3_lora_b_stacked = torch.zeros(
|
self.w3_lora_b_stacked = torch.zeros(
|
||||||
(
|
(
|
||||||
max_loras,
|
max_loras,
|
||||||
self.base_layer.global_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
self.base_layer.intermediate_size_per_partition,
|
self.base_layer.intermediate_size_per_partition,
|
||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank,
|
||||||
),
|
),
|
||||||
@ -308,7 +295,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.lora_a_stacked = []
|
self.lora_a_stacked = []
|
||||||
self.lora_b_stacked = []
|
self.lora_b_stacked = []
|
||||||
for lora_id in range(max_loras):
|
for lora_id in range(max_loras):
|
||||||
for experts_id in range(self.base_layer.global_num_experts):
|
for experts_id in range(self.base_layer.local_num_experts):
|
||||||
# gate_proj,down_proj,up_proj
|
# gate_proj,down_proj,up_proj
|
||||||
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
|
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
|
||||||
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
|
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
|
||||||
|
|||||||
@ -672,8 +672,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
elif self.fused_experts is not None:
|
elif self.fused_experts is not None:
|
||||||
if self.moe.has_bias:
|
|
||||||
raise ValueError("FusedMoEModularKernel does not support bias.")
|
|
||||||
result = self.fused_experts(
|
result = self.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user