[Core] Ensure LoRA linear respect the base_layer's tp_size and tp_rank (#25487)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-09-24 02:19:25 +08:00 committed by GitHub
parent 867ecdd1c8
commit 5abb117901
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 24 additions and 41 deletions

View File

@ -24,11 +24,12 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
super().__init__() super().__init__()
self.base_layer = base_layer self.base_layer = base_layer
self.input_size = self.base_layer.input_size self.input_size = self.base_layer.input_size
# Ensure tp_size and tp_rank consistency with the base_layer.
self.tp_size = self.base_layer.tp_size
self.tp_rank = self.base_layer.tp_rank
self.device = _get_lora_device(self.base_layer) self.device = _get_lora_device(self.base_layer)
self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
self.output_slices: tuple[int, ...] self.output_slices: tuple[int, ...]
self.tp_size: int
self.output_size: int self.output_size: int
self.n_slices: int self.n_slices: int

View File

@ -8,9 +8,7 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import tensor_model_parallel_all_gather
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
@ -85,7 +83,6 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
# inconsistent when TP is greater than 1. # inconsistent when TP is greater than 1.
self.is_merged_col_linear = type( self.is_merged_col_linear = type(
base_layer) is MergedColumnParallelLinear base_layer) is MergedColumnParallelLinear
self.tp_size = get_tensor_model_parallel_world_size()
self.output_size = self.base_layer.output_size_per_partition self.output_size = self.base_layer.output_size_per_partition
# There is only one LoRA layer # There is only one LoRA layer
self.n_slices = 1 self.n_slices = 1
@ -97,22 +94,20 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
# Applicable to cases where the base_layer is # Applicable to cases where the base_layer is
# MergedColumnParallelLinear. # MergedColumnParallelLinear.
if self.is_merged_col_linear: if self.is_merged_col_linear:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size // 2 shard_size = self.output_size // 2
offset = lora_b.shape[0] // 2 offset = lora_b.shape[0] // 2
left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) * left_weight = lora_b[self.tp_rank * shard_size:(self.tp_rank + 1) *
shard_size, :] shard_size, :]
right_weight = lora_b[offset + tp_rank * shard_size:offset + right_weight = lora_b[offset + self.tp_rank * shard_size:offset +
(tp_rank + 1) * shard_size, :] (self.tp_rank + 1) * shard_size, :]
lora_b = torch.cat([left_weight, right_weight], dim=0) lora_b = torch.cat([left_weight, right_weight], dim=0)
# Applicable to cases where the base_layer is # Applicable to cases where the base_layer is
# ColumnParallelLinear. # ColumnParallelLinear.
else: else:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size shard_size = self.output_size
start_idx = tensor_model_parallel_rank * shard_size start_idx = self.tp_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[start_idx:end_idx, :] lora_b = lora_b[start_idx:end_idx, :]
return lora_b return lora_b
@ -120,10 +115,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
# TODO: Fix the slicing logic of bias. # TODO: Fix the slicing logic of bias.
if bias is None: if bias is None:
return bias return bias
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size shard_size = self.output_size
start_idx = tensor_model_parallel_rank * shard_size start_idx = self.tp_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx] bias = bias[start_idx:end_idx]
return bias return bias
@ -144,7 +138,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
# Matrix multiply. # Matrix multiply.
output_parallel = self.apply(input_, bias) output_parallel = self.apply(input_, bias)
if self.base_layer.gather_output: if self.base_layer.gather_output and self.tp_size > 1:
# All-gather across the partitions. # All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel) output = tensor_model_parallel_all_gather(output_parallel)
else: else:
@ -185,8 +179,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
QKVParallelLinear]) -> None: QKVParallelLinear]) -> None:
super().__init__(base_layer) super().__init__(base_layer)
# There are two LoRA layers # There are two LoRA layers
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
# the output_sizes in MergedColumnParallelLinear is not sharded by tp # the output_sizes in MergedColumnParallelLinear is not sharded by tp
# we need to divide it by the tp_size to get correct slices size # we need to divide it by the tp_size to get correct slices size
output_sizes = self.base_layer.output_sizes output_sizes = self.base_layer.output_sizes
@ -341,9 +333,9 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.n_slices = 1 self.n_slices = 1
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank self.q_shard_id = self.tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[self.q_proj_shard_size * lora_b_q = lora_b[self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1), :] (self.q_shard_id + 1), :]
@ -397,8 +389,6 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
super().__init__(base_layer) super().__init__(base_layer)
# There are three LoRA layer. # There are three LoRA layer.
self.n_slices = len(self.base_layer.output_sizes) self.n_slices = len(self.base_layer.output_sizes)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.q_proj_shard_size = (self.base_layer.num_heads * self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size) self.base_layer.head_size)
@ -461,9 +451,8 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
# Therefore, the sharding of `lora_a` only needs to correspond with the # Therefore, the sharding of `lora_a` only needs to correspond with the
# gather operation. # gather operation.
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2] shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size start_idx = self.tp_rank * shard_size
lora_a = lora_a[start_idx:start_idx + shard_size, :] lora_a = lora_a[start_idx:start_idx + shard_size, :]
return lora_a return lora_a
@ -547,9 +536,8 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
""" """
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2] shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size start_idx = self.tp_rank * shard_size
lora_a = lora_a[start_idx:start_idx + shard_size, :] lora_a = lora_a[start_idx:start_idx + shard_size, :]
return lora_a return lora_a

