[Core] Ensure LoRA linear respect the base_layer's tp_size and tp_rank (#25487)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-09-24 02:19:25 +08:00 committed by GitHub
parent 867ecdd1c8
commit 5abb117901
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 24 additions and 41 deletions

View File

@ -24,11 +24,12 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
super().__init__()
self.base_layer = base_layer
self.input_size = self.base_layer.input_size
# Ensure tp_size and tp_rank consistency with the base_layer.
self.tp_size = self.base_layer.tp_size
self.tp_rank = self.base_layer.tp_rank
self.device = _get_lora_device(self.base_layer)
self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
self.output_slices: tuple[int, ...]
self.tp_size: int
self.output_size: int
self.n_slices: int

View File

@ -8,9 +8,7 @@ import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed.utils import divide
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
@ -85,7 +83,6 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
# inconsistent when TP is greater than 1.
self.is_merged_col_linear = type(
base_layer) is MergedColumnParallelLinear
self.tp_size = get_tensor_model_parallel_world_size()
self.output_size = self.base_layer.output_size_per_partition
# There is only one LoRA layer
self.n_slices = 1
@ -97,22 +94,20 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
# Applicable to cases where the base_layer is
# MergedColumnParallelLinear.
if self.is_merged_col_linear:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size // 2
offset = lora_b.shape[0] // 2
left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) *
left_weight = lora_b[self.tp_rank * shard_size:(self.tp_rank + 1) *
shard_size, :]
right_weight = lora_b[offset + tp_rank * shard_size:offset +
(tp_rank + 1) * shard_size, :]
right_weight = lora_b[offset + self.tp_rank * shard_size:offset +
(self.tp_rank + 1) * shard_size, :]
lora_b = torch.cat([left_weight, right_weight], dim=0)
# Applicable to cases where the base_layer is
# ColumnParallelLinear.
else:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[start_idx:end_idx, :]
return lora_b
@ -120,10 +115,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
# TODO: Fix the slicing logic of bias.
if bias is None:
return bias
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias
@ -144,7 +138,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
# Matrix multiply.
output_parallel = self.apply(input_, bias)
if self.base_layer.gather_output:
if self.base_layer.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
@ -185,8 +179,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
QKVParallelLinear]) -> None:
super().__init__(base_layer)
# There are two LoRA layers
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
# the output_sizes in MergedColumnParallelLinear is not sharded by tp
# we need to divide it by the tp_size to get correct slices size
output_sizes = self.base_layer.output_sizes
@ -341,9 +333,9 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.n_slices = 1
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
self.q_shard_id = self.tp_rank
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1), :]
@ -397,8 +389,6 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
super().__init__(base_layer)
# There are three LoRA layer.
self.n_slices = len(self.base_layer.output_sizes)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size)
@ -461,9 +451,8 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
# Therefore, the sharding of `lora_a` only needs to correspond with the
# gather operation.
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
start_idx = self.tp_rank * shard_size
lora_a = lora_a[start_idx:start_idx + shard_size, :]
return lora_a
@ -547,9 +536,8 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
"""
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
start_idx = self.tp_rank * shard_size
lora_a = lora_a[start_idx:start_idx + shard_size, :]
return lora_a

View File

@ -18,7 +18,6 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: ReplicatedLinear) -> None:
super().__init__(base_layer, )
# To ensure interface compatibility, set to 1 always.
self.tp_size = 1
self.output_size = self.base_layer.output_size
self.n_slices = 1

View File

@ -8,9 +8,7 @@ import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
from vllm.distributed import (split_tensor_along_last_dim,
tensor_model_parallel_all_reduce)
# yapf: disable
from vllm.model_executor.layers.linear import RowParallelLinear
@ -25,12 +23,9 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None:
super().__init__(base_layer)
self.tp_size = get_tensor_model_parallel_world_size()
# reset input_size
self.input_size = self.base_layer.input_size_per_partition
self.output_size = self.base_layer.output_size
self.tp_rank = get_tensor_model_parallel_rank()
# There is only one LoRA layer.
self.n_slices = 1
@ -68,12 +63,12 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
else:
# TODO: simplify code below
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size)
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply.
output_parallel = self.apply(input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
if self.base_layer.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
@ -154,8 +149,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
buffer, x, self.lora_a_stacked, 1.0)
if not current_platform.can_update_inplace():
buffer = shrunk_buffer
buffer = tensor_model_parallel_all_reduce(buffer)
if self.tp_size>1:
buffer = tensor_model_parallel_all_reduce(buffer)
# following S-LoRA, allows the fusing of all_gather and all_reduce
# by adding the column partitioned lora output to a slice of output

View File

@ -48,11 +48,11 @@ class LoRALayerWeights:
@property
def input_dim(self) -> int:
return self.lora_a.shape[0]
return self.lora_a.shape[1]
@property
def output_dim(self) -> int:
return self.lora_b.shape[1]
return self.lora_b.shape[0]
@property
def is_packed(self) -> bool: