mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 14:47:17 +08:00
[LoRA] Optimize 3D MoE logic (#29222)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
c309bb5245
commit
1073ba68b0
@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
@ -84,14 +86,17 @@ def test_gpt_oss_lora(gptoss20b_lora_files):
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_gpt_oss_lora_tp2(gptoss20b_lora_files):
|
||||
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
|
||||
def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=2,
|
||||
max_lora_rank=8,
|
||||
max_num_seqs=16,
|
||||
tensor_parallel_size=2,
|
||||
fully_sharded_loras=fully_sharded_loras,
|
||||
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
|
||||
cudagraph_specialize_lora=False,
|
||||
),
|
||||
|
||||
@ -11,7 +11,7 @@ from vllm.lora.layers.column_parallel_linear import (
|
||||
QKVParallelLinearWithLoRA,
|
||||
QKVParallelLinearWithShardedLoRA,
|
||||
)
|
||||
from vllm.lora.layers.fused_moe import FusedMoEWithLoRA
|
||||
from vllm.lora.layers.fused_moe import FusedMoE3DWithLoRA, FusedMoEWithLoRA
|
||||
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
|
||||
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
|
||||
from vllm.lora.layers.row_parallel_linear import (
|
||||
@ -38,4 +38,5 @@ __all__ = [
|
||||
"ReplicatedLinearWithLoRA",
|
||||
"LoRAMapping",
|
||||
"FusedMoEWithLoRA",
|
||||
"FusedMoE3DWithLoRA",
|
||||
]
|
||||
|
||||
@ -42,8 +42,8 @@ class BaseLayerWithLoRA(nn.Module):
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
lora_a: torch.Tensor | list[torch.Tensor],
|
||||
lora_b: torch.Tensor | list[torch.Tensor],
|
||||
):
|
||||
"""Overwrites lora tensors at index."""
|
||||
...
|
||||
|
||||
@ -94,13 +94,15 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
lora_a: torch.Tensor | list[torch.Tensor],
|
||||
lora_b: torch.Tensor | list[torch.Tensor],
|
||||
):
|
||||
# Except for QKVParallelLinearWithLoRA and
|
||||
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
|
||||
# store weights in a tuple of size 1. These two layers will
|
||||
# override this function.
|
||||
assert isinstance(lora_a, torch.Tensor)
|
||||
assert isinstance(lora_b, torch.Tensor)
|
||||
assert (
|
||||
len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1
|
||||
)
|
||||
|
||||
@ -246,8 +246,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
lora_a: torch.Tensor | list[torch.Tensor],
|
||||
lora_b: torch.Tensor | list[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.device = base_layer.w2_weight.device
|
||||
self.w13_slices = 2
|
||||
self._w13_slices = 2
|
||||
self._inject_lora_into_fused_moe()
|
||||
|
||||
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
|
||||
@ -160,7 +160,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
op_prefix="w13",
|
||||
num_loras=self.max_loras,
|
||||
rank=max_lora_rank,
|
||||
num_slices=self.w13_slices,
|
||||
num_slices=self._w13_slices,
|
||||
M=M,
|
||||
layer=layer,
|
||||
top_k=top_k,
|
||||
@ -230,7 +230,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_tokens = hidden_states.size(0)
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
max_lora_rank = self.w2_lora_a_stacked.shape[-2]
|
||||
max_lora_rank = self.w2_lora_a_stacked[0].shape[-2]
|
||||
shrink_config, expand_config = self._get_lora_moe_configs(
|
||||
op_prefix="w2",
|
||||
num_loras=self.max_loras,
|
||||
@ -258,8 +258,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
self.punica_wrapper.add_lora_fused_moe(
|
||||
intermediate_cache3,
|
||||
intermediate_cache2,
|
||||
(self.w2_lora_a_stacked,),
|
||||
(self.w2_lora_b_stacked,),
|
||||
self.w2_lora_a_stacked,
|
||||
self.w2_lora_b_stacked,
|
||||
topk_weights,
|
||||
sorted_token_ids_lora,
|
||||
expert_ids_lora,
|
||||
@ -292,22 +292,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
self.base_layer.quant_method, m_fused_moe_fn
|
||||
)
|
||||
|
||||
def create_lora_weights(
|
||||
def _create_lora_a_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
assert self.w13_slices == 2
|
||||
self.max_loras = lora_config.max_loras
|
||||
self.fully_sharded = lora_config.fully_sharded_loras
|
||||
|
||||
self.adapter_enabled = torch.tensor(
|
||||
[0] * (max_loras + 1), dtype=torch.int, device=self.device
|
||||
)
|
||||
|
||||
self.w13_lora_a_stacked = tuple(
|
||||
):
|
||||
self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
|
||||
torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
@ -320,10 +310,23 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.w13_slices)
|
||||
for _ in range(self._w13_slices)
|
||||
)
|
||||
self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
|
||||
torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.intermediate_size_per_partition,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
self.w13_lora_b_stacked = tuple(
|
||||
def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
|
||||
self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
|
||||
torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
@ -334,34 +337,42 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.w13_slices)
|
||||
for _ in range(self._w13_slices)
|
||||
)
|
||||
self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
|
||||
torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
self.base_layer.hidden_size
|
||||
if not self.fully_sharded
|
||||
else divide(self.base_layer.hidden_size, self.tp_size),
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
self.w2_lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.intermediate_size_per_partition,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.w2_lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
self.base_layer.hidden_size
|
||||
if not self.fully_sharded
|
||||
else divide(self.base_layer.hidden_size, self.tp_size),
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
self.max_loras = lora_config.max_loras
|
||||
self.fully_sharded = lora_config.fully_sharded_loras
|
||||
|
||||
self.adapter_enabled = torch.tensor(
|
||||
[0] * (max_loras + 1), dtype=torch.int, device=self.device
|
||||
)
|
||||
|
||||
self._create_lora_a_weights(max_loras, lora_config)
|
||||
self._create_lora_b_weights(max_loras, lora_config)
|
||||
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
|
||||
# to create a dummy LoRA weights.
|
||||
# TODO Optimize this section
|
||||
self.lora_a_stacked = []
|
||||
self.lora_b_stacked = []
|
||||
for lora_id in range(max_loras):
|
||||
@ -370,36 +381,43 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
self.lora_a_stacked.append(
|
||||
self.w13_lora_a_stacked[0][lora_id][experts_id]
|
||||
)
|
||||
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
|
||||
self.lora_a_stacked.append(
|
||||
self.w13_lora_a_stacked[1][lora_id][experts_id]
|
||||
self.w2_lora_a_stacked[0][lora_id][experts_id]
|
||||
)
|
||||
|
||||
self.lora_b_stacked.append(
|
||||
self.w13_lora_b_stacked[0][lora_id][experts_id]
|
||||
)
|
||||
self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id])
|
||||
self.lora_b_stacked.append(
|
||||
self.w2_lora_b_stacked[0][lora_id][experts_id]
|
||||
)
|
||||
|
||||
self.lora_a_stacked.append(
|
||||
self.w13_lora_a_stacked[1][lora_id][experts_id]
|
||||
)
|
||||
self.lora_b_stacked.append(
|
||||
self.w13_lora_b_stacked[1][lora_id][experts_id]
|
||||
)
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
"""Resets the lora weights at index back to 0."""
|
||||
for pos in range(self.w13_slices):
|
||||
for pos in range(self._w13_slices):
|
||||
self.w13_lora_a_stacked[pos][index] = 0
|
||||
self.w13_lora_b_stacked[pos][index] = 0
|
||||
|
||||
self.w2_lora_a_stacked[index] = 0
|
||||
self.w2_lora_b_stacked[index] = 0
|
||||
self.w2_lora_a_stacked[0][index] = 0
|
||||
self.w2_lora_b_stacked[0][index] = 0
|
||||
self.adapter_enabled[index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
lora_a: torch.Tensor | list[torch.Tensor],
|
||||
lora_b: torch.Tensor | list[torch.Tensor],
|
||||
):
|
||||
"""Overwrites lora tensors at index."""
|
||||
assert isinstance(lora_a, list)
|
||||
assert isinstance(lora_b, list)
|
||||
self.reset_lora(index)
|
||||
self.adapter_enabled[index] = 1
|
||||
for eid in range(len(lora_a) // 3):
|
||||
@ -432,7 +450,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
|
||||
w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]
|
||||
|
||||
w2_shard_size = self.w2_lora_b_stacked[index, eid].shape[0]
|
||||
w2_shard_size = self.w2_lora_b_stacked[0][index, eid].shape[0]
|
||||
w2_start_idx = self.tp_rank * w2_shard_size
|
||||
w2_end_idx = (self.tp_rank + 1) * w2_shard_size
|
||||
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
|
||||
@ -454,26 +472,14 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
|
||||
].copy_(w3_lora_b, non_blocking=True)
|
||||
|
||||
self.w2_lora_a_stacked[
|
||||
self.w2_lora_a_stacked[0][
|
||||
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
|
||||
].copy_(w2_lora_a, non_blocking=True)
|
||||
|
||||
self.w2_lora_b_stacked[
|
||||
self.w2_lora_b_stacked[0][
|
||||
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
|
||||
].copy_(w2_lora_b, non_blocking=True)
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
# return type(source_layer) is FusedMoE
|
||||
return isinstance(source_layer, FusedMoE)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.base_layer.forward(*args, **kwargs)
|
||||
|
||||
@ -491,3 +497,220 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
@property
|
||||
def is_internal_router(self) -> bool:
|
||||
return self.base_layer.is_internal_router
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
# return type(source_layer) is FusedMoE
|
||||
|
||||
return type(source_layer) is FusedMoE and len(packed_modules_list) == 2
|
||||
|
||||
|
||||
class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
|
||||
def __init__(self, base_layer):
|
||||
super().__init__(base_layer)
|
||||
self._w13_slices = 1
|
||||
|
||||
def _create_lora_b_weights(self, max_loras, lora_config):
|
||||
self.w13_lora_b_stacked: tuple[torch.Tensor] = tuple(
|
||||
torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
self.base_layer.intermediate_size_per_partition * 2,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self._w13_slices)
|
||||
)
|
||||
self.w2_lora_b_stacked: tuple[torch.Tensor] = (
|
||||
torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
self.base_layer.hidden_size
|
||||
if not self.fully_sharded
|
||||
else divide(self.base_layer.hidden_size, self.tp_size),
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
self.max_loras = lora_config.max_loras
|
||||
self.fully_sharded = lora_config.fully_sharded_loras
|
||||
|
||||
self.adapter_enabled = torch.tensor(
|
||||
[0] * (max_loras + 1), dtype=torch.int, device=self.device
|
||||
)
|
||||
|
||||
self._create_lora_a_weights(max_loras, lora_config)
|
||||
self._create_lora_b_weights(max_loras, lora_config)
|
||||
|
||||
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
|
||||
if self.tp_size == 1 or not self.fully_sharded:
|
||||
return w13_lora_a
|
||||
|
||||
# w13_lora_a shape (num_experts,rank,input_size)
|
||||
current_lora_rank = w13_lora_a.shape[1]
|
||||
assert current_lora_rank % self.tp_size == 0
|
||||
|
||||
sliced_rank = current_lora_rank // self.tp_size
|
||||
start_idx = self.tp_rank * sliced_rank
|
||||
end_idx = (self.tp_rank + 1) * sliced_rank
|
||||
return w13_lora_a[:, start_idx:end_idx, :]
|
||||
|
||||
def _slice_w13_b(self, w13_lora_b: torch.Tensor, is_interleave: bool = True):
|
||||
if self.tp_size == 1:
|
||||
return w13_lora_b
|
||||
|
||||
# w13_lora_b shape (num_experts,output_size,rank)
|
||||
shard_size = self.base_layer.intermediate_size_per_partition
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
if is_interleave:
|
||||
# For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
|
||||
# in the interleaved order, and corresponding LoRA need to be processed.
|
||||
w1_lora_b = w13_lora_b[:, ::2, :]
|
||||
w3_lora_b = w13_lora_b[:, 1::2, :]
|
||||
sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
|
||||
sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
|
||||
|
||||
return torch.stack([sliced_w1_lora_b, sliced_w3_lora_b], dim=2).flatten(
|
||||
1, 2
|
||||
)
|
||||
else:
|
||||
slice_size = w13_lora_b.shape[1] // 2
|
||||
w1_lora_b = w13_lora_b[:, :slice_size, :]
|
||||
w3_lora_b = w13_lora_b[:, slice_size:, :]
|
||||
sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
|
||||
sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
|
||||
|
||||
return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)
|
||||
|
||||
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
|
||||
if self.tp_size == 1:
|
||||
return w2_lora_a
|
||||
# w2_lora_a shape (num_experts,rank,input_size)
|
||||
shard_size = self.base_layer.intermediate_size_per_partition
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
|
||||
return w2_lora_a[:, :, start_idx:end_idx]
|
||||
|
||||
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
|
||||
if self.tp_size == 1 or not self.fully_sharded:
|
||||
return w2_lora_b
|
||||
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
|
||||
# w2_lora_b shape (num_experts,output_size,rank)
|
||||
current_lora_size = w2_lora_b.shape[1]
|
||||
|
||||
sliced_size = current_lora_size // self.tp_size
|
||||
start_idx = self.tp_rank * sliced_size
|
||||
end_idx = (self.tp_rank + 1) * sliced_size
|
||||
return w2_lora_b[:, start_idx:end_idx, :]
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor | list[torch.Tensor],
|
||||
lora_b: torch.Tensor | list[torch.Tensor],
|
||||
):
|
||||
"""Overwrites lora tensors at index."""
|
||||
# Make mypy happy
|
||||
assert isinstance(lora_a, list)
|
||||
assert isinstance(lora_b, list)
|
||||
assert len(lora_a) == len(lora_b) == 2
|
||||
|
||||
self.reset_lora(index)
|
||||
self.adapter_enabled[index] = 1
|
||||
|
||||
num_experts = self.w13_lora_a_stacked[0].shape[1]
|
||||
w13_lora_a, w2_lora_a = lora_a
|
||||
w13_lora_b, w2_lora_b = lora_b
|
||||
|
||||
# (num_experts,rank,input_size)
|
||||
w13_lora_a = w13_lora_a.reshape(num_experts, -1, w13_lora_a.shape[-1])
|
||||
w2_lora_a = w2_lora_a.reshape(num_experts, -1, w2_lora_a.shape[-1])
|
||||
# (output_size,num_experts,rank)
|
||||
w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], num_experts, -1)
|
||||
w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], num_experts, -1)
|
||||
# (num_experts,output_size,rank)
|
||||
w13_lora_b = w13_lora_b.permute(1, 0, 2)
|
||||
w2_lora_b = w2_lora_b.permute(1, 0, 2)
|
||||
|
||||
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
|
||||
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b, is_interleave=True)
|
||||
|
||||
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
|
||||
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
|
||||
|
||||
self.w13_lora_a_stacked[0][
|
||||
index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
|
||||
].copy_(sliced_w13_lora_a, non_blocking=True)
|
||||
self.w2_lora_a_stacked[0][
|
||||
index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
|
||||
].copy_(sliced_w2_lora_a, non_blocking=True)
|
||||
|
||||
self.w13_lora_b_stacked[0][
|
||||
index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
|
||||
].copy_(sliced_w13_lora_b, non_blocking=True)
|
||||
self.w2_lora_b_stacked[0][
|
||||
index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
|
||||
].copy_(sliced_w2_lora_b, non_blocking=True)
|
||||
|
||||
@property
|
||||
def w13_input_size(self):
|
||||
"""
|
||||
Full size
|
||||
"""
|
||||
return self.w13_lora_a_stacked[0].shape[-1]
|
||||
|
||||
@property
|
||||
def w13_output_size(self):
|
||||
"""
|
||||
Full size
|
||||
"""
|
||||
return self.w13_lora_b_stacked[0].shape[-2] * self.tp_size
|
||||
|
||||
@property
|
||||
def w2_input_size(self):
|
||||
"""
|
||||
Full size
|
||||
"""
|
||||
return self.w2_lora_a_stacked[0].shape[-1] * self.tp_size
|
||||
|
||||
@property
|
||||
def w2_output_size(self):
|
||||
"""
|
||||
Full size
|
||||
"""
|
||||
return self.w2_lora_a_stacked[0].shape[-2]
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
|
||||
return type(source_layer) is FusedMoE and len(packed_modules_list) == 1
|
||||
|
||||
@ -128,9 +128,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
lora_a: torch.Tensor | list[torch.Tensor],
|
||||
lora_b: torch.Tensor | list[torch.Tensor],
|
||||
):
|
||||
assert isinstance(lora_a, torch.Tensor)
|
||||
assert isinstance(lora_b, torch.Tensor)
|
||||
self.reset_lora(index)
|
||||
self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
|
||||
lora_a, non_blocking=True
|
||||
|
||||
@ -77,12 +77,15 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
lora_a: torch.Tensor | list[torch.Tensor],
|
||||
lora_b: torch.Tensor | list[torch.Tensor],
|
||||
):
|
||||
assert isinstance(lora_a, torch.Tensor)
|
||||
assert isinstance(lora_b, torch.Tensor)
|
||||
self.reset_lora(index)
|
||||
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
|
||||
# so we need transpose here
|
||||
|
||||
self.lora_a_stacked[index, : lora_a.shape[1], : lora_a.shape[0]].copy_(
|
||||
lora_a.T, non_blocking=True
|
||||
)
|
||||
|
||||
@ -22,11 +22,13 @@ from vllm.lora.utils import (
|
||||
from_layer_logits_processor,
|
||||
get_supported_lora_modules,
|
||||
is_base_embeddding_weights,
|
||||
is_moe_model,
|
||||
is_regex_target_modules,
|
||||
parse_fine_tuned_lora_name,
|
||||
process_packed_modules_mapping,
|
||||
replace_submodule,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
||||
from vllm.model_executor.models.interfaces import is_pooling_model
|
||||
@ -356,7 +358,11 @@ class LoRAModelManager:
|
||||
self.modules: dict[str, BaseLayerWithLoRA] = {}
|
||||
# Dict instead of a set for compatibility with LRUCache.
|
||||
self._last_mapping: LoRAMapping | None = None
|
||||
self._is_3d_moe_model = is_moe_model(self.model) and hasattr(
|
||||
self.model, "is_3d_moe_weight"
|
||||
)
|
||||
self._create_lora_modules()
|
||||
|
||||
self.model.lora_manager = self
|
||||
|
||||
def __len__(self) -> int:
|
||||
@ -400,22 +406,36 @@ class LoRAModelManager:
|
||||
self.lora_index_to_id[index] = lora_model.id
|
||||
for module_name, module in self.modules.items():
|
||||
module_lora = self._get_lora_layer_weights(lora_model, module_name)
|
||||
if module_lora:
|
||||
# Note (gnovack) - If MOE lora weights are not split into
|
||||
# num_experts chunks, we split them here
|
||||
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
|
||||
module_lora.lora_a
|
||||
):
|
||||
# Handle FSDP file format where experts.base_layer is the
|
||||
# gate_up_proj and experts is the down_proj
|
||||
gate_up_proj_lora = self._get_lora_layer_weights(
|
||||
lora_model, module_name + ".base_layer"
|
||||
)
|
||||
|
||||
assert gate_up_proj_lora is not None
|
||||
assert module_lora is not None
|
||||
|
||||
down_proj_lora = module_lora
|
||||
if not module_lora:
|
||||
module.reset_lora(index)
|
||||
continue
|
||||
# Note (gnovack) - If MOE lora weights are not split into
|
||||
# num_experts chunks, we split them here
|
||||
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
|
||||
module_lora.lora_a
|
||||
):
|
||||
# Handle PEFT file format where experts.base_layer is the
|
||||
# gate_up_proj and experts is the down_proj
|
||||
gate_up_proj_lora = self._get_lora_layer_weights(
|
||||
lora_model, module_name + ".base_layer"
|
||||
)
|
||||
down_proj_lora = module_lora
|
||||
# FIXME Edge case where LoRA is not added to gate_up_proj
|
||||
# or down_proj
|
||||
assert gate_up_proj_lora is not None
|
||||
assert down_proj_lora is not None
|
||||
if self._is_3d_moe_model:
|
||||
module_lora.lora_a = [
|
||||
gate_up_proj_lora.lora_a,
|
||||
down_proj_lora.lora_a,
|
||||
]
|
||||
module_lora.lora_b = [
|
||||
gate_up_proj_lora.lora_b,
|
||||
down_proj_lora.lora_b,
|
||||
]
|
||||
else:
|
||||
# Some 3D MoE models haven't added the `is_3d_moe_weight`
|
||||
# attribute yet, so fallback here
|
||||
num_experts = module_lora.lora_a.shape[0] // module_lora.rank
|
||||
|
||||
gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
|
||||
@ -444,14 +464,12 @@ class LoRAModelManager:
|
||||
|
||||
module_lora.lora_a = lora_a
|
||||
module_lora.lora_b = lora_b
|
||||
module.set_lora(
|
||||
index,
|
||||
module_lora.lora_a,
|
||||
module_lora.lora_b,
|
||||
)
|
||||
|
||||
module.set_lora(
|
||||
index,
|
||||
module_lora.lora_a,
|
||||
module_lora.lora_b,
|
||||
)
|
||||
else:
|
||||
module.reset_lora(index)
|
||||
return True
|
||||
|
||||
def _deactivate_adapter(self, lora_id: int):
|
||||
@ -512,6 +530,13 @@ class LoRAModelManager:
|
||||
continue
|
||||
parts = module_name.split(".")[-1]
|
||||
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
|
||||
if isinstance(module, FusedMoE):
|
||||
# packed_moduled_lst is used here to just determine whether to
|
||||
# instantiate FusedMoE3DWithLoRA or FusedMoEWithLoRA, and the
|
||||
# difference between these two LoRA layers is whether the
|
||||
# LoRA weights of w1 and w3 have already been fused on disk.
|
||||
|
||||
packed_moduled_lst = ["w13"] if self._is_3d_moe_model else ["w1", "w3"]
|
||||
new_module = replace_submodule(
|
||||
self.model,
|
||||
module_name,
|
||||
@ -560,6 +585,7 @@ class LoRAModelManager:
|
||||
self._register_packed_modules(module_name)
|
||||
# All lora layers share the same punica_wrapper based on reference.
|
||||
new_module.set_mapping(self.punica_wrapper)
|
||||
pass
|
||||
|
||||
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||
assert isinstance(module, BaseLayerWithLoRA), (
|
||||
@ -605,6 +631,30 @@ class LoRAModelManager:
|
||||
module.lora_a_stacked[0].dtype,
|
||||
"cpu",
|
||||
)
|
||||
model.loras[module_name] = lora
|
||||
elif module.__class__.__name__ == "FusedMoE3DWithLoRA":
|
||||
# Case for 3D moe model
|
||||
# w2
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
module_name,
|
||||
module.w2_input_size,
|
||||
module.w2_output_size,
|
||||
rank * module.w2_lora_a_stacked[0].shape[1], # rank*num_experts
|
||||
module.w2_lora_a_stacked[0].dtype,
|
||||
"cpu",
|
||||
)
|
||||
model.loras[module_name] = lora
|
||||
# w13
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
module_name,
|
||||
module.w13_input_size,
|
||||
module.w13_output_size,
|
||||
rank
|
||||
* module.w13_lora_a_stacked[0].shape[1], # rank*num_experts
|
||||
module.w13_lora_a_stacked[0].dtype,
|
||||
"cpu",
|
||||
)
|
||||
model.loras[module_name + ".base_layer"] = lora
|
||||
else:
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
module_name,
|
||||
@ -614,6 +664,7 @@ class LoRAModelManager:
|
||||
module.lora_a_stacked[0].dtype,
|
||||
"cpu",
|
||||
)
|
||||
model.loras[module_name] = lora
|
||||
else:
|
||||
parts = module_name.split(".")
|
||||
replacements = self.packed_modules_mapping[parts[-1]]
|
||||
@ -629,7 +680,7 @@ class LoRAModelManager:
|
||||
)
|
||||
subloras.append(lora)
|
||||
lora = PackedLoRALayerWeights.pack(subloras)
|
||||
model.loras[module_name] = lora
|
||||
model.loras[module_name] = lora
|
||||
return model
|
||||
|
||||
def _match_target_modules(self, module_name: str):
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.lora.layers import (
|
||||
BaseLayerWithLoRA,
|
||||
ColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
FusedMoE3DWithLoRA,
|
||||
FusedMoEWithLoRA,
|
||||
LogitsProcessorWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
@ -62,6 +63,7 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
|
||||
MergedQKVParallelLinearWithShardedLoRA,
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
FusedMoEWithLoRA,
|
||||
FusedMoE3DWithLoRA,
|
||||
}
|
||||
|
||||
|
||||
@ -288,10 +290,12 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
|
||||
# the expert indices are expanded based on the configured number
|
||||
# of routed experts.
|
||||
packed_modules_mapping = get_packed_modules_mapping(model)
|
||||
|
||||
packed_modules_mapping["experts"] = [
|
||||
weight_name.rstrip(".") for _, weight_name, _, _ in moe_packed_mapping
|
||||
]
|
||||
if not hasattr(model, "is_3d_moe_weight"):
|
||||
# 3D MoE LoRA does not need `packed_modules_mapping`
|
||||
packed_modules_mapping["experts"] = [
|
||||
weight_name.rstrip(".")
|
||||
for _, weight_name, _, _ in moe_packed_mapping
|
||||
]
|
||||
|
||||
return packed_modules_mapping
|
||||
else:
|
||||
|
||||
@ -656,6 +656,7 @@ class GptOssModel(nn.Module):
|
||||
|
||||
|
||||
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
|
||||
is_3d_moe_weight: bool = True
|
||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user