View File

@ -18,7 +18,6 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: ReplicatedLinear) -> None: def __init__(self, base_layer: ReplicatedLinear) -> None:
super().__init__(base_layer, ) super().__init__(base_layer, )
# To ensure interface compatibility, set to 1 always. # To ensure interface compatibility, set to 1 always.
self.tp_size = 1
self.output_size = self.base_layer.output_size self.output_size = self.base_layer.output_size
self.n_slices = 1 self.n_slices = 1

View File

@ -8,9 +8,7 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (split_tensor_along_last_dim,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
# yapf: disable # yapf: disable
from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.linear import RowParallelLinear
@ -25,12 +23,9 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None: def __init__(self, base_layer: RowParallelLinear) -> None:
super().__init__(base_layer) super().__init__(base_layer)
self.tp_size = get_tensor_model_parallel_world_size()
# reset input_size # reset input_size
self.input_size = self.base_layer.input_size_per_partition self.input_size = self.base_layer.input_size_per_partition
self.output_size = self.base_layer.output_size self.output_size = self.base_layer.output_size
self.tp_rank = get_tensor_model_parallel_rank()
# There is only one LoRA layer. # There is only one LoRA layer.
self.n_slices = 1 self.n_slices = 1
@ -68,12 +63,12 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
else: else:
# TODO: simplify code below # TODO: simplify code below
splitted_input = split_tensor_along_last_dim( splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size) input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous() input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply. # Matrix multiply.
output_parallel = self.apply(input_parallel) output_parallel = self.apply(input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1: if self.base_layer.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel) output_ = tensor_model_parallel_all_reduce(output_parallel)
else: else:
output_ = output_parallel output_ = output_parallel
@ -154,8 +149,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
buffer, x, self.lora_a_stacked, 1.0) buffer, x, self.lora_a_stacked, 1.0)
if not current_platform.can_update_inplace(): if not current_platform.can_update_inplace():
buffer = shrunk_buffer buffer = shrunk_buffer
if self.tp_size>1:
buffer = tensor_model_parallel_all_reduce(buffer) buffer = tensor_model_parallel_all_reduce(buffer)
# following S-LoRA, allows the fusing of all_gather and all_reduce # following S-LoRA, allows the fusing of all_gather and all_reduce
# by adding the column partitioned lora output to a slice of output # by adding the column partitioned lora output to a slice of output

View File

@ -48,11 +48,11 @@ class LoRALayerWeights:
@property @property
def input_dim(self) -> int: def input_dim(self) -> int:
return self.lora_a.shape[0] return self.lora_a.shape[1]
@property @property
def output_dim(self) -> int: def output_dim(self) -> int:
return self.lora_b.shape[1] return self.lora_b.shape[0]
@property @property
def is_packed(self) -> bool: def is_packed(self) -> bool: