mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 00:18:29 +08:00
248 lines
8.3 KiB
Python
248 lines
8.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import math
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm.config.lora import LoRAConfig
|
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
VocabParallelEmbedding)
|
|
from vllm.platforms import current_platform
|
|
|
|
from .base import BaseLayerWithLoRA
|
|
|
|
|
|
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|
"""
|
|
LoRA wrapper for LogitsProcessor, with extra logic to handle the
|
|
application of the LoRA adapter and added LoRA vocabulary.
|
|
|
|
Args:
|
|
base_layer: LogitsProcessor layer
|
|
hidden_size: hidden size of the model
|
|
dtype: data type of the model
|
|
device: device of the model
|
|
sharded_to_full_mapping: index mapping from sharded vocab to full vocab
|
|
received from base_layer.get_sharded_to_full_mapping(). If None,
|
|
no reindexing will be done.
|
|
"""
|
|
|
|
def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
|
|
dtype: torch.dtype, device: torch.device,
|
|
sharded_to_full_mapping: Optional[list[int]]) -> None:
|
|
super().__init__()
|
|
self.base_layer = base_layer
|
|
self.hidden_size = hidden_size
|
|
self.dtype = dtype
|
|
self.device = device
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
self.sharded_to_full_mapping = sharded_to_full_mapping
|
|
|
|
@property
|
|
def logits_as_input(self):
|
|
return self.base_layer.logits_as_input
|
|
|
|
@property
|
|
def vocab_size(self):
|
|
return self.base_layer.vocab_size
|
|
|
|
@property
|
|
def scale(self):
|
|
return self.base_layer.scale
|
|
|
|
@property
|
|
def soft_cap(self):
|
|
return self.base_layer.soft_cap
|
|
|
|
@property
|
|
def use_all_gather(self):
|
|
return self.base_layer.use_all_gather
|
|
|
|
@property
|
|
def org_vocab_size(self):
|
|
return self.base_layer.org_vocab_size
|
|
|
|
@property
|
|
def include_gpu_probs_tensor(self):
|
|
return self.base_layer.include_gpu_probs_tensor
|
|
|
|
@property
|
|
def should_modify_greedy_probs_inplace(self):
|
|
return self.base_layer.should_modify_greedy_probs_inplace
|
|
|
|
def create_lora_weights(
|
|
self,
|
|
max_loras: int,
|
|
lora_config: LoRAConfig,
|
|
model_config: Optional[PretrainedConfig] = None,
|
|
) -> None:
|
|
# TODO: Verify if this condition can be further relaxed
|
|
if 32000 < self.base_layer.vocab_size > 257024:
|
|
raise ValueError("When using LoRA, vocab size must be "
|
|
"32000 >= vocab_size <= 257024")
|
|
self.lora_a_stacked = torch.zeros(
|
|
(
|
|
max_loras,
|
|
1,
|
|
lora_config.max_lora_rank,
|
|
self.hidden_size,
|
|
),
|
|
dtype=lora_config.lora_dtype,
|
|
device=self.device,
|
|
)
|
|
self.lora_b_stacked = torch.zeros(
|
|
(
|
|
max_loras,
|
|
1,
|
|
# Pad for kernel compatibility
|
|
math.ceil(self.base_layer.vocab_size /
|
|
lora_config.lora_vocab_padding_size) *
|
|
lora_config.lora_vocab_padding_size,
|
|
lora_config.max_lora_rank,
|
|
),
|
|
dtype=lora_config.lora_dtype,
|
|
device=self.device,
|
|
)
|
|
self.embeddings_tensors = torch.full(
|
|
(max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
|
|
fill_value=float("-inf"),
|
|
dtype=self.dtype,
|
|
device=self.device,
|
|
)
|
|
if self.sharded_to_full_mapping is not None:
|
|
self.sharded_to_full_mapping_gpu = torch.tensor(
|
|
self.sharded_to_full_mapping,
|
|
device=self.device,
|
|
dtype=torch.long)
|
|
else:
|
|
self.sharded_to_full_mapping_gpu = None
|
|
|
|
def reset_lora(self, index: int):
|
|
self.lora_a_stacked[index] = 0
|
|
self.lora_b_stacked[index] = 0
|
|
self.embeddings_tensors[index] = float("-inf")
|
|
|
|
def set_lora(
|
|
self,
|
|
index: int,
|
|
lora_a: torch.Tensor,
|
|
lora_b: torch.Tensor,
|
|
embeddings_tensor: Optional[torch.Tensor],
|
|
bias: Optional[torch.Tensor] = None,
|
|
):
|
|
self.reset_lora(index)
|
|
self.lora_a_stacked[index,
|
|
0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
|
|
lora_a, non_blocking=True)
|
|
self.lora_b_stacked[index,
|
|
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
|
|
lora_b, non_blocking=True)
|
|
if embeddings_tensor is not None:
|
|
self.embeddings_tensors[
|
|
index,
|
|
:embeddings_tensor.shape[0],
|
|
:embeddings_tensor.shape[1],
|
|
] = embeddings_tensor
|
|
|
|
def _get_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
lm_head: VocabParallelEmbedding,
|
|
embedding_bias: Optional[torch.Tensor] = None,
|
|
) -> Optional[torch.Tensor]:
|
|
# Get the logits for the next tokens.
|
|
logits = lm_head.quant_method.apply(lm_head, hidden_states)
|
|
if embedding_bias is not None:
|
|
logits += embedding_bias
|
|
|
|
# Gather logits for TP
|
|
logits = self.base_layer._gather_logits(logits)
|
|
|
|
if logits is None:
|
|
return None
|
|
|
|
if self.sharded_to_full_mapping_gpu is not None:
|
|
# Reindex full logits tensor to ensure 1:1 mapping between
|
|
# index and token_id
|
|
# Example for:
|
|
# org_vocab_size = 4
|
|
# added_vocab_size = 2
|
|
# pad_to_size = 8
|
|
# tp_size = 2
|
|
|
|
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
|
|
# token_id: [0, 1, 4, -1, 2, 3, 5, -1]
|
|
|
|
# Therefore, the mapping is expected to be:
|
|
# [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
|
|
# we get:
|
|
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
|
|
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
|
|
logits = logits[:, self.sharded_to_full_mapping_gpu]
|
|
|
|
lora_logits = torch.empty(
|
|
self.embeddings_tensors.shape[0] + 1,
|
|
self.embeddings_tensors.shape[1],
|
|
hidden_states.shape[0],
|
|
dtype=self.embeddings_tensors.dtype,
|
|
device=self.embeddings_tensors.device,
|
|
)
|
|
torch.matmul(self.embeddings_tensors,
|
|
hidden_states.T,
|
|
out=lora_logits[:-1])
|
|
|
|
neg_inf, pos_inf = current_platform.get_infinity_values(
|
|
lora_logits.dtype)
|
|
|
|
lora_logits[-1] = neg_inf
|
|
lora_logits = lora_logits.mT
|
|
indices_padded = self.punica_wrapper.sampler_indices_padded
|
|
|
|
if current_platform.is_tpu() or current_platform.is_xpu():
|
|
indices_padded = indices_padded[:logits.size(0)]
|
|
|
|
lora_logits = (lora_logits.reshape(
|
|
lora_logits.shape[0] * lora_logits.shape[1],
|
|
lora_logits.shape[2],
|
|
).index_select(0, indices_padded).nan_to_num_(nan=neg_inf,
|
|
posinf=pos_inf,
|
|
neginf=neg_inf))
|
|
|
|
logits[:,
|
|
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
|
|
lora_logits.shape[1]] = lora_logits
|
|
|
|
lora_output: Optional[
|
|
torch.Tensor] = self.punica_wrapper.add_lora_logits(
|
|
logits, hidden_states, self.lora_a_stacked,
|
|
self.lora_b_stacked, 1.0)
|
|
|
|
if not current_platform.can_update_inplace():
|
|
logits = lora_output
|
|
|
|
# Remove paddings in vocab (if any).
|
|
logits = logits[:, :self.base_layer.vocab_size]
|
|
return logits
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return type(self.base_layer).forward(self, *args, **kwargs)
|
|
|
|
@classmethod
|
|
def can_replace_layer(
|
|
cls,
|
|
source_layer: nn.Module,
|
|
lora_config: LoRAConfig,
|
|
packed_modules_list: list,
|
|
model_config: Optional[PretrainedConfig],
|
|
) -> bool:
|
|
# Special handling for the LogitsProcessor.
|
|
return False
|