[LoRA] Cleanup FusedMoEWithLoRA (#29187)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-11-22 12:01:30 +08:00 committed by GitHub
parent 933f67ecd8
commit 1489902b53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 98 additions and 103 deletions

View File

@ -42,6 +42,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.device = base_layer.w2_weight.device self.device = base_layer.w2_weight.device
self.w13_slices = 2
self._inject_lora_into_fused_moe() self._inject_lora_into_fused_moe()
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]: def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
@ -60,8 +61,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
def _get_lora_moe_configs( def _get_lora_moe_configs(
self, self,
op_prefix: str, op_prefix: str,
lora_a_stacked: torch.Tensor, num_loras: int,
lora_b_stacked: torch.Tensor, rank: int,
num_slices: int, num_slices: int,
M: int, M: int,
layer: FusedMoE, layer: FusedMoE,
@ -69,23 +70,25 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
config_dtype: str, config_dtype: str,
): ):
if envs.VLLM_TUNED_CONFIG_FOLDER: if envs.VLLM_TUNED_CONFIG_FOLDER:
hidden_size = layer.hidden_size
intermediate_size = layer.intermediate_size_per_partition
shrink_config = get_lora_op_configs( shrink_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_shrink", op_type=f"fused_moe_lora_{op_prefix}_shrink",
max_loras=lora_a_stacked.shape[0], max_loras=num_loras,
batch=M, batch=M,
hidden_size=lora_a_stacked.shape[-1], hidden_size=hidden_size,
rank=lora_a_stacked.shape[-2], rank=rank,
num_slices=num_slices, num_slices=num_slices,
moe_intermediate_size=lora_b_stacked.shape[-2], moe_intermediate_size=intermediate_size,
) )
expand_config = get_lora_op_configs( expand_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_expand", op_type=f"fused_moe_lora_{op_prefix}_expand",
max_loras=lora_a_stacked.shape[0], max_loras=num_loras,
batch=M, batch=M,
hidden_size=lora_a_stacked.shape[-1], hidden_size=hidden_size, # lora_a_stacked.shape[-1],
rank=lora_a_stacked.shape[-2], rank=rank,
num_slices=num_slices, num_slices=num_slices,
moe_intermediate_size=lora_b_stacked.shape[-2], moe_intermediate_size=intermediate_size, # lora_b_stacked.shape[-2],
) )
else: # fall back to the default config else: # fall back to the default config
get_config_func = functools.partial( get_config_func = functools.partial(
@ -152,12 +155,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE) M = min(num_tokens, CHUNK_SIZE)
max_lora_rank = self.w13_lora_a_stacked[0].shape[-2]
shrink_config, expand_config = self._get_lora_moe_configs( shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w13", op_prefix="w13",
lora_a_stacked=self.w1_lora_a_stacked, num_loras=self.max_loras,
lora_b_stacked=self.w1_lora_b_stacked, rank=max_lora_rank,
num_slices=2, num_slices=self.w13_slices,
M=M, M=M,
layer=layer, layer=layer,
top_k=top_k, top_k=top_k,
@ -165,7 +168,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
) )
# get the block size of m from customized config or default config # get the block size of m from customized config or default config
max_loras = self.w1_lora_a_stacked.shape[0]
( (
sorted_token_ids_lora, sorted_token_ids_lora,
expert_ids_lora, expert_ids_lora,
@ -175,7 +177,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
num_tokens, num_tokens,
shrink_config["BLOCK_SIZE_M"], shrink_config["BLOCK_SIZE_M"],
self.base_layer.local_num_experts, self.base_layer.local_num_experts,
max_loras, self.max_loras,
self.adapter_enabled, self.adapter_enabled,
expert_map, expert_map,
) )
@ -186,17 +188,15 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
num_tokens_post_padded_lora num_tokens_post_padded_lora
) )
w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked] expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked] sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
max_lora_rank = self.w1_lora_a_stacked.shape[-2] #
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
self.punica_wrapper.add_lora_fused_moe( self.punica_wrapper.add_lora_fused_moe(
input.view(-1, top_k, input.shape[-1]), input.view(-1, top_k, input.shape[-1]),
hidden_states, hidden_states,
w13_lora_a_stacked, self.w13_lora_a_stacked,
w13_lora_b_stacked, self.w13_lora_b_stacked,
topk_weights, topk_weights,
sorted_token_ids_lora, sorted_token_ids_lora,
expert_ids_lora, expert_ids_lora,
@ -230,11 +230,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE) M = min(num_tokens, CHUNK_SIZE)
max_lora_rank = self.w2_lora_a_stacked.shape[-2]
shrink_config, expand_config = self._get_lora_moe_configs( shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w2", op_prefix="w2",
lora_a_stacked=self.w2_lora_a_stacked, num_loras=self.max_loras,
lora_b_stacked=self.w2_lora_b_stacked, rank=max_lora_rank,
num_slices=1, num_slices=1,
M=M, M=M,
layer=layer, layer=layer,
@ -247,20 +247,19 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
num_tokens_post_padded_lora = moe_state_dict[ num_tokens_post_padded_lora = moe_state_dict[
"num_tokens_post_padded_lora" "num_tokens_post_padded_lora"
] ]
max_loras = self.w1_lora_a_stacked.shape[0]
expert_ids_lora = expert_ids_lora.view(max_loras, -1) expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1) sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
intermediate_cache2 = moe_state_dict["intermediate_cache2"] intermediate_cache2 = moe_state_dict["intermediate_cache2"]
intermediate_cache3 = args[0] intermediate_cache3 = args[0]
max_lora_rank = self.w2_lora_a_stacked.shape[-2]
shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size) shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)
self.punica_wrapper.add_lora_fused_moe( self.punica_wrapper.add_lora_fused_moe(
intermediate_cache3, intermediate_cache3,
intermediate_cache2, intermediate_cache2,
[self.w2_lora_a_stacked], (self.w2_lora_a_stacked,),
[self.w2_lora_b_stacked], (self.w2_lora_b_stacked,),
topk_weights, topk_weights,
sorted_token_ids_lora, sorted_token_ids_lora,
expert_ids_lora, expert_ids_lora,
@ -289,7 +288,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
fused_experts.moe_sum = moe_sum_decorator( fused_experts.moe_sum = moe_sum_decorator(
self.base_layer, fused_experts.moe_sum self.base_layer, fused_experts.moe_sum
) )
self.base_layer.quant_method = FusedMoEModularMethod( self.base_layer.quant_method = FusedMoEModularMethod(
self.base_layer.quant_method, m_fused_moe_fn self.base_layer.quant_method, m_fused_moe_fn
) )
@ -301,13 +299,16 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
model_config: PretrainedConfig | None = None, model_config: PretrainedConfig | None = None,
) -> None: ) -> None:
"""Initializes lora matrices.""" """Initializes lora matrices."""
assert self.w13_slices == 2
self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras self.fully_sharded = lora_config.fully_sharded_loras
self.adapter_enabled = torch.tensor( self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device [0] * (max_loras + 1), dtype=torch.int, device=self.device
) )
self.w1_lora_a_stacked = torch.zeros( self.w13_lora_a_stacked = tuple(
torch.zeros(
( (
max_loras, max_loras,
self.base_layer.local_num_experts, self.base_layer.local_num_experts,
@ -319,7 +320,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) )
self.w1_lora_b_stacked = torch.zeros( for _ in range(self.w13_slices)
)
self.w13_lora_b_stacked = tuple(
torch.zeros(
( (
max_loras, max_loras,
self.base_layer.local_num_experts, self.base_layer.local_num_experts,
@ -329,6 +334,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) )
for _ in range(self.w13_slices)
)
self.w2_lora_a_stacked = torch.zeros( self.w2_lora_a_stacked = torch.zeros(
( (
@ -353,29 +360,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
device=self.device, device=self.device,
) )
self.w3_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
lora_config.max_lora_rank
if not self.fully_sharded
else divide(lora_config.max_lora_rank, self.tp_size),
self.base_layer.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w3_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights' # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
# to create a dummy LoRA weights. # to create a dummy LoRA weights.
self.lora_a_stacked = [] self.lora_a_stacked = []
@ -383,20 +367,28 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
for lora_id in range(max_loras): for lora_id in range(max_loras):
for experts_id in range(self.base_layer.local_num_experts): for experts_id in range(self.base_layer.local_num_experts):
# gate_proj,down_proj,up_proj # gate_proj,down_proj,up_proj
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id]) self.lora_a_stacked.append(
self.w13_lora_a_stacked[0][lora_id][experts_id]
)
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id]) self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
self.lora_a_stacked.append(self.w3_lora_a_stacked[lora_id][experts_id]) self.lora_a_stacked.append(
self.w13_lora_a_stacked[1][lora_id][experts_id]
)
self.lora_b_stacked.append(self.w1_lora_b_stacked[lora_id][experts_id]) self.lora_b_stacked.append(
self.w13_lora_b_stacked[0][lora_id][experts_id]
)
self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id]) self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id])
self.lora_b_stacked.append(self.w3_lora_b_stacked[lora_id][experts_id]) self.lora_b_stacked.append(
self.w13_lora_b_stacked[1][lora_id][experts_id]
)
def reset_lora(self, index: int): def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0.""" """Resets the lora weights at index back to 0."""
self.w1_lora_a_stacked[index] = 0 for pos in range(self.w13_slices):
self.w1_lora_b_stacked[index] = 0 self.w13_lora_a_stacked[pos][index] = 0
self.w3_lora_a_stacked[index] = 0 self.w13_lora_b_stacked[pos][index] = 0
self.w3_lora_b_stacked[index] = 0
self.w2_lora_a_stacked[index] = 0 self.w2_lora_a_stacked[index] = 0
self.w2_lora_b_stacked[index] = 0 self.w2_lora_b_stacked[index] = 0
self.adapter_enabled[index] = 0 self.adapter_enabled[index] = 0
@ -434,7 +426,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
if self.fully_sharded: if self.fully_sharded:
# Based on S-LoRA, we slice W1 and W3 A along the rank dim, # Based on S-LoRA, we slice W1 and W3 A along the rank dim,
# and W2 B along the hidden_size dim. # and W2 B along the hidden_size dim.
w13_shard_size = self.w1_lora_a_stacked[index, eid].shape[0] w13_shard_size = self.w13_lora_a_stacked[0][index, eid].shape[0]
w13_start_idx = self.tp_rank * w13_shard_size w13_start_idx = self.tp_rank * w13_shard_size
w13_end_idx = (self.tp_rank + 1) * w13_shard_size w13_end_idx = (self.tp_rank + 1) * w13_shard_size
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :] w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
@ -444,29 +436,32 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
w2_start_idx = self.tp_rank * w2_shard_size w2_start_idx = self.tp_rank * w2_shard_size
w2_end_idx = (self.tp_rank + 1) * w2_shard_size w2_end_idx = (self.tp_rank + 1) * w2_shard_size
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :] w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
# w1 lora_a
self.w1_lora_a_stacked[ self.w13_lora_a_stacked[0][
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1] index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
].copy_(w1_lora_a, non_blocking=True) ].copy_(w1_lora_a, non_blocking=True)
# w3 lora_a
self.w3_lora_a_stacked[ self.w13_lora_a_stacked[1][
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1] index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
].copy_(w3_lora_a, non_blocking=True) ].copy_(w3_lora_a, non_blocking=True)
# w1 lora_b
self.w13_lora_b_stacked[0][
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
].copy_(w1_lora_b, non_blocking=True)
# w3 lora_b
self.w13_lora_b_stacked[1][
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
].copy_(w3_lora_b, non_blocking=True)
self.w2_lora_a_stacked[
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
].copy_(w2_lora_a, non_blocking=True)
self.w2_lora_b_stacked[ self.w2_lora_b_stacked[
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1] index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
].copy_(w2_lora_b, non_blocking=True) ].copy_(w2_lora_b, non_blocking=True)
self.w1_lora_b_stacked[
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
].copy_(w1_lora_b, non_blocking=True)
self.w3_lora_b_stacked[
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
].copy_(w3_lora_b, non_blocking=True)
self.w2_lora_a_stacked[
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
].copy_(w2_lora_a, non_blocking=True)
@classmethod @classmethod
def can_replace_layer( def can_replace_layer(
cls, cls,

View File

@ -470,8 +470,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
self, self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: list[torch.Tensor], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: list[torch.Tensor], lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,

View File

@ -360,8 +360,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
self, self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: list[torch.Tensor], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: list[torch.Tensor], lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,