[Core] Handle MoE LoRA edge cases (#27335)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-10-22 21:14:33 +08:00 committed by GitHub
parent 8e4ca4d14e
commit abf3db40ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 11 deletions

View File

@ -74,7 +74,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
moe_state_dict["apply_router_weight_on_input"] = kwargs[
"apply_router_weight_on_input"
]
moe_state_dict["max_loras"] = layer.w1_lora_a_stacked.shape[0]
result = func(*args, **kwargs)
return result
@ -89,7 +88,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
curr_topk_ids = moe_state_dict["topk_ids"]
global_num_experts = moe_state_dict["global_num_experts"]
expert_map = moe_state_dict["expert_map"]
max_loras = moe_state_dict["max_loras"]
config_dtype = _get_config_dtype_str(
dtype=hidden_states.dtype,
@ -110,6 +108,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
block_shape=layer.quant_method.moe_quant_config.block_shape,
)
max_loras = self.w1_lora_a_stacked.shape[0]
config = get_config_func(M)
(
sorted_token_ids_lora,
@ -161,7 +160,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
def wrapper(*args, **kwargs):
hidden_states = moe_state_dict["hidden_states"]
topk_weights = moe_state_dict["topk_weights"]
max_loras = moe_state_dict["max_loras"]
config_dtype = _get_config_dtype_str(
dtype=hidden_states.dtype,
@ -189,7 +187,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
num_tokens_post_padded_lora = moe_state_dict[
"num_tokens_post_padded_lora"
]
max_loras = self.w1_lora_a_stacked.shape[0]
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
@ -305,12 +303,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
device=self.device,
)
self.base_layer.w1_lora_a_stacked = self.w1_lora_a_stacked
self.base_layer.w1_lora_b_stacked = self.w1_lora_b_stacked
self.base_layer.w2_lora_a_stacked = self.w2_lora_a_stacked
self.base_layer.w2_lora_b_stacked = self.w2_lora_b_stacked
self.base_layer.w3_lora_a_stacked = self.w3_lora_a_stacked
self.base_layer.w3_lora_b_stacked = self.w3_lora_b_stacked
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
# to create a dummy LoRA weights.
self.lora_a_stacked = []
@ -343,6 +335,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
embeddings_tensor: torch.Tensor | None,
bias: torch.Tensor | None = None,
):
self.reset_lora(index)
"""Overwrites lora tensors at index."""
for eid in range(len(lora_a) // 3):
w1_lora_a = lora_a[eid * 3]
@ -352,6 +345,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
w2_lora_b = lora_b[eid * 3 + 1]
w3_lora_b = lora_b[eid * 3 + 2]
# Handle the case of adding LoRA to only a subset of experts
if w1_lora_a is None or w2_lora_a is None or w3_lora_a is None:
continue
if self.tp_size > 1:
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size

View File

@ -426,7 +426,6 @@ class LoRAModelManager:
for module_name, module in self.modules.items():
module_lora = self._get_lora_layer_weights(lora_model, module_name)
if module_lora:
module_lora.optimize()
# Note (gnovack) - If MOE lora weights are not split into
# num_experts chunks, we split them here
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(