diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index 6cf5815ef12da..ed294b0aedaf4 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -24,11 +24,12 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): super().__init__() self.base_layer = base_layer self.input_size = self.base_layer.input_size + # Ensure tp_size and tp_rank consistency with the base_layer. + self.tp_size = self.base_layer.tp_size + self.tp_rank = self.base_layer.tp_rank self.device = _get_lora_device(self.base_layer) self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None - self.output_slices: tuple[int, ...] - self.tp_size: int self.output_size: int self.n_slices: int diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index fa4eb272a69fe..6284576446c8f 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -8,9 +8,7 @@ import torch.nn as nn from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed.utils import divide from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -85,7 +83,6 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): # inconsistent when TP is greater than 1. self.is_merged_col_linear = type( base_layer) is MergedColumnParallelLinear - self.tp_size = get_tensor_model_parallel_world_size() self.output_size = self.base_layer.output_size_per_partition # There is only one LoRA layer self.n_slices = 1 @@ -97,22 +94,20 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): # Applicable to cases where the base_layer is # MergedColumnParallelLinear. if self.is_merged_col_linear: - tp_rank = get_tensor_model_parallel_rank() shard_size = self.output_size // 2 offset = lora_b.shape[0] // 2 - left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) * + left_weight = lora_b[self.tp_rank * shard_size:(self.tp_rank + 1) * shard_size, :] - right_weight = lora_b[offset + tp_rank * shard_size:offset + - (tp_rank + 1) * shard_size, :] + right_weight = lora_b[offset + self.tp_rank * shard_size:offset + + (self.tp_rank + 1) * shard_size, :] lora_b = torch.cat([left_weight, right_weight], dim=0) # Applicable to cases where the base_layer is # ColumnParallelLinear. else: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() shard_size = self.output_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size lora_b = lora_b[start_idx:end_idx, :] return lora_b @@ -120,10 +115,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): # TODO: Fix the slicing logic of bias. if bias is None: return bias - tensor_model_parallel_rank = get_tensor_model_parallel_rank() shard_size = self.output_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size bias = bias[start_idx:end_idx] return bias @@ -144,7 +138,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): # Matrix multiply. output_parallel = self.apply(input_, bias) - if self.base_layer.gather_output: + if self.base_layer.gather_output and self.tp_size > 1: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) else: @@ -185,8 +179,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): QKVParallelLinear]) -> None: super().__init__(base_layer) # There are two LoRA layers - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() # the output_sizes in MergedColumnParallelLinear is not sharded by tp # we need to divide it by the tp_size to get correct slices size output_sizes = self.base_layer.output_sizes @@ -341,9 +333,9 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): self.n_slices = 1 def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() - self.q_shard_id = tp_rank - self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas lora_b_q = lora_b[self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size * (self.q_shard_id + 1), :] @@ -397,8 +389,6 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): super().__init__(base_layer) # There are three LoRA layer. self.n_slices = len(self.base_layer.output_sizes) - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() self.q_proj_shard_size = (self.base_layer.num_heads * self.base_layer.head_size) @@ -461,9 +451,8 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): # Therefore, the sharding of `lora_a` only needs to correspond with the # gather operation. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() shard_size = self.lora_a_stacked[0].shape[2] - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size lora_a = lora_a[start_idx:start_idx + shard_size, :] return lora_a @@ -547,9 +536,8 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): """ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() shard_size = self.lora_a_stacked[0].shape[2] - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size lora_a = lora_a[start_idx:start_idx + shard_size, :] return lora_a diff --git a/vllm/lora/layers/replicated_linear.py b/vllm/lora/layers/replicated_linear.py index 3356297c1537a..18a8f13ed9427 100644 --- a/vllm/lora/layers/replicated_linear.py +++ b/vllm/lora/layers/replicated_linear.py @@ -18,7 +18,6 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): def __init__(self, base_layer: ReplicatedLinear) -> None: super().__init__(base_layer, ) # To ensure interface compatibility, set to 1 always. - self.tp_size = 1 self.output_size = self.base_layer.output_size self.n_slices = 1 diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py index cac2c92136dca..d468655e629ae 100644 --- a/vllm/lora/layers/row_parallel_linear.py +++ b/vllm/lora/layers/row_parallel_linear.py @@ -8,9 +8,7 @@ import torch.nn as nn from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, +from vllm.distributed import (split_tensor_along_last_dim, tensor_model_parallel_all_reduce) # yapf: disable from vllm.model_executor.layers.linear import RowParallelLinear @@ -25,12 +23,9 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): def __init__(self, base_layer: RowParallelLinear) -> None: super().__init__(base_layer) - self.tp_size = get_tensor_model_parallel_world_size() # reset input_size self.input_size = self.base_layer.input_size_per_partition self.output_size = self.base_layer.output_size - - self.tp_rank = get_tensor_model_parallel_rank() # There is only one LoRA layer. self.n_slices = 1 @@ -68,12 +63,12 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): else: # TODO: simplify code below splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.base_layer.tp_size) + input_, num_partitions=self.tp_size) input_parallel = splitted_input[self.tp_rank].contiguous() # Matrix multiply. output_parallel = self.apply(input_parallel) - if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + if self.base_layer.reduce_results and self.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: output_ = output_parallel @@ -154,8 +149,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): buffer, x, self.lora_a_stacked, 1.0) if not current_platform.can_update_inplace(): buffer = shrunk_buffer - - buffer = tensor_model_parallel_all_reduce(buffer) + if self.tp_size>1: + buffer = tensor_model_parallel_all_reduce(buffer) # following S-LoRA, allows the fusing of all_gather and all_reduce # by adding the column partitioned lora output to a slice of output diff --git a/vllm/lora/lora_weights.py b/vllm/lora/lora_weights.py index e3198fb3d3ae4..90e18217d28be 100644 --- a/vllm/lora/lora_weights.py +++ b/vllm/lora/lora_weights.py @@ -48,11 +48,11 @@ class LoRALayerWeights: @property def input_dim(self) -> int: - return self.lora_a.shape[0] + return self.lora_a.shape[1] @property def output_dim(self) -> int: - return self.lora_b.shape[1] + return self.lora_b.shape[0] @property def is_packed(self) -> bool: