diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 891bc75fcdee0..b0038a28ed89d 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -12,20 +12,20 @@ import torch import torch.nn.functional as F from vllm.config import LoRAConfig -from vllm.lora.fully_sharded_layers import ( - ColumnParallelLinearWithShardedLoRA, - MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, - RowParallelLinearWithShardedLoRA) # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + ColumnParallelLinearWithShardedLoRA, LogitsProcessorWithLoRA, LoRAMapping, MergedColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA, ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, + RowParallelLinearWithShardedLoRA, VocabParallelEmbeddingWithLoRA) # yapf: enable from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py deleted file mode 100644 index 7fc4cfe026aee..0000000000000 --- a/vllm/lora/fully_sharded_layers.py +++ /dev/null @@ -1,355 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# pylint: disable=unused-argument -from typing import TYPE_CHECKING, Optional, Union, cast - -import torch -import torch.nn as nn -from transformers import PretrainedConfig - -from vllm.config import LoRAConfig -from vllm.distributed.communication_op import ( - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank -from vllm.lora.layers import (ColumnParallelLinearWithLoRA, - MergedColumnParallelLinearWithLoRA, - MergedQKVParallelLinearWithLoRA, - QKVParallelLinearWithLoRA, - RowParallelLinearWithLoRA) -from vllm.platforms import current_platform - -if TYPE_CHECKING: - pass - - -def _fully_sharded_can_replace(can_replace): - """ - decorator which adds the condition of fully sharded loras - intended to wrap can_replace_layer() - """ - - def dec(*args, **kwargs): - return (can_replace(*args, **kwargs) - and kwargs["lora_config"].fully_sharded_loras) - - return dec - - -def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA): - """ - For `ColumnParallelLinearWithLoRA` or classes that inherit from - `ColumnParallelLinearWithLoRA`, they share the same `apply` logic. - """ - assert (layer.n_slices == len(layer.lora_a_stacked) == len( - layer.lora_b_stacked) == len(layer.output_slices)) - if layer.lora_bias_stacked is not None: - assert layer.n_slices == len(layer.lora_bias_stacked) - - output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) - - x = x.view(-1, x.shape[-1]) - output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape - - # Since communication is needed, the buffer is directly initialized as a - # tensor rather than a tuple of tensor. - buffers = torch.zeros( - (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]), - dtype=torch.float32, - device=x.device, - ) - - shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink( - buffers, x, layer.lora_a_stacked, 1.0) - - if not current_platform.can_update_inplace(): - buffers = shrunk_buffers - - buffers = tensor_model_parallel_all_gather(buffers) - - lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand( - output, - buffers, - layer.lora_b_stacked, - layer.lora_bias_stacked, - layer.output_slices, - offset_start=0, - add_input=True) - - if not current_platform.can_update_inplace(): - output = lora_output - - output = output.view(*out_orig_shape) - # now have column partitioned and packed output - return output - - -# these layers are based on the tensor parallelism strategy given in -# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, -# https://arxiv.org/abs/2311.03285. - - -class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): - """ - Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. - - Based on S-LoRA, slicing happens along the rank dim. - """ - - # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`, - # their `lora_a` and `lora_b` have different sharding patterns. After - # completing the `lora_a` GEMM , a gather operation is performed. - # Therefore, the sharding of `lora_a` only needs to correspond with the - # gather operation. - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.lora_a_stacked[0].shape[2] - start_idx = tp_rank * shard_size - lora_a = lora_a[:, start_idx:start_idx + shard_size] - return lora_a - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _mcp_apply(x, bias, self) - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) - - -class MergedColumnParallelLinearWithShardedLoRA( - MergedColumnParallelLinearWithLoRA): - """ - Differs from MergedColumnParallelLinearWithLoRA by slicing the - LoRA A's also. - - Based on S-LoRA, slicing happens along the rank dim. - """ - - def slice_lora_a( - self, lora_a: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - #NOTE: lora_a contains 2 subloras, and each sublora could be None. - output_shard_size = self.lora_a_stacked[0].shape[2] - output_start_idx = self.tp_rank * output_shard_size - lora_a = [ - lora_a[0][:, output_start_idx:output_start_idx + - output_shard_size] if lora_a[0] is not None else None, - lora_a[1][:, output_start_idx:output_start_idx + - output_shard_size] if lora_a[1] is not None else None, - ] - return lora_a - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _mcp_apply(x, bias, self) - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) - - -class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): - """ - Differs from QKVParallelLinearWithLoRA by slicing the - LoRA A's also. - - Based on S-LoRA, slicing happens along the rank dim. - """ - - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.lora_a_stacked[0].shape[2] - start_idx = tp_rank * shard_size - lora_a = lora_a[:, start_idx:start_idx + shard_size] - return lora_a - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _mcp_apply(x, bias, self) - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig]) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) - - -class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): - """ - Differs from MergedQKVParallelLinearWithLoRA by slicing the - LoRA A's also. - - Based on S-LoRA, slicing happens along the rank dim. - """ - - def slice_lora_a( - self, lora_a: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - # NOTE: lora_a contains 3 subloras, and each sublora could be None. - shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] - start_idx = [self.tp_rank * shard_size[i] for i in range(3)] - lora_a = [ - lora_a[0][:, start_idx[0]:start_idx[0] + - shard_size[0]] if lora_a[0] is not None else None, - lora_a[1][:, start_idx[1]:start_idx[1] + - shard_size[1]] if lora_a[1] is not None else None, - lora_a[2][:, start_idx[2]:start_idx[2] + - shard_size[2]] if lora_a[2] is not None else None, - ] - return lora_a - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _mcp_apply(x, bias, self) - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) - - -class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): - """ - Differs from RowParallelLinearWithLoRA by slicing the - LoRA B's also. - - Based on S-LoRA, slicing happens along the output dim. - This yields a combined partial sum from the row parallel base - layer and column partitioned output from the LoRA. - """ - - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - shard_size = self.lora_b_stacked[0].shape[2] - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] - return lora_b - - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - if bias is None: - return bias - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) - shard_size = self.lora_bias_stacked[0].shape[2] - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - bias = bias[start_idx:end_idx] - return bias - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x) - - x = x.view(-1, x.shape[-1]) - output, out_orig_shape = output.view(-1, - output.shape[-1]), output.shape - buffer = torch.zeros( - (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]), - dtype=torch.float32, - device=x.device, - ) - - shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink( - buffer, x, self.lora_a_stacked, 1.0) - if not current_platform.can_update_inplace(): - buffer = shrunk_buffer - - buffer = tensor_model_parallel_all_reduce(buffer) - - # following S-LoRA, allows the fusing of all_gather and all_reduce - # by adding the column partitioned lora output to a slice of output - # tensor, which is a partial sum due to row parallel. All that - # remains is a standard all_reduce. User should be aware though that - # the output is not the same as a normal row_parallel, it should be - # reduced before being used - # NOTE offset are based on the rank. - shard_size = self.lora_b_stacked[0].shape[2] - offset_start = self.tp_rank * shard_size - lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand( - output, - buffer, - self.lora_b_stacked, - self.lora_bias_stacked, - self.output_slices, - offset_start=offset_start, - add_input=True, - ) - - if not current_platform.can_update_inplace(): - output = lora_output - - output = output.view(*out_orig_shape) - return output - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py deleted file mode 100644 index 6e4b69c303254..0000000000000 --- a/vllm/lora/layers.py +++ /dev/null @@ -1,1192 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# pylint: disable=unused-argument -import math -from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union, cast - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import PretrainedConfig - -from vllm.adapter_commons.layers import AdapterMapping -from vllm.config import LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) -from vllm.distributed.utils import divide -# yapf: disable -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -# yapf: enable -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) -from vllm.platforms import current_platform - -if TYPE_CHECKING: - from vllm.lora.punica_wrapper import PunicaWrapperBase - - -def _get_lora_device(base_layer: nn.Module) -> torch.device: - # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 - """Returns the device for where to place the LoRA tensors.""" - # unquantizedLinear - if hasattr(base_layer, "weight"): - return base_layer.weight.device - # Compressed Tensor - elif hasattr(base_layer, "weight_packed"): - return base_layer.weight_packed.device - # GPTQ/AWQ - elif hasattr(base_layer, "qweight"): - return base_layer.qweight.device - # HQQ marlin - elif hasattr(base_layer, "W_q"): - return base_layer.W_q.device - else: - raise ValueError(f"Unsupported base layer: {base_layer}") - - -def _not_fully_sharded_can_replace(can_replace): - """ - decorator which adds the condition of not using fully sharded loras - intended to wrap can_replace_layer() - """ - - def dec(*args, **kwargs): - decorate = kwargs.pop("decorate") if "decorate" in kwargs else True - condition = (not kwargs["lora_config"].fully_sharded_loras - if decorate else True) - return can_replace(*args, **kwargs) and condition - - return dec - - -@dataclass -class LoRAMapping(AdapterMapping): - is_prefill: bool = False - - -class BaseLayerWithLoRA(nn.Module): - - def slice_lora_a( - self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]] - ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: - """Slice lora a if splitting for tensor parallelism.""" - ... - - def slice_lora_b( - self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]] - ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: - """Slice lora b if splitting with tensor parallelism.""" - ... - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - """Initializes lora matrices.""" - ... - - def reset_lora(self, index: int): - """Resets the lora weights at index back to 0.""" - ... - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - """Overwrites lora tensors at index.""" - ... - - def set_mapping( - self, - punica_wrapper, - ): - self.punica_wrapper: PunicaWrapperBase = punica_wrapper - - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - """Returns True if the layer can be replaced by this LoRA layer.""" - raise NotImplementedError - - -class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): - - def __init__(self, base_layer: VocabParallelEmbedding) -> None: - super().__init__() - self.base_layer = base_layer - self.embeddings_slice: Optional[tuple[int, int]] - self.embeddings_weights: Optional[torch.Tensor] - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> None: - - if self.base_layer.num_added_embeddings_per_partition > 0: - # We can start adding lora weights - self.embeddings_weights = self.base_layer.weight.data[ - self.base_layer.num_org_embeddings_per_partition:self. - base_layer.num_org_embeddings_per_partition + - self.base_layer.num_added_embeddings_per_partition] - self.embeddings_slice = ( - self.base_layer.shard_indices.added_vocab_start_index - - self.base_layer.org_vocab_size, - self.base_layer.shard_indices.added_vocab_end_index - - self.base_layer.org_vocab_size) - self.base_layer.weight.data[ - self.base_layer.num_org_embeddings_per_partition:].fill_(0) - else: - self.embeddings_slice = None - self.embeddings_weights = None - - self.embeddings_tensors = torch.zeros( - ( - max_loras, - lora_config.lora_extra_vocab_size, - self.base_layer.embedding_dim, - ), - dtype=self.base_layer.weight.dtype, - device=self.base_layer.weight.device, - ) - self.lora_a_stacked = torch.zeros( - ( - max_loras, - self.base_layer.org_vocab_size + - lora_config.lora_extra_vocab_size, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, - ) - self.lora_b_stacked = torch.zeros( - ( - max_loras, - 1, - self.base_layer.embedding_dim, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, - ) - self.lora_a_stacked_2d = self.lora_a_stacked.view( - self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], - self.lora_a_stacked.shape[2], - ) - - def reset_lora(self, index: int): - self.lora_a_stacked[index] = 0 - self.lora_b_stacked[index] = 0 - self.embeddings_tensors[index] = 0 - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( - lora_a, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if embeddings_tensor is not None: - self.embeddings_tensors[ - index, - :embeddings_tensor.shape[0], - :embeddings_tensor.shape[1], - ].copy_(embeddings_tensor, non_blocking=True) - if self.embeddings_slice is not None: - # TODO(yard1): Optimize this copy, we don't need to copy - # everything, just the modified part - embeddings = self.embeddings_tensors.view( - self.embeddings_tensors.shape[0] * - self.embeddings_tensors.shape[1], - self.embeddings_tensors.shape[2], - )[self.embeddings_slice[0]:self.embeddings_slice[1]] - assert self.embeddings_weights is not None - self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, - 1, 0) - - # NB: Don't use torch.narrow here. torch.narrow triggers some - # Dynamic Shape specialization in torch.compile - num_tokens = x.shape[0] - indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens] - indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens] - - full_lora_a_embeddings = F.embedding( - x + indices_1, - self.lora_a_stacked_2d, - ) - full_output = self.base_layer.forward(x + - (indices_0 * added_tokens_mask)) - - full_output_org = full_output - if full_output.ndim == 3: - full_output = full_output.view( - full_output.shape[0] * full_output.shape[1], -1) - if full_lora_a_embeddings.ndim == 3: - full_lora_a_embeddings = full_lora_a_embeddings.view( - full_lora_a_embeddings.shape[0] * - full_lora_a_embeddings.shape[1], - -1, - ) - - lora_output: Optional[ - torch.Tensor] = self.punica_wrapper.add_lora_embedding( - full_output, - full_lora_a_embeddings, - self.lora_b_stacked, - add_input=True) - - if not current_platform.can_update_inplace(): - full_output = lora_output - - return full_output.view_as(full_output_org) - - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is VocabParallelEmbedding - - @property - def weight(self): - return self.base_layer.weight - - -class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): - - def __init__(self, base_layer: LinearBase): - super().__init__() - self.base_layer = base_layer - self.input_size = self.base_layer.input_size - self.device = _get_lora_device(self.base_layer) - self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None - - self.output_slices: tuple[int, ...] - self.tp_size: int - self.output_size: int - self.n_slices: int - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - self.lora_config = lora_config - # - if isinstance(self.base_layer, ReplicatedLinear): - lora_a_out_size = lora_config.max_lora_rank - lora_b_out_size = self.output_size - - elif isinstance(self.base_layer, ColumnParallelLinear): - lora_a_out_size = (lora_config.max_lora_rank if - not lora_config.fully_sharded_loras else divide( - lora_config.max_lora_rank, self.tp_size)) - lora_b_out_size = self.output_size - - elif isinstance(self.base_layer, RowParallelLinear): - lora_a_out_size = lora_config.max_lora_rank - lora_b_out_size = (self.output_size if - not lora_config.fully_sharded_loras else divide( - self.output_size, self.tp_size)) - else: - raise NotImplementedError - - self.lora_a_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_a_out_size, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) - self.lora_b_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_b_out_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) - if lora_config.bias_enabled: - lora_bias_out_size = lora_b_out_size - self.lora_bias_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_bias_out_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) - self.output_slices = (self.lora_b_stacked[0].shape[2], ) - - def reset_lora(self, index: int): - for s_index in range(self.n_slices): - self.lora_a_stacked[s_index][index] = 0 - self.lora_b_stacked[s_index][index] = 0 - if self.lora_config.bias_enabled: - # Make mypy happy - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) - self.lora_bias_stacked[s_index][index] = 0 - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - lora_bias: Optional[torch.Tensor] = None, - ): - # Except for QKVParallelLinearWithLoRA and - # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers - # store weights in a tuple of size 1. These two layers will - # override this function. - assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == - self.n_slices == 1) - - self.reset_lora(index) - if self.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - if lora_bias is not None: - lora_bias = self.slice_bias(lora_bias) - - self.lora_a_stacked[0][index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[0][index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if lora_bias is not None: - - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) - assert len(self.lora_bias_stacked) - self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( - lora_bias.T, non_blocking=True) - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - - # In transformers backend, x and output have extra batch dimension like - # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), - # therefore we need to flatten the batch dimensions. - if x.ndim == 3 and output.ndim == 3: - output = output.flatten(0, 1) - x = x.flatten(0, 1) - - lora_output: Optional[ - torch.Tensor] = self.punica_wrapper.add_lora_linear( - output, x, self.lora_a_stacked, self.lora_b_stacked, - self.lora_bias_stacked, 1.0, self.output_slices) - if not current_platform.can_update_inplace(): - output = lora_output - - return output - - @property - def weight(self) -> torch.Tensor: - - # unquantizedLinear - if hasattr(self.base_layer, "weight"): - return self.base_layer.weight - # Compressed Tensor - elif hasattr(self.base_layer, "weight_packed"): - return self.base_layer.weight_packed - # GPTQ/AWQ - elif hasattr(self.base_layer, "qweight"): - return self.base_layer.qweight - # marlin - elif hasattr(self.base_layer, "B"): - return self.base_layer.B - # HQQ marlin - elif hasattr(self.base_layer, "W_q"): - return self.base_layer.W_q - else: - raise ValueError(f"Unsupported base layer: {self.base_layer}") - - @property - def bias(self) -> Optional[torch.Tensor]: - if hasattr(self.base_layer, "bias"): - return self.base_layer.bias - else: - return None - - -class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): - - def __init__(self, base_layer: ReplicatedLinear) -> None: - super().__init__(base_layer, ) - # To ensure interface compatibility, set to 1 always. - self.tp_size = 1 - self.output_size = self.base_layer.output_size - self.n_slices = 1 - - def forward( - self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: - """Forward of ReplicatedLinearWithLoRA - - Args: - input_: Tensor whose last dimension is `input_size`. - - Returns: - - output - - bias - """ - bias = (self.base_layer.bias - if not self.base_layer.skip_bias_add else None) - - # Matrix multiply. - output = self.apply(input_, bias) - - output_bias = (self.base_layer.bias - if self.base_layer.skip_bias_add else None) - - if not self.base_layer.return_bias: - return output - - return output, output_bias - - # ReplicatedLinear should always be replaced, regardless of the fully - # sharded LoRAs setting, because it is, by definition, copied per GPU. - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is ReplicatedLinear - - -class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): - """ - LoRA on top of ColumnParallelLinear layer. - LoRA B is sliced for tensor parallelism. - There are two types for the `base_layer`: - 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`. - 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`. - """ - - def __init__(self, base_layer: ColumnParallelLinear) -> None: - super().__init__(base_layer) - # The base_layer type is ColumnParallelLinear or - # MergedColumnParallelLinear, their weight sharding logic is - # inconsistent when TP is greater than 1. - self.is_merged_col_linear = type( - base_layer) is MergedColumnParallelLinear - self.tp_size = get_tensor_model_parallel_world_size() - self.output_size = self.base_layer.output_size_per_partition - # There is only one LoRA layer - self.n_slices = 1 - - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - return lora_a - - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - # Applicable to cases where the base_layer is - # MergedColumnParallelLinear. - if self.is_merged_col_linear: - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.output_size // 2 - offset = lora_b.shape[-1] // 2 - - left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * - shard_size] - right_weight = lora_b[:, offset + tp_rank * shard_size:offset + - (tp_rank + 1) * shard_size] - lora_b = torch.cat([left_weight, right_weight], dim=1) - # Applicable to cases where the base_layer is - # ColumnParallelLinear. - else: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] - return lora_b - - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - # TODO: Fix the slicing logic of bias. - if bias is None: - return bias - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - bias = bias[start_idx:end_idx] - return bias - - def forward( - self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: - """Forward of ColumnParallelLinear - - Args: - input_: Tensor whose last dimension is `input_size`. - - Returns: - - output - - bias - """ - bias = (self.base_layer.bias - if not self.base_layer.skip_bias_add else None) - - # Matrix multiply. - output_parallel = self.apply(input_, bias) - if self.base_layer.gather_output: - # All-gather across the partitions. - output = tensor_model_parallel_all_gather(output_parallel) - else: - output = output_parallel - - if not self.base_layer.return_bias: - return output - - output_bias = (self.base_layer.bias - if self.base_layer.skip_bias_add else None) - return output, output_bias - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is ColumnParallelLinear or ( - type(source_layer) is MergedColumnParallelLinear - and len(packed_modules_list) == 1) - - -class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): - """ColumnParallelLinear layer that is composed of 2 sublayers (slices) - packed together (e.g. gate_proj + up_proj -> gate_up_proj). - - This means we have 2 LoRAs, each applied to one half of the layer. - - Both slices must have the same size. - """ - - def __init__( - self, base_layer: Union[MergedColumnParallelLinear, - QKVParallelLinear]) -> None: - super().__init__(base_layer) - # There are two LoRA layers - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - # the output_sizes in MergedColumnParallelLinear is not sharded by tp - # we need to divide it by the tp_size to get correct slices size - output_sizes = self.base_layer.output_sizes - self.output_slices = tuple( - divide(output_size, self.tp_size) for output_size in output_sizes) - self.n_slices = len(self.output_slices) - self.output_ids = (self.tp_rank, ) * self.n_slices - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - """ - The main reason for overriding this function is to enhance code - maintainability. - """ - self.lora_config = lora_config - - lora_a_output_size_per_partition = ( - lora_config.max_lora_rank if not lora_config.fully_sharded_loras - else divide(lora_config.max_lora_rank, self.tp_size)) - - self.lora_a_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_a_output_size_per_partition, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) - self.lora_b_stacked = tuple( - torch.zeros( - max_loras, - 1, - output_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ) for output_size in self.output_slices) - if lora_config.bias_enabled: - self.lora_bias_stacked = tuple( - torch.zeros( - max_loras, - 1, - output_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for output_size in self.output_slices) - - def slice_lora_a( - self, lora_a: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - return lora_a - - def slice_lora_b( - self, lora_b: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - sliced_lora_b = [None] * self.n_slices - for i, (shard_id, shard_size) in enumerate( - zip(self.output_ids, self.output_slices)): - if (lora_b_i := lora_b[i]) is not None: - sliced_lora_b[i] = lora_b_i[:, - shard_size * shard_id:shard_size * - (shard_id + 1)] - return sliced_lora_b - - def slice_bias( - self, bias: list[Union[torch.Tensor, - None]]) -> list[Union[torch.Tensor, None]]: - for i, (shard_id, shard_size) in enumerate( - zip(self.output_ids, self.output_slices)): - if (bias_i := bias[i]) is not None: - bias[i] = bias_i[shard_size * shard_id:shard_size * - (shard_id + 1)] - return bias - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - lora_bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - - if self.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - if lora_bias is not None: - lora_bias = self.slice_bias(lora_bias) - - for i in range(self.n_slices): - if (lora_a_i := lora_a[i]) is not None: - self.lora_a_stacked[i][ - index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( - lora_a_i.T, non_blocking=True) - if (lora_b_i := lora_b[i]) is not None: - self.lora_b_stacked[i][ - index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( - lora_b_i.T, non_blocking=True) - - if lora_bias is not None: - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) - for i in range(self.n_slices): - if (lora_bias_i := lora_bias[i]) is not None: - self.lora_bias_stacked[i][index, - 0, :lora_bias_i.shape[0]].copy_( - lora_bias_i.T, - non_blocking=True) - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return (type(source_layer) is MergedColumnParallelLinear - and len(packed_modules_list) == 2) - - -class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): - """ - ColumnParallelLinear layer that is specifically designed for - qkv_proj. Certain models, such as chatglm3 and baichuan-7b, - only contains a single LoRA within their qkv_proj layer. - - During inference with Tensor Parallel, the weights of lora_b - must be accurately partitioned according to the respective ranks. - - Q slice may have different shape than K and V slices (which both have - the same shape). - """ - - def __init__(self, base_layer: QKVParallelLinear) -> None: - super().__init__(base_layer) - self.q_proj_total_size = (self.base_layer.total_num_heads * - self.base_layer.head_size) - self.q_proj_shard_size = (self.base_layer.num_heads * - self.base_layer.head_size) - self.kv_proj_shard_size = (self.base_layer.num_kv_heads * - self.base_layer.head_size) - self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * - self.base_layer.head_size) - # There is only one LoRA layer - self.n_slices = 1 - - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() - self.q_shard_id = tp_rank - self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas - lora_b_q = lora_b[:, self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] - k_offset = self.q_proj_total_size - lora_b_k = lora_b[:, k_offset + - self.kv_proj_shard_size * self.kv_shard_id:k_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - v_offset = k_offset + self.kv_proj_total_size - lora_b_v = lora_b[:, v_offset + - self.kv_proj_shard_size * self.kv_shard_id:v_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) - return lora_b - - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - bias_q = bias[self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] - k_offset = self.q_proj_total_size - bias_k = bias[k_offset + - self.kv_proj_shard_size * self.kv_shard_id:k_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - v_offset = k_offset + self.kv_proj_total_size - bias_v = bias[v_offset + - self.kv_proj_shard_size * self.kv_shard_id:v_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - bias = torch.cat([bias_q, bias_k, bias_v], dim=1) - return bias - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig]) -> bool: - return type(source_layer) is QKVParallelLinear and len( - packed_modules_list) == 1 - - -class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): - """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) - packed together in qkv proj fashion - (q_proj + k_proj + v_proj -> qkv_proj). - - This means we have 3 LoRAs, each applied to one slice of the layer. - - Q slice may have different shape than K and V slices (which both have - the same shape). - """ - - def __init__(self, base_layer: QKVParallelLinear) -> None: - super().__init__(base_layer) - # There are three LoRA layer. - self.n_slices = len(self.base_layer.output_sizes) - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - self.q_proj_shard_size = (self.base_layer.num_heads * - self.base_layer.head_size) - self.kv_proj_shard_size = (self.base_layer.num_kv_heads * - self.base_layer.head_size) - self.q_shard_id = self.tp_rank - self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas - - self.output_slices = ( - self.q_proj_shard_size, - self.kv_proj_shard_size, - self.kv_proj_shard_size, - ) - self.output_ids = ( - self.q_shard_id, - self.kv_shard_id, - self.kv_shard_id, - ) - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - """ - The main reason for overloading this function is to handle inconsistent - weight dimensions in qkv lora. - """ - super().create_lora_weights(max_loras, lora_config, model_config) - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return (type(source_layer) is QKVParallelLinear - and len(packed_modules_list) == 3) - - -#TODO: Implement this -class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA): - pass - - -class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): - - def __init__(self, base_layer: RowParallelLinear) -> None: - super().__init__(base_layer) - - self.tp_size = get_tensor_model_parallel_world_size() - # reset input_size - self.input_size = self.base_layer.input_size_per_partition - self.output_size = self.base_layer.output_size - - self.tp_rank = get_tensor_model_parallel_rank() - # There is only one LoRA layer. - self.n_slices = 1 - - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - - shard_size = self.input_size - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - lora_a = lora_a[start_idx:end_idx, :] - return lora_a - - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - return lora_b - - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - return bias - - def forward( - self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: - """Forward of RowParallelLinear - - Args: - input_: tensor whose last dimension is `input_size`. If - `input_is_parallel` is set, then the last dimension - is `input_size // tp_size`. - - Returns: - - output - - bias - """ - # set up backprop all-reduce. - if self.base_layer.input_is_parallel: - input_parallel = input_ - else: - # TODO: simplify code below - splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.base_layer.tp_size) - input_parallel = splitted_input[self.tp_rank].contiguous() - - # Matrix multiply. - output_parallel = self.apply(input_parallel) - if self.base_layer.reduce_results and self.base_layer.tp_size > 1: - output_ = tensor_model_parallel_all_reduce(output_parallel) - else: - output_ = output_parallel - - if not self.base_layer.skip_bias_add: - output = (output_ + self.base_layer.bias - if self.base_layer.bias is not None else output_) - output_bias = None - else: - output = output_ - output_bias = self.base_layer.bias - - if not self.base_layer.return_bias: - return output - - return output, output_bias - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is RowParallelLinear - - -class LogitsProcessorWithLoRA(BaseLayerWithLoRA): - """ - LoRA wrapper for LogitsProcessor, with extra logic to handle the - application of the LoRA adapter and added LoRA vocabulary. - - Args: - base_layer: LogitsProcessor layer - hidden_size: hidden size of the model - dtype: data type of the model - device: device of the model - sharded_to_full_mapping: index mapping from sharded vocab to full vocab - received from base_layer.get_sharded_to_full_mapping(). If None, - no reindexing will be done. - """ - - def __init__(self, base_layer: LogitsProcessor, hidden_size: int, - dtype: torch.dtype, device: torch.device, - sharded_to_full_mapping: Optional[list[int]]) -> None: - super().__init__() - self.base_layer = base_layer - self.hidden_size = hidden_size - self.dtype = dtype - self.device = device - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.sharded_to_full_mapping = sharded_to_full_mapping - - @property - def logits_as_input(self): - return self.base_layer.logits_as_input - - @property - def vocab_size(self): - return self.base_layer.vocab_size - - @property - def scale(self): - return self.base_layer.scale - - @property - def soft_cap(self): - return self.base_layer.soft_cap - - @property - def use_all_gather(self): - return self.base_layer.use_all_gather - - @property - def org_vocab_size(self): - return self.base_layer.org_vocab_size - - @property - def include_gpu_probs_tensor(self): - return self.base_layer.include_gpu_probs_tensor - - @property - def should_modify_greedy_probs_inplace(self): - return self.base_layer.should_modify_greedy_probs_inplace - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - # TODO: Verify if this condition can be further relaxed - if 32000 < self.base_layer.vocab_size > 257024: - raise ValueError("When using LoRA, vocab size must be " - "32000 >= vocab_size <= 257024") - self.lora_a_stacked = torch.zeros( - ( - max_loras, - 1, - lora_config.max_lora_rank, - self.hidden_size, - ), - dtype=lora_config.lora_dtype, - device=self.device, - ) - self.lora_b_stacked = torch.zeros( - ( - max_loras, - 1, - # Pad for kernel compatibility - math.ceil(self.base_layer.vocab_size / - lora_config.lora_vocab_padding_size) * - lora_config.lora_vocab_padding_size, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.device, - ) - self.embeddings_tensors = torch.full( - (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), - fill_value=float("-inf"), - dtype=self.dtype, - device=self.device, - ) - if self.sharded_to_full_mapping is not None: - self.sharded_to_full_mapping_gpu = torch.tensor( - self.sharded_to_full_mapping, - device=self.device, - dtype=torch.long) - else: - self.sharded_to_full_mapping_gpu = None - - def reset_lora(self, index: int): - self.lora_a_stacked[index] = 0 - self.lora_b_stacked[index] = 0 - self.embeddings_tensors[index] = float("-inf") - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - self.lora_a_stacked[index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if embeddings_tensor is not None: - self.embeddings_tensors[ - index, - :embeddings_tensor.shape[0], - :embeddings_tensor.shape[1], - ] = embeddings_tensor - - def _get_logits( - self, - hidden_states: torch.Tensor, - lm_head: VocabParallelEmbedding, - embedding_bias: Optional[torch.Tensor] = None, - ) -> Optional[torch.Tensor]: - # Get the logits for the next tokens. - logits = lm_head.quant_method.apply(lm_head, hidden_states) - if embedding_bias is not None: - logits += embedding_bias - - # Gather logits for TP - logits = self.base_layer._gather_logits(logits) - - if logits is None: - return None - - if self.sharded_to_full_mapping_gpu is not None: - # Reindex full logits tensor to ensure 1:1 mapping between - # index and token_id - # Example for: - # org_vocab_size = 4 - # added_vocab_size = 2 - # pad_to_size = 8 - # tp_size = 2 - - # indices: [0, 1, 2, 3, 4, 5, 6, 7] - # token_id: [0, 1, 4, -1, 2, 3, 5, -1] - - # Therefore, the mapping is expected to be: - # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, - # we get: - # indices: [0, 1, 2, 3, 4, 5, 6, 7] - # token_id: [0, 1, 2, 3, 4, 5, -1, -1] - logits = logits[:, self.sharded_to_full_mapping_gpu] - - lora_logits = torch.empty( - self.embeddings_tensors.shape[0] + 1, - self.embeddings_tensors.shape[1], - hidden_states.shape[0], - dtype=self.embeddings_tensors.dtype, - device=self.embeddings_tensors.device, - ) - torch.matmul(self.embeddings_tensors, - hidden_states.T, - out=lora_logits[:-1]) - - neg_inf, pos_inf = current_platform.get_infinity_values( - lora_logits.dtype) - - lora_logits[-1] = neg_inf - lora_logits = lora_logits.mT - indices_padded = self.punica_wrapper.sampler_indices_padded - - if current_platform.is_tpu() or current_platform.is_xpu(): - indices_padded = indices_padded[:logits.size(0)] - - lora_logits = (lora_logits.reshape( - lora_logits.shape[0] * lora_logits.shape[1], - lora_logits.shape[2], - ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, - posinf=pos_inf, - neginf=neg_inf)) - - logits[:, - self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + - lora_logits.shape[1]] = lora_logits - - lora_output: Optional[ - torch.Tensor] = self.punica_wrapper.add_lora_logits( - logits, hidden_states, self.lora_a_stacked, - self.lora_b_stacked, 1.0) - - if not current_platform.can_update_inplace(): - logits = lora_output - - # Remove paddings in vocab (if any). - logits = logits[:, :self.base_layer.vocab_size] - return logits - - def forward(self, *args, **kwargs): - return type(self.base_layer).forward(self, *args, **kwargs) - - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # Special handling for the LogitsProcessor. - return False diff --git a/vllm/lora/layers/__init__.py b/vllm/lora/layers/__init__.py new file mode 100644 index 0000000000000..d3bb145dc7bf8 --- /dev/null +++ b/vllm/lora/layers/__init__.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.lora.layers.base import BaseLayerWithLoRA +from vllm.lora.layers.column_parallel_linear import ( + ColumnParallelLinearWithLoRA, ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA) +from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA +from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA +from vllm.lora.layers.row_parallel_linear import ( + RowParallelLinearWithLoRA, RowParallelLinearWithShardedLoRA) +from vllm.lora.layers.utils import LoRAMapping +from vllm.lora.layers.vocal_parallel_embedding import ( + VocabParallelEmbeddingWithLoRA) + +__all__ = [ + "BaseLayerWithLoRA", + "VocabParallelEmbeddingWithLoRA", + "LogitsProcessorWithLoRA", + "ColumnParallelLinearWithLoRA", + "ColumnParallelLinearWithShardedLoRA", + "MergedColumnParallelLinearWithLoRA", + "MergedColumnParallelLinearWithShardedLoRA", + "MergedQKVParallelLinearWithLoRA", + "MergedQKVParallelLinearWithShardedLoRA", + "QKVParallelLinearWithLoRA", + "QKVParallelLinearWithShardedLoRA", + "RowParallelLinearWithLoRA", + "RowParallelLinearWithShardedLoRA", + "ReplicatedLinearWithLoRA", + "LoRAMapping", +] diff --git a/vllm/lora/layers/base.py b/vllm/lora/layers/base.py new file mode 100644 index 0000000000000..0e759d5d5719b --- /dev/null +++ b/vllm/lora/layers/base.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING, Optional, Union + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig + +if TYPE_CHECKING: + from vllm.lora.punica_wrapper import PunicaWrapperBase + + +class BaseLayerWithLoRA(nn.Module): + + def slice_lora_a( + self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: + """Slice lora a if splitting for tensor parallelism.""" + ... + + def slice_lora_b( + self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: + """Slice lora b if splitting with tensor parallelism.""" + ... + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """Initializes lora matrices.""" + ... + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + punica_wrapper, + ): + self.punica_wrapper: PunicaWrapperBase = punica_wrapper + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + raise NotImplementedError diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py new file mode 100644 index 0000000000000..4e062971d9188 --- /dev/null +++ b/vllm/lora/layers/base_linear.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, cast + +import torch +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.distributed.utils import divide +# yapf: disable +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, ReplicatedLinear, + RowParallelLinear) +from vllm.platforms import current_platform + +from .base import BaseLayerWithLoRA +from .utils import _get_lora_device + + +class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: LinearBase): + super().__init__() + self.base_layer = base_layer + self.input_size = self.base_layer.input_size + self.device = _get_lora_device(self.base_layer) + self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None + + self.output_slices: tuple[int, ...] + self.tp_size: int + self.output_size: int + self.n_slices: int + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + # + if isinstance(self.base_layer, ReplicatedLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, ColumnParallelLinear): + lora_a_out_size = (lora_config.max_lora_rank if + not lora_config.fully_sharded_loras else divide( + lora_config.max_lora_rank, self.tp_size)) + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, RowParallelLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = (self.output_size if + not lora_config.fully_sharded_loras else divide( + self.output_size, self.tp_size)) + else: + raise NotImplementedError + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_out_size, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_b_out_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + if lora_config.bias_enabled: + lora_bias_out_size = lora_b_out_size + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_bias_out_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.output_slices = (self.lora_b_stacked[0].shape[2], ) + + def reset_lora(self, index: int): + for s_index in range(self.n_slices): + self.lora_a_stacked[s_index][index] = 0 + self.lora_b_stacked[s_index][index] = 0 + if self.lora_config.bias_enabled: + # Make mypy happy + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + self.lora_bias_stacked[s_index][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + # Except for QKVParallelLinearWithLoRA and + # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers + # store weights in a tuple of size 1. These two layers will + # override this function. + assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == + self.n_slices == 1) + + self.reset_lora(index) + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + self.lora_a_stacked[0][index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[0][index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if lora_bias is not None: + + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + assert len(self.lora_bias_stacked) + self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( + lora_bias.T, non_blocking=True) + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + + # In transformers backend, x and output have extra batch dimension like + # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), + # therefore we need to flatten the batch dimensions. + if x.ndim == 3 and output.ndim == 3: + output = output.flatten(0, 1) + x = x.flatten(0, 1) + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_linear( + output, x, self.lora_a_stacked, self.lora_b_stacked, + self.lora_bias_stacked, 1.0, self.output_slices) + if not current_platform.can_update_inplace(): + output = lora_output + + return output + + @property + def weight(self) -> torch.Tensor: + + # unquantizedLinear + if hasattr(self.base_layer, "weight"): + return self.base_layer.weight + # Compressed Tensor + elif hasattr(self.base_layer, "weight_packed"): + return self.base_layer.weight_packed + # GPTQ/AWQ + elif hasattr(self.base_layer, "qweight"): + return self.base_layer.qweight + # marlin + elif hasattr(self.base_layer, "B"): + return self.base_layer.B + # HQQ marlin + elif hasattr(self.base_layer, "W_q"): + return self.base_layer.W_q + else: + raise ValueError(f"Unsupported base layer: {self.base_layer}") + + @property + def bias(self) -> Optional[torch.Tensor]: + if hasattr(self.base_layer, "bias"): + return self.base_layer.bias + else: + return None diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py new file mode 100644 index 0000000000000..d2f8e05554c84 --- /dev/null +++ b/vllm/lora/layers/column_parallel_linear.py @@ -0,0 +1,622 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union, cast + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) +from vllm.distributed.utils import divide +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear) +from vllm.platforms import current_platform + +from .base_linear import BaseLinearLayerWithLoRA +from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace + + +def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"): + """ + For `ColumnParallelLinearWithLoRA` or classes that inherit from + `ColumnParallelLinearWithLoRA`, they share the same `apply` logic. + """ + assert (layer.n_slices == len(layer.lora_a_stacked) == len( + layer.lora_b_stacked) == len(layer.output_slices)) + if layer.lora_bias_stacked is not None: + assert layer.n_slices == len(layer.lora_bias_stacked) + + output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + + # Since communication is needed, the buffer is directly initialized as a + # tensor rather than a tuple of tensor. + buffers = torch.zeros( + (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink( + buffers, x, layer.lora_a_stacked, 1.0) + + if not current_platform.can_update_inplace(): + buffers = shrunk_buffers + + buffers = tensor_model_parallel_all_gather(buffers) + + lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand( + output, + buffers, + layer.lora_b_stacked, + layer.lora_bias_stacked, + layer.output_slices, + offset_start=0, + add_input=True) + + if not current_platform.can_update_inplace(): + output = lora_output + + output = output.view(*out_orig_shape) + # now have column partitioned and packed output + return output + + +class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + """ + LoRA on top of ColumnParallelLinear layer. + LoRA B is sliced for tensor parallelism. + There are two types for the `base_layer`: + 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`. + 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`. + """ + + def __init__(self, base_layer: ColumnParallelLinear) -> None: + super().__init__(base_layer) + # The base_layer type is ColumnParallelLinear or + # MergedColumnParallelLinear, their weight sharding logic is + # inconsistent when TP is greater than 1. + self.is_merged_col_linear = type( + base_layer) is MergedColumnParallelLinear + self.tp_size = get_tensor_model_parallel_world_size() + self.output_size = self.base_layer.output_size_per_partition + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + # Applicable to cases where the base_layer is + # MergedColumnParallelLinear. + if self.is_merged_col_linear: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size // 2 + offset = lora_b.shape[-1] // 2 + + left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * + shard_size] + right_weight = lora_b[:, offset + tp_rank * shard_size:offset + + (tp_rank + 1) * shard_size] + lora_b = torch.cat([left_weight, right_weight], dim=1) + # Applicable to cases where the base_layer is + # ColumnParallelLinear. + else: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + # TODO: Fix the slicing logic of bias. + if bias is None: + return bias + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output_parallel = self.apply(input_, bias) + if self.base_layer.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + + if not self.base_layer.return_bias: + return output + + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ColumnParallelLinear or ( + type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 1) + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (e.g. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__( + self, base_layer: Union[MergedColumnParallelLinear, + QKVParallelLinear]) -> None: + super().__init__(base_layer) + # There are two LoRA layers + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + # the output_sizes in MergedColumnParallelLinear is not sharded by tp + # we need to divide it by the tp_size to get correct slices size + output_sizes = self.base_layer.output_sizes + self.output_slices = tuple( + divide(output_size, self.tp_size) for output_size in output_sizes) + self.n_slices = len(self.output_slices) + self.output_ids = (self.tp_rank, ) * self.n_slices + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overriding this function is to enhance code + maintainability. + """ + self.lora_config = lora_config + + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_output_size_per_partition, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for output_size in self.output_slices) + if lora_config.bias_enabled: + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for output_size in self.output_slices) + + def slice_lora_a( + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + return lora_a + + def slice_lora_b( + self, lora_b: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + sliced_lora_b = [None] * self.n_slices + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (lora_b_i := lora_b[i]) is not None: + sliced_lora_b[i] = lora_b_i[:, + shard_size * shard_id:shard_size * + (shard_id + 1)] + return sliced_lora_b + + def slice_bias( + self, bias: list[Union[torch.Tensor, + None]]) -> list[Union[torch.Tensor, None]]: + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (bias_i := bias[i]) is not None: + bias[i] = bias_i[shard_size * shard_id:shard_size * + (shard_id + 1)] + return bias + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + for i in range(self.n_slices): + if (lora_a_i := lora_a[i]) is not None: + self.lora_a_stacked[i][ + index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( + lora_a_i.T, non_blocking=True) + if (lora_b_i := lora_b[i]) is not None: + self.lora_b_stacked[i][ + index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( + lora_b_i.T, non_blocking=True) + + if lora_bias is not None: + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + for i in range(self.n_slices): + if (lora_bias_i := lora_bias[i]) is not None: + self.lora_bias_stacked[i][index, + 0, :lora_bias_i.shape[0]].copy_( + lora_bias_i.T, + non_blocking=True) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 2) + + +class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ + ColumnParallelLinear layer that is specifically designed for + qkv_proj. Certain models, such as chatglm3 and baichuan-7b, + only contains a single LoRA within their qkv_proj layer. + + During inference with Tensor Parallel, the weights of lora_b + must be accurately partitioned according to the respective ranks. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + self.q_proj_total_size = (self.base_layer.total_num_heads * + self.base_layer.head_size) + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * + self.base_layer.head_size) + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + lora_b_q = lora_b[:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + lora_b_k = lora_b[:, k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + lora_b_v = lora_b[:, v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + bias_q = bias[self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + bias_k = bias[k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + bias_v = bias[v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + bias = torch.cat([bias_q, bias_k, bias_v], dim=1) + return bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: list, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is QKVParallelLinear and len( + packed_modules_list) == 1 + + +class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): + """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) + packed together in qkv proj fashion + (q_proj + k_proj + v_proj -> qkv_proj). + + This means we have 3 LoRAs, each applied to one slice of the layer. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + # There are three LoRA layer. + self.n_slices = len(self.base_layer.output_sizes) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas + + self.output_slices = ( + self.q_proj_shard_size, + self.kv_proj_shard_size, + self.kv_proj_shard_size, + ) + self.output_ids = ( + self.q_shard_id, + self.kv_shard_id, + self.kv_shard_id, + ) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overloading this function is to handle inconsistent + weight dimensions in qkv lora. + """ + super().create_lora_weights(max_loras, lora_config, model_config) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is QKVParallelLinear + and len(packed_modules_list) == 3) + + +# These following layers are based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + + +class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): + """ + Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`, + # their `lora_a` and `lora_b` have different sharding patterns. After + # completing the `lora_a` GEMM , a gather operation is performed. + # Therefore, the sharding of `lora_a` only needs to correspond with the + # gather operation. + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked[0].shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedColumnParallelLinearWithShardedLoRA( + MergedColumnParallelLinearWithLoRA): + """ + Differs from MergedColumnParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + #NOTE: lora_a contains 2 subloras, and each sublora could be None. + output_shard_size = self.lora_a_stacked[0].shape[2] + output_start_idx = self.tp_rank * output_shard_size + lora_a = [ + lora_a[0][:, output_start_idx:output_start_idx + + output_shard_size] if lora_a[0] is not None else None, + lora_a[1][:, output_start_idx:output_start_idx + + output_shard_size] if lora_a[1] is not None else None, + ] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): + """ + Differs from QKVParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked[0].shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: list, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): + """ + Differs from MergedQKVParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + # NOTE: lora_a contains 3 subloras, and each sublora could be None. + shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] + start_idx = [self.tp_rank * shard_size[i] for i in range(3)] + lora_a = [ + lora_a[0][:, start_idx[0]:start_idx[0] + + shard_size[0]] if lora_a[0] is not None else None, + lora_a[1][:, start_idx[1]:start_idx[1] + + shard_size[1]] if lora_a[1] is not None else None, + lora_a[2][:, start_idx[2]:start_idx[2] + + shard_size[2]] if lora_a[2] is not None else None, + ] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py new file mode 100644 index 0000000000000..db974147ccca7 --- /dev/null +++ b/vllm/lora/layers/logits_processor.py @@ -0,0 +1,247 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.platforms import current_platform + +from .base import BaseLayerWithLoRA + + +class LogitsProcessorWithLoRA(BaseLayerWithLoRA): + """ + LoRA wrapper for LogitsProcessor, with extra logic to handle the + application of the LoRA adapter and added LoRA vocabulary. + + Args: + base_layer: LogitsProcessor layer + hidden_size: hidden size of the model + dtype: data type of the model + device: device of the model + sharded_to_full_mapping: index mapping from sharded vocab to full vocab + received from base_layer.get_sharded_to_full_mapping(). If None, + no reindexing will be done. + """ + + def __init__(self, base_layer: LogitsProcessor, hidden_size: int, + dtype: torch.dtype, device: torch.device, + sharded_to_full_mapping: Optional[list[int]]) -> None: + super().__init__() + self.base_layer = base_layer + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.sharded_to_full_mapping = sharded_to_full_mapping + + @property + def logits_as_input(self): + return self.base_layer.logits_as_input + + @property + def vocab_size(self): + return self.base_layer.vocab_size + + @property + def scale(self): + return self.base_layer.scale + + @property + def soft_cap(self): + return self.base_layer.soft_cap + + @property + def use_all_gather(self): + return self.base_layer.use_all_gather + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def include_gpu_probs_tensor(self): + return self.base_layer.include_gpu_probs_tensor + + @property + def should_modify_greedy_probs_inplace(self): + return self.base_layer.should_modify_greedy_probs_inplace + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + # TODO: Verify if this condition can be further relaxed + if 32000 < self.base_layer.vocab_size > 257024: + raise ValueError("When using LoRA, vocab size must be " + "32000 >= vocab_size <= 257024") + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + # Pad for kernel compatibility + math.ceil(self.base_layer.vocab_size / + lora_config.lora_vocab_padding_size) * + lora_config.lora_vocab_padding_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.embeddings_tensors = torch.full( + (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), + fill_value=float("-inf"), + dtype=self.dtype, + device=self.device, + ) + if self.sharded_to_full_mapping is not None: + self.sharded_to_full_mapping_gpu = torch.tensor( + self.sharded_to_full_mapping, + device=self.device, + dtype=torch.long) + else: + self.sharded_to_full_mapping_gpu = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = float("-inf") + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, + :embeddings_tensor.shape[0], + :embeddings_tensor.shape[1], + ] = embeddings_tensor + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + # Get the logits for the next tokens. + logits = lm_head.quant_method.apply(lm_head, hidden_states) + if embedding_bias is not None: + logits += embedding_bias + + # Gather logits for TP + logits = self.base_layer._gather_logits(logits) + + if logits is None: + return None + + if self.sharded_to_full_mapping_gpu is not None: + # Reindex full logits tensor to ensure 1:1 mapping between + # index and token_id + # Example for: + # org_vocab_size = 4 + # added_vocab_size = 2 + # pad_to_size = 8 + # tp_size = 2 + + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 4, -1, 2, 3, 5, -1] + + # Therefore, the mapping is expected to be: + # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, + # we get: + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 2, 3, 4, 5, -1, -1] + logits = logits[:, self.sharded_to_full_mapping_gpu] + + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, + hidden_states.T, + out=lora_logits[:-1]) + + neg_inf, pos_inf = current_platform.get_infinity_values( + lora_logits.dtype) + + lora_logits[-1] = neg_inf + lora_logits = lora_logits.mT + indices_padded = self.punica_wrapper.sampler_indices_padded + + if current_platform.is_tpu() or current_platform.is_xpu(): + indices_padded = indices_padded[:logits.size(0)] + + lora_logits = (lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, + posinf=pos_inf, + neginf=neg_inf)) + + logits[:, + self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + + lora_logits.shape[1]] = lora_logits + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_logits( + logits, hidden_states, self.lora_a_stacked, + self.lora_b_stacked, 1.0) + + if not current_platform.can_update_inplace(): + logits = lora_output + + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + return logits + + def forward(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # Special handling for the LogitsProcessor. + return False diff --git a/vllm/lora/layers/qkv_x_parallel_linear.py b/vllm/lora/layers/qkv_x_parallel_linear.py new file mode 100644 index 0000000000000..367482d0ee078 --- /dev/null +++ b/vllm/lora/layers/qkv_x_parallel_linear.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .base import BaseLayerWithLoRA + + +#TODO: Implement this +class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA): + pass diff --git a/vllm/lora/layers/replicated_linear.py b/vllm/lora/layers/replicated_linear.py new file mode 100644 index 0000000000000..db922a02d40b0 --- /dev/null +++ b/vllm/lora/layers/replicated_linear.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.model_executor.layers.linear import ReplicatedLinear + +from .base_linear import BaseLinearLayerWithLoRA + + +class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): + + def __init__(self, base_layer: ReplicatedLinear) -> None: + super().__init__(base_layer, ) + # To ensure interface compatibility, set to 1 always. + self.tp_size = 1 + self.output_size = self.base_layer.output_size + self.n_slices = 1 + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of ReplicatedLinearWithLoRA + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output = self.apply(input_, bias) + + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + + if not self.base_layer.return_bias: + return output + + return output, output_bias + + # ReplicatedLinear should always be replaced, regardless of the fully + # sharded LoRAs setting, because it is, by definition, copied per GPU. + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ReplicatedLinear diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py new file mode 100644 index 0000000000000..bf1d9ae374f48 --- /dev/null +++ b/vllm/lora/layers/row_parallel_linear.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union, cast + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_reduce) +# yapf: disable +from vllm.model_executor.layers.linear import RowParallelLinear +from vllm.platforms import current_platform + +from .base_linear import BaseLinearLayerWithLoRA +from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace + + +class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + + def __init__(self, base_layer: RowParallelLinear) -> None: + super().__init__(base_layer) + + self.tp_size = get_tensor_model_parallel_world_size() + # reset input_size + self.input_size = self.base_layer.input_size_per_partition + self.output_size = self.base_layer.output_size + + self.tp_rank = get_tensor_model_parallel_rank() + # There is only one LoRA layer. + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + + shard_size = self.input_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + return bias + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply(input_parallel) + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = (output_ + self.base_layer.bias + if self.base_layer.bias is not None else output_) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + + if not self.base_layer.return_bias: + return output + + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is RowParallelLinear + + + +# The following layer is based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + +class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): + """ + Differs from RowParallelLinearWithLoRA by slicing the + LoRA B's also. + + Based on S-LoRA, slicing happens along the output dim. + This yields a combined partial sum from the row parallel base + layer and column partitioned output from the LoRA. + """ + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + shard_size = self.lora_b_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + if bias is None: + return bias + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + shard_size = self.lora_bias_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros( + (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink( + buffer, x, self.lora_a_stacked, 1.0) + if not current_platform.can_update_inplace(): + buffer = shrunk_buffer + + buffer = tensor_model_parallel_all_reduce(buffer) + + # following S-LoRA, allows the fusing of all_gather and all_reduce + # by adding the column partitioned lora output to a slice of output + # tensor, which is a partial sum due to row parallel. All that + # remains is a standard all_reduce. User should be aware though that + # the output is not the same as a normal row_parallel, it should be + # reduced before being used + # NOTE offset are based on the rank. + shard_size = self.lora_b_stacked[0].shape[2] + offset_start = self.tp_rank * shard_size + lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand( + output, + buffer, + self.lora_b_stacked, + self.lora_bias_stacked, + self.output_slices, + offset_start=offset_start, + add_input=True, + ) + + if not current_platform.can_update_inplace(): + output = lora_output + + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/vllm/lora/layers/utils.py b/vllm/lora/layers/utils.py new file mode 100644 index 0000000000000..27dcd720fbdea --- /dev/null +++ b/vllm/lora/layers/utils.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from vllm.adapter_commons.layers import AdapterMapping + + +@dataclass +class LoRAMapping(AdapterMapping): + is_prefill: bool = False + + +def _get_lora_device(base_layer: nn.Module) -> torch.device: + # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 + """Returns the device for where to place the LoRA tensors.""" + # unquantizedLinear + if hasattr(base_layer, "weight"): + return base_layer.weight.device + # Compressed Tensor + elif hasattr(base_layer, "weight_packed"): + return base_layer.weight_packed.device + # GPTQ/AWQ + elif hasattr(base_layer, "qweight"): + return base_layer.qweight.device + # HQQ marlin + elif hasattr(base_layer, "W_q"): + return base_layer.W_q.device + else: + raise ValueError(f"Unsupported base layer: {base_layer}") + + +def _not_fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of not using fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + decorate = kwargs.pop("decorate") if "decorate" in kwargs else True + condition = (not kwargs["lora_config"].fully_sharded_loras + if decorate else True) + return can_replace(*args, **kwargs) and condition + + return dec + + +def _fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + return (can_replace(*args, **kwargs) + and kwargs["lora_config"].fully_sharded_loras) + + return dec diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py new file mode 100644 index 0000000000000..192e154fe56a6 --- /dev/null +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.platforms import current_platform + +from .base import BaseLayerWithLoRA + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + self.embeddings_slice: Optional[tuple[int, int]] + self.embeddings_weights: Optional[torch.Tensor] + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + + if self.base_layer.num_added_embeddings_per_partition > 0: + # We can start adding lora weights + self.embeddings_weights = self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:self. + base_layer.num_org_embeddings_per_partition + + self.base_layer.num_added_embeddings_per_partition] + self.embeddings_slice = ( + self.base_layer.shard_indices.added_vocab_start_index - + self.base_layer.org_vocab_size, + self.base_layer.shard_indices.added_vocab_end_index - + self.base_layer.org_vocab_size) + self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:].fill_(0) + else: + self.embeddings_slice = None + self.embeddings_weights = None + + self.embeddings_tensors = torch.zeros( + ( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.base_layer.weight.dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.org_vocab_size + + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.embedding_dim, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked_2d = self.lora_a_stacked.view( + self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], + self.lora_a_stacked.shape[2], + ) + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, + :embeddings_tensor.shape[0], + :embeddings_tensor.shape[1], + ].copy_(embeddings_tensor, non_blocking=True) + if self.embeddings_slice is not None: + # TODO(yard1): Optimize this copy, we don't need to copy + # everything, just the modified part + embeddings = self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2], + )[self.embeddings_slice[0]:self.embeddings_slice[1]] + assert self.embeddings_weights is not None + self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, + 1, 0) + + # NB: Don't use torch.narrow here. torch.narrow triggers some + # Dynamic Shape specialization in torch.compile + num_tokens = x.shape[0] + indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens] + indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens] + + full_lora_a_embeddings = F.embedding( + x + indices_1, + self.lora_a_stacked_2d, + ) + full_output = self.base_layer.forward(x + + (indices_0 * added_tokens_mask)) + + full_output_org = full_output + if full_output.ndim == 3: + full_output = full_output.view( + full_output.shape[0] * full_output.shape[1], -1) + if full_lora_a_embeddings.ndim == 3: + full_lora_a_embeddings = full_lora_a_embeddings.view( + full_lora_a_embeddings.shape[0] * + full_lora_a_embeddings.shape[1], + -1, + ) + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_embedding( + full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) + + if not current_platform.can_update_inplace(): + full_output = lora_output + + return full_output.view_as(full_output_org) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is VocabParallelEmbedding + + @property + def weight(self): + return self.base_layer.weight diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 1fc214c12b5d1..2b05a2cf4d40c 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -13,21 +13,21 @@ from transformers import PretrainedConfig from vllm.config import LoRAConfig from vllm.logger import init_logger -from vllm.lora.fully_sharded_layers import ( - ColumnParallelLinearWithShardedLoRA, - MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, - RowParallelLinearWithShardedLoRA) # being imported for _all_lora_classes below # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + ColumnParallelLinearWithShardedLoRA, LogitsProcessorWithLoRA, MergedColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA, ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, + RowParallelLinearWithShardedLoRA, VocabParallelEmbeddingWithLoRA) from vllm.model_executor.layers.linear import LinearBase