diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index 1720566840bb1..ffdc32b7339af 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -1,5 +1,5 @@ # pylint: disable=unused-argument -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Union import torch import torch.nn as nn @@ -51,10 +51,9 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): lora_a = lora_a[:, start_idx:start_idx + shard_size] return lora_a - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer, x, bias) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, @@ -88,7 +87,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): ) -def _mcp_apply_weights(x, bias, layer): +def _mcp_apply(x, bias, layer): """ MergedColumnParallelLinearWithShardedLoRA and QKVParallelLinearWithShardedLora share the same @@ -100,8 +99,7 @@ def _mcp_apply_weights(x, bias, layer): """ # expecting 2 for column parallel and 3 for qkv n = len(layer.lora_a_stacked) - output = layer.base_layer.linear_method.apply_weights( - layer.base_layer, x, bias) + output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape @@ -136,18 +134,23 @@ class MergedColumnParallelLinearWithShardedLoRA( Based on S-LoRA, slicing happens along the rank dim. """ - def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + if lora_a[0] is None or lora_a[1] is None: + return lora_a output_shard_size = self.lora_a_stacked[0].shape[2] output_start_idx = self.tp_rank * output_shard_size lora_a = [ - lora_a[i][:, output_start_idx:output_start_idx + output_shard_size] - for i in range(2) + lora_a[0][:, + output_start_idx:output_start_idx + output_shard_size], + lora_a[1][:, output_start_idx:output_start_idx + output_shard_size] ] return lora_a - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - return _mcp_apply_weights(x, bias, self) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return _mcp_apply(x, bias, self) @classmethod @_fully_sharded_can_replace @@ -172,19 +175,23 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): Based on S-LoRA, slicing happens along the rank dim. """ - def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None: + return lora_a shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] start_idx = [self.tp_rank * shard_size[i] for i in range(3)] lora_a = [ - lora_a[i][:, start_idx[i]:start_idx[i] + - shard_size[i]] if lora_a[i] is not None else None - for i in range(3) + lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]], + lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]], + lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]] ] return lora_a - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - return _mcp_apply_weights(x, bias, self) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return _mcp_apply(x, bias, self) @classmethod @_fully_sharded_can_replace @@ -218,9 +225,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): lora_b = lora_b[:, start_idx:end_idx] return lora_b - def apply_weights(self, x: torch.Tensor) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer, x) + def apply(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index b3609666b2ec7..90f63c34fb2d3 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1,7 +1,7 @@ # pylint: disable=unused-argument import math from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -145,11 +145,15 @@ class LoRAMapping: class BaseLayerWithLoRA(nn.Module): - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + def slice_lora_a( + self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: """Slice lora a if splitting for tensor parallelism.""" ... - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + def slice_lora_b( + self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: """Slice lora b if splitting with tensor parallelism.""" ... @@ -539,10 +543,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): self.lora_b_stacked[0][index] = 0 self.lora_b_stacked[1][index] = 0 - def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: return lora_a - def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_b( + self, lora_b: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + if lora_b[0] is None or lora_b[1] is None: + return lora_b shard_size = self.output_dim start_idx = self.tp_rank * shard_size end_idx = (self.tp_rank + 1) * shard_size @@ -767,10 +777,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): self.lora_a_stacked[2][index] = 0 self.lora_b_stacked[2][index] = 0 - def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: return lora_a - def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_b( + self, lora_b: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + lora_b_q, lora_b_k, lora_b_v = None, None, None if lora_b[0] is not None: lora_b_q = lora_b[0][:, self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size * @@ -992,7 +1007,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): @property def weight(self): - return self.base_layer.weight if hasattr( self.base_layer, "weight") else self.base_layer.qweight