mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-04 12:27:05 +08:00
pin lora_b moe weights on cpu
Signed-off-by: gnovack <gnovack@amazon.com>
This commit is contained in:
parent
c016c95b45
commit
627a534f91
@ -672,20 +672,9 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
|
||||
self.reset_lora(index)
|
||||
self.adapter_enabled[index] = 1
|
||||
|
||||
num_experts = self.w13_lora_a_stacked[0].shape[1]
|
||||
w13_lora_a, w2_lora_a = lora_a
|
||||
w13_lora_b, w2_lora_b = lora_b
|
||||
|
||||
# (num_experts,rank,input_size)
|
||||
w13_lora_a = w13_lora_a.reshape(num_experts, -1, w13_lora_a.shape[-1])
|
||||
w2_lora_a = w2_lora_a.reshape(num_experts, -1, w2_lora_a.shape[-1])
|
||||
# (output_size,num_experts,rank)
|
||||
w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], num_experts, -1)
|
||||
w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], num_experts, -1)
|
||||
# (num_experts,output_size,rank)
|
||||
w13_lora_b = w13_lora_b.permute(1, 0, 2)
|
||||
w2_lora_b = w2_lora_b.permute(1, 0, 2)
|
||||
|
||||
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
|
||||
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)
|
||||
|
||||
|
||||
@ -171,6 +171,37 @@ class LoRAModelManager:
|
||||
assert gate_up_proj_lora is not None
|
||||
assert down_proj_lora is not None
|
||||
if self._is_3d_moe_model:
|
||||
num_experts = module.w13_lora_a_stacked[0].shape[1]
|
||||
lora_weight_device = gate_up_proj_lora.lora_b.device
|
||||
|
||||
# (num_experts,rank,input_size)
|
||||
gate_up_proj_lora.lora_a = gate_up_proj_lora.lora_a.reshape(
|
||||
num_experts, -1, gate_up_proj_lora.lora_a.shape[-1]
|
||||
)
|
||||
down_proj_lora.lora_a = down_proj_lora.lora_a.reshape(
|
||||
num_experts, -1, down_proj_lora.lora_a.shape[-1]
|
||||
)
|
||||
|
||||
# (output_size,num_experts,rank)
|
||||
gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.reshape(
|
||||
gate_up_proj_lora.lora_b.shape[0], num_experts, -1
|
||||
)
|
||||
down_proj_lora.lora_b = down_proj_lora.lora_b.reshape(
|
||||
down_proj_lora.lora_b.shape[0], num_experts, -1
|
||||
)
|
||||
|
||||
# (num_experts,output_size,rank)
|
||||
gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.permute(
|
||||
1, 0, 2
|
||||
).contiguous()
|
||||
down_proj_lora.lora_b = down_proj_lora.lora_b.permute(
|
||||
1, 0, 2
|
||||
).contiguous()
|
||||
|
||||
if str(lora_weight_device) == "cpu" and is_pin_memory_available():
|
||||
gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.pin_memory()
|
||||
down_proj_lora.lora_b = down_proj_lora.lora_b.pin_memory()
|
||||
|
||||
module_lora.lora_a = [
|
||||
gate_up_proj_lora.lora_a,
|
||||
down_proj_lora.lora_a,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user