diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 24cab79a72443..a114f2fa72e36 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -672,20 +672,9 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA): self.reset_lora(index) self.adapter_enabled[index] = 1 - num_experts = self.w13_lora_a_stacked[0].shape[1] w13_lora_a, w2_lora_a = lora_a w13_lora_b, w2_lora_b = lora_b - # (num_experts,rank,input_size) - w13_lora_a = w13_lora_a.reshape(num_experts, -1, w13_lora_a.shape[-1]) - w2_lora_a = w2_lora_a.reshape(num_experts, -1, w2_lora_a.shape[-1]) - # (output_size,num_experts,rank) - w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], num_experts, -1) - w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], num_experts, -1) - # (num_experts,output_size,rank) - w13_lora_b = w13_lora_b.permute(1, 0, 2) - w2_lora_b = w2_lora_b.permute(1, 0, 2) - sliced_w13_lora_a = self._slice_w13_a(w13_lora_a) sliced_w13_lora_b = self._slice_w13_b(w13_lora_b) diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 44e0448d92de0..31688e7054daf 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -171,6 +171,37 @@ class LoRAModelManager: assert gate_up_proj_lora is not None assert down_proj_lora is not None if self._is_3d_moe_model: + num_experts = module.w13_lora_a_stacked[0].shape[1] + lora_weight_device = gate_up_proj_lora.lora_b.device + + # (num_experts,rank,input_size) + gate_up_proj_lora.lora_a = gate_up_proj_lora.lora_a.reshape( + num_experts, -1, gate_up_proj_lora.lora_a.shape[-1] + ) + down_proj_lora.lora_a = down_proj_lora.lora_a.reshape( + num_experts, -1, down_proj_lora.lora_a.shape[-1] + ) + + # (output_size,num_experts,rank) + gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.reshape( + gate_up_proj_lora.lora_b.shape[0], num_experts, -1 + ) + down_proj_lora.lora_b = down_proj_lora.lora_b.reshape( + down_proj_lora.lora_b.shape[0], num_experts, -1 + ) + + # (num_experts,output_size,rank) + gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.permute( + 1, 0, 2 + ).contiguous() + down_proj_lora.lora_b = down_proj_lora.lora_b.permute( + 1, 0, 2 + ).contiguous() + + if str(lora_weight_device) == "cpu" and is_pin_memory_available(): + gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.pin_memory() + down_proj_lora.lora_b = down_proj_lora.lora_b.pin_memory() + module_lora.lora_a = [ gate_up_proj_lora.lora_a, down_proj_lora.lora_a,