diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 473e4bedf3d60..3e9c2ceb83eac 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -542,10 +542,20 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): Both slices must have the same size. """ - def __init__(self, base_layer: MergedColumnParallelLinear) -> None: + def __init__( + self, base_layer: Union[MergedColumnParallelLinear, + QKVParallelLinear]) -> None: super().__init__(base_layer) # There are two LoRA layers - self.n_slices = len(self.base_layer.output_sizes) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + # the output_sizes in MergedColumnParallelLinear is not sharded by tp + # we need to divide it by the tp_size to get correct slices size + output_sizes = self.base_layer.output_sizes + self.output_slices = tuple( + divide(output_size, self.tp_size) for output_size in output_sizes) + self.n_slices = len(self.output_slices) + self.output_ids = (self.tp_rank, ) * self.n_slices def create_lora_weights( self, @@ -559,15 +569,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): """ self.lora_config = lora_config - if not (len(self.base_layer.output_sizes) == self.n_slices == 2 - and self.base_layer.output_sizes[0] - == self.base_layer.output_sizes[1]): - raise ValueError( - "LoRAColumnParallelLinear2Slice requires 2 slices with " - "the same size.") - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - lora_a_output_size_per_partition = ( lora_config.max_lora_rank if not lora_config.fully_sharded_loras else divide(lora_config.max_lora_rank, self.tp_size)) @@ -585,22 +586,20 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): torch.zeros( max_loras, 1, - self.output_size // 2, + output_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, device=self.device, - ) for _ in range(self.n_slices)) + ) for output_size in self.output_slices) if lora_config.bias_enabled: self.lora_bias_stacked = tuple( torch.zeros( max_loras, 1, - self.output_size // 2, + output_size, dtype=lora_config.lora_dtype, device=self.device, - ) for _ in range(self.n_slices)) - self.output_dim = self.lora_b_stacked[0].shape[2] - self.output_slices = (self.output_dim, self.output_dim) + ) for output_size in self.output_slices) def slice_lora_a( self, lora_a: List[Union[torch.Tensor, None]] @@ -610,27 +609,21 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): def slice_lora_b( self, lora_b: List[Union[torch.Tensor, None]] ) -> List[Union[torch.Tensor, None]]: - #NOTE: lora_b contains 2 subloras, and each sublora could be None. - shard_size = self.output_dim - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - lora_b = [ - lora_b[0][:, start_idx:end_idx] if lora_b[0] is not None else None, - lora_b[1][:, start_idx:end_idx] if lora_b[1] is not None else None, - ] + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (lora_b_i := lora_b[i]) is not None: + lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size * + (shard_id + 1)] return lora_b def slice_bias( self, bias: List[Union[torch.Tensor, None]]) -> List[Union[torch.Tensor, None]]: - # NOTE : each bias could be None. - shard_size = self.output_dim - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - bias = [ - bias[0][start_idx:end_idx] if bias[0] is not None else None, - bias[1][start_idx:end_idx] if bias[1] is not None else None - ] + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (bias_i := bias[i]) is not None: + bias[i] = bias_i[shard_size * shard_id:shard_size * + (shard_id + 1)] return bias def set_lora( @@ -649,30 +642,25 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): if lora_bias is not None: lora_bias = self.slice_bias(lora_bias) - if lora_a[0] is not None: - self.lora_a_stacked[0][ - index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( - lora_a[0].T, non_blocking=True) - self.lora_b_stacked[0][ - index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( - lora_b[0].T, non_blocking=True) - if lora_bias is not None and lora_bias[0] is not None: + for i in range(self.n_slices): + if (lora_a_i := lora_a[i]) is not None: + self.lora_a_stacked[i][ + index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( + lora_a_i.T, non_blocking=True) + if (lora_b_i := lora_b[i]) is not None: + self.lora_b_stacked[i][ + index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( + lora_b_i.T, non_blocking=True) + + if lora_bias is not None: self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], self.lora_bias_stacked) - self.lora_bias_stacked[0][index, 0, :lora_bias[0].shape[0]].copy_( - lora_bias[0].T, non_blocking=True) - if lora_a[1] is not None: - self.lora_a_stacked[1][ - index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( - lora_a[1].T, non_blocking=True) - self.lora_b_stacked[1][ - index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( - lora_b[1].T, non_blocking=True) - if lora_bias is not None and lora_bias[1] is not None: - self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], - self.lora_bias_stacked) - self.lora_bias_stacked[1][index, 0, :lora_bias[1].shape[0]].copy_( - lora_bias[1].T, non_blocking=True) + for i in range(self.n_slices): + if (lora_bias_i := lora_bias[i]) is not None: + self.lora_bias_stacked[i][index, + 0, :lora_bias_i.shape[0]].copy_( + lora_bias_i.T, + non_blocking=True) @classmethod @_not_fully_sharded_can_replace @@ -755,8 +743,8 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): packed_modules_list) == 1 -class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): - """ColumnParallelLinear layer that is composed of 3 sublayers (slices) +class MergedQKVParallelLinearWithLora(MergedColumnParallelLinearWithLoRA): + """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) packed together in qkv proj fashion (q_proj + k_proj + v_proj -> qkv_proj). @@ -773,6 +761,24 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas + + self.output_slices = ( + self.q_proj_shard_size, + self.kv_proj_shard_size, + self.kv_proj_shard_size, + ) + self.output_ids = ( + self.q_shard_id, + self.kv_shard_id, + self.kv_shard_id, + ) + def create_lora_weights( self, max_loras: int, @@ -783,216 +789,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): The main reason for overloading this function is to handle inconsistent weight dimensions in qkv lora. """ - self.lora_config = lora_config - - if not (len(self.base_layer.output_sizes) == self.n_slices == 3): - raise ValueError( - "LoRAColumnParallelLinear3Slice requires 3 slices.") - - self.q_proj_shard_size = (self.base_layer.num_heads * - self.base_layer.head_size) - self.kv_proj_shard_size = (self.base_layer.num_kv_heads * - self.base_layer.head_size) - self.q_shard_id = self.tp_rank - self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas - - lora_a_output_size_per_partition = ( - lora_config.max_lora_rank if not lora_config.fully_sharded_loras - else divide(lora_config.max_lora_rank, self.tp_size)) - # q, k, v - self.lora_a_stacked = ( - torch.zeros( - max_loras, - 1, - lora_a_output_size_per_partition, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - lora_a_output_size_per_partition, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - lora_a_output_size_per_partition, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - ) - self.lora_b_stacked = ( - torch.zeros( - max_loras, - 1, - self.q_proj_shard_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - self.kv_proj_shard_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - self.kv_proj_shard_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ), - ) - if lora_config.bias_enabled: - self.lora_bias_stacked = ( - torch.zeros( - max_loras, - 1, - self.q_proj_shard_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - self.kv_proj_shard_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - self.kv_proj_shard_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - ) - self.output_slices = ( - self.q_proj_shard_size, - self.kv_proj_shard_size, - self.kv_proj_shard_size, - ) - self.packed_indices: Optional[torch.Tensor] = None - self.standard_indices: Optional[torch.Tensor] = None - # lazily initialized. - self.indices: torch.Tensor - self.indices_len: List[int] - - def slice_lora_a( - self, lora_a: List[Union[torch.Tensor, None]] - ) -> List[Union[torch.Tensor, None]]: - return lora_a - - def slice_lora_b( - self, lora_b: List[Union[torch.Tensor, None]] - ) -> List[Union[torch.Tensor, None]]: - lora_b_q, lora_b_k, lora_b_v = None, None, None - if lora_b[0] is not None: - lora_b_q = lora_b[0][:, self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1), ] - if lora_b[1] is not None: - lora_b_k = lora_b[1][:, self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1), ] - if lora_b[2] is not None: - lora_b_v = lora_b[2][:, self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1), ] - lora_b = [lora_b_q, lora_b_k, lora_b_v] - return lora_b - - def slice_bias( - self, bias: List[Union[torch.Tensor, - None]]) -> List[Union[torch.Tensor, None]]: - bias_q, bias_k, bias_v = bias - if bias_q is not None: - bias_q = bias_q[self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] - if bias_k is not None: - bias_k = bias_k[self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] - if bias_v is not None: - bias_v = bias_v[self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] - bias = [bias_q, bias_k, bias_v] - return bias - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - lora_bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - - if self.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - if lora_bias is not None: - lora_bias = self.slice_bias(lora_bias) - - if lora_b[0] is not None: - lora_b_q = lora_b[0] - self.lora_b_stacked[0][ - index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( - lora_b_q.T, non_blocking=True) - if lora_b[1] is not None: - lora_b_k = lora_b[1] - self.lora_b_stacked[1][ - index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( - lora_b_k.T, non_blocking=True) - if lora_b[2] is not None: - lora_b_v = lora_b[2] - self.lora_b_stacked[2][ - index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( - lora_b_v.T, non_blocking=True) - - if lora_a[0] is not None: - self.lora_a_stacked[0][ - index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( - lora_a[0].T, non_blocking=True) - if lora_a[1] is not None: - self.lora_a_stacked[1][ - index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( - lora_a[1].T, non_blocking=True) - if lora_a[2] is not None: - self.lora_a_stacked[2][ - index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( - lora_a[2].T, non_blocking=True) - - if lora_bias is not None: - self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], - self.lora_bias_stacked) - if lora_bias[0] is not None: - self.lora_bias_stacked[0][index, - 0, :lora_bias[0].shape[0]].copy_( - lora_bias[0].T, - non_blocking=True) - if lora_bias[1] is not None: - self.lora_bias_stacked[1][index, - 0, :lora_bias[1].shape[0]].copy_( - lora_bias[1].T, - non_blocking=True) - if lora_bias[2] is not None: - self.lora_bias_stacked[2][index, - 0, :lora_bias[2].shape[0]].copy_( - lora_bias[2].T, - non_blocking=True) + super().create_lora_weights(max_loras, lora_config, model_config) @classmethod @_not_fully_sharded_can_replace