From 1489902b531bb649f8110c94572b2d8b753a72cc Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 22 Nov 2025 12:01:30 +0800 Subject: [PATCH] [LoRA] Cleanup FusedMoEWithLoRA (#29187) Signed-off-by: Jee Jee Li --- vllm/lora/layers/fused_moe.py | 193 ++++++++++++------------ vllm/lora/punica_wrapper/punica_base.py | 4 +- vllm/lora/punica_wrapper/punica_gpu.py | 4 +- 3 files changed, 98 insertions(+), 103 deletions(-) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index adf30855cafc..5aeaca8de5e5 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -42,6 +42,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() self.device = base_layer.w2_weight.device + self.w13_slices = 2 self._inject_lora_into_fused_moe() def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]: @@ -60,8 +61,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): def _get_lora_moe_configs( self, op_prefix: str, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, + num_loras: int, + rank: int, num_slices: int, M: int, layer: FusedMoE, @@ -69,23 +70,25 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): config_dtype: str, ): if envs.VLLM_TUNED_CONFIG_FOLDER: + hidden_size = layer.hidden_size + intermediate_size = layer.intermediate_size_per_partition shrink_config = get_lora_op_configs( op_type=f"fused_moe_lora_{op_prefix}_shrink", - max_loras=lora_a_stacked.shape[0], + max_loras=num_loras, batch=M, - hidden_size=lora_a_stacked.shape[-1], - rank=lora_a_stacked.shape[-2], + hidden_size=hidden_size, + rank=rank, num_slices=num_slices, - moe_intermediate_size=lora_b_stacked.shape[-2], + moe_intermediate_size=intermediate_size, ) expand_config = get_lora_op_configs( op_type=f"fused_moe_lora_{op_prefix}_expand", - max_loras=lora_a_stacked.shape[0], + max_loras=num_loras, batch=M, - hidden_size=lora_a_stacked.shape[-1], - rank=lora_a_stacked.shape[-2], + hidden_size=hidden_size, # lora_a_stacked.shape[-1], + rank=rank, num_slices=num_slices, - moe_intermediate_size=lora_b_stacked.shape[-2], + moe_intermediate_size=intermediate_size, # lora_b_stacked.shape[-2], ) else: # fall back to the default config get_config_func = functools.partial( @@ -152,12 +155,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE num_tokens = hidden_states.size(0) M = min(num_tokens, CHUNK_SIZE) - + max_lora_rank = self.w13_lora_a_stacked[0].shape[-2] shrink_config, expand_config = self._get_lora_moe_configs( op_prefix="w13", - lora_a_stacked=self.w1_lora_a_stacked, - lora_b_stacked=self.w1_lora_b_stacked, - num_slices=2, + num_loras=self.max_loras, + rank=max_lora_rank, + num_slices=self.w13_slices, M=M, layer=layer, top_k=top_k, @@ -165,7 +168,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ) # get the block size of m from customized config or default config - max_loras = self.w1_lora_a_stacked.shape[0] ( sorted_token_ids_lora, expert_ids_lora, @@ -175,7 +177,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): num_tokens, shrink_config["BLOCK_SIZE_M"], self.base_layer.local_num_experts, - max_loras, + self.max_loras, self.adapter_enabled, expert_map, ) @@ -186,17 +188,15 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): num_tokens_post_padded_lora ) - w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked] - w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked] - max_lora_rank = self.w1_lora_a_stacked.shape[-2] - expert_ids_lora = expert_ids_lora.view(max_loras, -1) - sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1) + expert_ids_lora = expert_ids_lora.view(self.max_loras, -1) + sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1) + # self.punica_wrapper.add_lora_fused_moe( input.view(-1, top_k, input.shape[-1]), hidden_states, - w13_lora_a_stacked, - w13_lora_b_stacked, + self.w13_lora_a_stacked, + self.w13_lora_b_stacked, topk_weights, sorted_token_ids_lora, expert_ids_lora, @@ -230,11 +230,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE num_tokens = hidden_states.size(0) M = min(num_tokens, CHUNK_SIZE) - + max_lora_rank = self.w2_lora_a_stacked.shape[-2] shrink_config, expand_config = self._get_lora_moe_configs( op_prefix="w2", - lora_a_stacked=self.w2_lora_a_stacked, - lora_b_stacked=self.w2_lora_b_stacked, + num_loras=self.max_loras, + rank=max_lora_rank, num_slices=1, M=M, layer=layer, @@ -247,20 +247,19 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): num_tokens_post_padded_lora = moe_state_dict[ "num_tokens_post_padded_lora" ] - max_loras = self.w1_lora_a_stacked.shape[0] - expert_ids_lora = expert_ids_lora.view(max_loras, -1) - sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1) + + expert_ids_lora = expert_ids_lora.view(self.max_loras, -1) + sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1) intermediate_cache2 = moe_state_dict["intermediate_cache2"] intermediate_cache3 = args[0] - max_lora_rank = self.w2_lora_a_stacked.shape[-2] shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size) self.punica_wrapper.add_lora_fused_moe( intermediate_cache3, intermediate_cache2, - [self.w2_lora_a_stacked], - [self.w2_lora_b_stacked], + (self.w2_lora_a_stacked,), + (self.w2_lora_b_stacked,), topk_weights, sorted_token_ids_lora, expert_ids_lora, @@ -289,7 +288,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): fused_experts.moe_sum = moe_sum_decorator( self.base_layer, fused_experts.moe_sum ) - self.base_layer.quant_method = FusedMoEModularMethod( self.base_layer.quant_method, m_fused_moe_fn ) @@ -301,33 +299,42 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): model_config: PretrainedConfig | None = None, ) -> None: """Initializes lora matrices.""" + assert self.w13_slices == 2 + self.max_loras = lora_config.max_loras self.fully_sharded = lora_config.fully_sharded_loras self.adapter_enabled = torch.tensor( [0] * (max_loras + 1), dtype=torch.int, device=self.device ) - self.w1_lora_a_stacked = torch.zeros( - ( - max_loras, - self.base_layer.local_num_experts, - lora_config.max_lora_rank - if not self.fully_sharded - else divide(lora_config.max_lora_rank, self.tp_size), - self.base_layer.hidden_size, - ), - dtype=lora_config.lora_dtype, - device=self.device, + self.w13_lora_a_stacked = tuple( + torch.zeros( + ( + max_loras, + self.base_layer.local_num_experts, + lora_config.max_lora_rank + if not self.fully_sharded + else divide(lora_config.max_lora_rank, self.tp_size), + self.base_layer.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + for _ in range(self.w13_slices) ) - self.w1_lora_b_stacked = torch.zeros( - ( - max_loras, - self.base_layer.local_num_experts, - self.base_layer.intermediate_size_per_partition, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.device, + + self.w13_lora_b_stacked = tuple( + torch.zeros( + ( + max_loras, + self.base_layer.local_num_experts, + self.base_layer.intermediate_size_per_partition, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + for _ in range(self.w13_slices) ) self.w2_lora_a_stacked = torch.zeros( @@ -353,29 +360,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): device=self.device, ) - self.w3_lora_a_stacked = torch.zeros( - ( - max_loras, - self.base_layer.local_num_experts, - lora_config.max_lora_rank - if not self.fully_sharded - else divide(lora_config.max_lora_rank, self.tp_size), - self.base_layer.hidden_size, - ), - dtype=lora_config.lora_dtype, - device=self.device, - ) - self.w3_lora_b_stacked = torch.zeros( - ( - max_loras, - self.base_layer.local_num_experts, - self.base_layer.intermediate_size_per_partition, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.device, - ) - # They will be used by 'LoRALayerWeights.create_dummy_lora_weights' # to create a dummy LoRA weights. self.lora_a_stacked = [] @@ -383,20 +367,28 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): for lora_id in range(max_loras): for experts_id in range(self.base_layer.local_num_experts): # gate_proj,down_proj,up_proj - self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id]) + self.lora_a_stacked.append( + self.w13_lora_a_stacked[0][lora_id][experts_id] + ) self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id]) - self.lora_a_stacked.append(self.w3_lora_a_stacked[lora_id][experts_id]) + self.lora_a_stacked.append( + self.w13_lora_a_stacked[1][lora_id][experts_id] + ) - self.lora_b_stacked.append(self.w1_lora_b_stacked[lora_id][experts_id]) + self.lora_b_stacked.append( + self.w13_lora_b_stacked[0][lora_id][experts_id] + ) self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id]) - self.lora_b_stacked.append(self.w3_lora_b_stacked[lora_id][experts_id]) + self.lora_b_stacked.append( + self.w13_lora_b_stacked[1][lora_id][experts_id] + ) def reset_lora(self, index: int): """Resets the lora weights at index back to 0.""" - self.w1_lora_a_stacked[index] = 0 - self.w1_lora_b_stacked[index] = 0 - self.w3_lora_a_stacked[index] = 0 - self.w3_lora_b_stacked[index] = 0 + for pos in range(self.w13_slices): + self.w13_lora_a_stacked[pos][index] = 0 + self.w13_lora_b_stacked[pos][index] = 0 + self.w2_lora_a_stacked[index] = 0 self.w2_lora_b_stacked[index] = 0 self.adapter_enabled[index] = 0 @@ -434,7 +426,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): if self.fully_sharded: # Based on S-LoRA, we slice W1 and W3 A along the rank dim, # and W2 B along the hidden_size dim. - w13_shard_size = self.w1_lora_a_stacked[index, eid].shape[0] + w13_shard_size = self.w13_lora_a_stacked[0][index, eid].shape[0] w13_start_idx = self.tp_rank * w13_shard_size w13_end_idx = (self.tp_rank + 1) * w13_shard_size w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :] @@ -444,29 +436,32 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): w2_start_idx = self.tp_rank * w2_shard_size w2_end_idx = (self.tp_rank + 1) * w2_shard_size w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :] - - self.w1_lora_a_stacked[ + # w1 lora_a + self.w13_lora_a_stacked[0][ index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1] ].copy_(w1_lora_a, non_blocking=True) - - self.w3_lora_a_stacked[ + # w3 lora_a + self.w13_lora_a_stacked[1][ index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1] ].copy_(w3_lora_a, non_blocking=True) + # w1 lora_b + self.w13_lora_b_stacked[0][ + index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1] + ].copy_(w1_lora_b, non_blocking=True) + # w3 lora_b + self.w13_lora_b_stacked[1][ + index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1] + ].copy_(w3_lora_b, non_blocking=True) + + self.w2_lora_a_stacked[ + index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1] + ].copy_(w2_lora_a, non_blocking=True) + self.w2_lora_b_stacked[ index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1] ].copy_(w2_lora_b, non_blocking=True) - self.w1_lora_b_stacked[ - index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1] - ].copy_(w1_lora_b, non_blocking=True) - self.w3_lora_b_stacked[ - index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1] - ].copy_(w3_lora_b, non_blocking=True) - self.w2_lora_a_stacked[ - index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1] - ].copy_(w2_lora_a, non_blocking=True) - @classmethod def can_replace_layer( cls, diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 7c0fc8167711..ce38751e4b6a 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -470,8 +470,8 @@ class PunicaWrapperBase(PunicaWrapperABC): self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: list[torch.Tensor], - lora_b_stacked: list[torch.Tensor], + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], topk_weights: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 52138ef0cc3b..ef4b4ab7c349 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -360,8 +360,8 @@ class PunicaWrapperGPU(PunicaWrapperBase): self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: list[torch.Tensor], - lora_b_stacked: list[torch.Tensor], + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], topk_weights: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor,