mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 14:35:39 +08:00
175 lines
6.4 KiB
Python
175 lines
6.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm.config.lora import LoRAConfig
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
VocabParallelEmbedding)
|
|
from vllm.platforms import current_platform
|
|
|
|
from .base import BaseLayerWithLoRA
|
|
|
|
|
|
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|
|
|
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
|
|
super().__init__()
|
|
self.base_layer = base_layer
|
|
self.embeddings_slice: Optional[tuple[int, int]]
|
|
self.embeddings_weights: Optional[torch.Tensor]
|
|
|
|
def create_lora_weights(
|
|
self,
|
|
max_loras: int,
|
|
lora_config: LoRAConfig,
|
|
model_config: Optional[PretrainedConfig] = None) -> None:
|
|
|
|
if self.base_layer.num_added_embeddings_per_partition > 0:
|
|
# We can start adding lora weights
|
|
self.embeddings_weights = self.base_layer.weight.data[
|
|
self.base_layer.num_org_embeddings_per_partition:self.
|
|
base_layer.num_org_embeddings_per_partition +
|
|
self.base_layer.num_added_embeddings_per_partition]
|
|
self.embeddings_slice = (
|
|
self.base_layer.shard_indices.added_vocab_start_index -
|
|
self.base_layer.org_vocab_size,
|
|
self.base_layer.shard_indices.added_vocab_end_index -
|
|
self.base_layer.org_vocab_size)
|
|
self.base_layer.weight.data[
|
|
self.base_layer.num_org_embeddings_per_partition:].fill_(0)
|
|
else:
|
|
self.embeddings_slice = None
|
|
self.embeddings_weights = None
|
|
|
|
self.embeddings_tensors = torch.zeros(
|
|
(
|
|
max_loras,
|
|
lora_config.lora_extra_vocab_size,
|
|
self.base_layer.embedding_dim,
|
|
),
|
|
dtype=self.base_layer.weight.dtype,
|
|
device=self.base_layer.weight.device,
|
|
)
|
|
self.lora_a_stacked = torch.zeros(
|
|
(
|
|
max_loras,
|
|
self.base_layer.org_vocab_size +
|
|
lora_config.lora_extra_vocab_size,
|
|
lora_config.max_lora_rank,
|
|
),
|
|
dtype=lora_config.lora_dtype,
|
|
device=self.base_layer.weight.device,
|
|
)
|
|
self.lora_b_stacked = torch.zeros(
|
|
(
|
|
max_loras,
|
|
1,
|
|
self.base_layer.embedding_dim,
|
|
lora_config.max_lora_rank,
|
|
),
|
|
dtype=lora_config.lora_dtype,
|
|
device=self.base_layer.weight.device,
|
|
)
|
|
self.lora_a_stacked_2d = self.lora_a_stacked.view(
|
|
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
|
|
self.lora_a_stacked.shape[2],
|
|
)
|
|
|
|
def reset_lora(self, index: int):
|
|
self.lora_a_stacked[index] = 0
|
|
self.lora_b_stacked[index] = 0
|
|
self.embeddings_tensors[index] = 0
|
|
|
|
def set_lora(
|
|
self,
|
|
index: int,
|
|
lora_a: torch.Tensor,
|
|
lora_b: torch.Tensor,
|
|
embeddings_tensor: Optional[torch.Tensor],
|
|
bias: Optional[torch.Tensor] = None,
|
|
):
|
|
self.reset_lora(index)
|
|
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
|
|
# so we need transpose here
|
|
self.lora_a_stacked[index, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
|
lora_a.T, non_blocking=True)
|
|
self.lora_b_stacked[index,
|
|
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
|
|
lora_b, non_blocking=True)
|
|
if embeddings_tensor is not None:
|
|
self.embeddings_tensors[
|
|
index,
|
|
:embeddings_tensor.shape[0],
|
|
:embeddings_tensor.shape[1],
|
|
].copy_(embeddings_tensor, non_blocking=True)
|
|
if self.embeddings_slice is not None:
|
|
# TODO(yard1): Optimize this copy, we don't need to copy
|
|
# everything, just the modified part
|
|
embeddings = self.embeddings_tensors.view(
|
|
self.embeddings_tensors.shape[0] *
|
|
self.embeddings_tensors.shape[1],
|
|
self.embeddings_tensors.shape[2],
|
|
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
|
|
assert self.embeddings_weights is not None
|
|
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1,
|
|
1, 0)
|
|
|
|
# NB: Don't use torch.narrow here. torch.narrow triggers some
|
|
# Dynamic Shape specialization in torch.compile
|
|
num_tokens = x.shape[0]
|
|
indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
|
|
indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens]
|
|
|
|
full_lora_a_embeddings = F.embedding(
|
|
x + indices_1,
|
|
self.lora_a_stacked_2d,
|
|
)
|
|
full_output = self.base_layer.forward(x +
|
|
(indices_0 * added_tokens_mask))
|
|
|
|
full_output_org = full_output
|
|
if full_output.ndim == 3:
|
|
full_output = full_output.view(
|
|
full_output.shape[0] * full_output.shape[1], -1)
|
|
if full_lora_a_embeddings.ndim == 3:
|
|
full_lora_a_embeddings = full_lora_a_embeddings.view(
|
|
full_lora_a_embeddings.shape[0] *
|
|
full_lora_a_embeddings.shape[1],
|
|
-1,
|
|
)
|
|
|
|
lora_output: Optional[
|
|
torch.Tensor] = self.punica_wrapper.add_lora_embedding(
|
|
full_output,
|
|
full_lora_a_embeddings,
|
|
self.lora_b_stacked,
|
|
add_input=True)
|
|
|
|
if not current_platform.can_update_inplace():
|
|
full_output = lora_output
|
|
|
|
return full_output.view_as(full_output_org)
|
|
|
|
@classmethod
|
|
def can_replace_layer(
|
|
cls,
|
|
source_layer: nn.Module,
|
|
lora_config: LoRAConfig,
|
|
packed_modules_list: list,
|
|
model_config: Optional[PretrainedConfig],
|
|
) -> bool:
|
|
return type(source_layer) is VocabParallelEmbedding
|
|
|
|
@property
|
|
def weight(self):
|
|
return self.base_layer.weight
|