mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:25:45 +08:00
[Core] Handle MoE LoRA edge cases (#27335)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
8e4ca4d14e
commit
abf3db40ef
@ -74,7 +74,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
moe_state_dict["apply_router_weight_on_input"] = kwargs[
|
||||
"apply_router_weight_on_input"
|
||||
]
|
||||
moe_state_dict["max_loras"] = layer.w1_lora_a_stacked.shape[0]
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
|
||||
@ -89,7 +88,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
curr_topk_ids = moe_state_dict["topk_ids"]
|
||||
global_num_experts = moe_state_dict["global_num_experts"]
|
||||
expert_map = moe_state_dict["expert_map"]
|
||||
max_loras = moe_state_dict["max_loras"]
|
||||
|
||||
config_dtype = _get_config_dtype_str(
|
||||
dtype=hidden_states.dtype,
|
||||
@ -110,6 +108,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
block_shape=layer.quant_method.moe_quant_config.block_shape,
|
||||
)
|
||||
|
||||
max_loras = self.w1_lora_a_stacked.shape[0]
|
||||
config = get_config_func(M)
|
||||
(
|
||||
sorted_token_ids_lora,
|
||||
@ -161,7 +160,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
def wrapper(*args, **kwargs):
|
||||
hidden_states = moe_state_dict["hidden_states"]
|
||||
topk_weights = moe_state_dict["topk_weights"]
|
||||
max_loras = moe_state_dict["max_loras"]
|
||||
|
||||
config_dtype = _get_config_dtype_str(
|
||||
dtype=hidden_states.dtype,
|
||||
@ -189,7 +187,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
num_tokens_post_padded_lora = moe_state_dict[
|
||||
"num_tokens_post_padded_lora"
|
||||
]
|
||||
|
||||
max_loras = self.w1_lora_a_stacked.shape[0]
|
||||
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
|
||||
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
|
||||
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
|
||||
@ -305,12 +303,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.base_layer.w1_lora_a_stacked = self.w1_lora_a_stacked
|
||||
self.base_layer.w1_lora_b_stacked = self.w1_lora_b_stacked
|
||||
self.base_layer.w2_lora_a_stacked = self.w2_lora_a_stacked
|
||||
self.base_layer.w2_lora_b_stacked = self.w2_lora_b_stacked
|
||||
self.base_layer.w3_lora_a_stacked = self.w3_lora_a_stacked
|
||||
self.base_layer.w3_lora_b_stacked = self.w3_lora_b_stacked
|
||||
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
|
||||
# to create a dummy LoRA weights.
|
||||
self.lora_a_stacked = []
|
||||
@ -343,6 +335,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
embeddings_tensor: torch.Tensor | None,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
"""Overwrites lora tensors at index."""
|
||||
for eid in range(len(lora_a) // 3):
|
||||
w1_lora_a = lora_a[eid * 3]
|
||||
@ -352,6 +345,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
w2_lora_b = lora_b[eid * 3 + 1]
|
||||
w3_lora_b = lora_b[eid * 3 + 2]
|
||||
|
||||
# Handle the case of adding LoRA to only a subset of experts
|
||||
if w1_lora_a is None or w2_lora_a is None or w3_lora_a is None:
|
||||
continue
|
||||
|
||||
if self.tp_size > 1:
|
||||
shard_size = self.base_layer.intermediate_size_per_partition
|
||||
start_idx = self.tp_rank * shard_size
|
||||
|
||||
@ -426,7 +426,6 @@ class LoRAModelManager:
|
||||
for module_name, module in self.modules.items():
|
||||
module_lora = self._get_lora_layer_weights(lora_model, module_name)
|
||||
if module_lora:
|
||||
module_lora.optimize()
|
||||
# Note (gnovack) - If MOE lora weights are not split into
|
||||
# num_experts chunks, we split them here
|
||||
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user