mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:15:42 +08:00
137 lines
5.4 KiB
Python
137 lines
5.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from typing import TYPE_CHECKING, Optional, Union
|
|
|
|
import torch
|
|
|
|
if TYPE_CHECKING:
|
|
# avoid circuit import
|
|
from vllm.lora.layers import LoRAMapping
|
|
|
|
|
|
def compute_meta(
|
|
token_lora_tensor: torch.Tensor
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
|
|
"""
|
|
Get the information required for the sgmv kernel. With the features:
|
|
1. If consecutive requests in the batch use the same LoRA, this function
|
|
will combine them into a single request, improving sgmv kernel inference
|
|
performance.
|
|
2. At the beginning of each prefill stage inference, recalculations are
|
|
needed based on the input, but only once.
|
|
"""
|
|
|
|
lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
|
|
token_lora_tensor, return_counts=True)
|
|
cum_result = torch.cumsum(seq_length_tensor, dim=0)
|
|
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
|
|
b_seq_start_tensor[1:].copy_(cum_result[:-1])
|
|
max_length = seq_length_tensor.max().item()
|
|
token_nums = seq_length_tensor.sum().item()
|
|
batch_size = lora_indices_tensor.size(0)
|
|
no_lora = False
|
|
# -1 means no lora should be applied. Use `no_lora` to determine whether
|
|
# the current step requires LoRA. If LoRA is not needed, the prefill stage
|
|
# does not need to launch the triton kernel, which can improve performance
|
|
if batch_size == 1 and lora_indices_tensor == -1:
|
|
no_lora = True
|
|
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
|
|
batch_size, max_length, token_nums, no_lora)
|
|
|
|
|
|
# TODO see if this can be vectorized
|
|
def convert_mapping(
|
|
mapping: "LoRAMapping",
|
|
lora_index_to_id: list[Optional[int]],
|
|
max_loras: int,
|
|
vocab_size: int,
|
|
extra_vocab_size: int,
|
|
device: torch.device,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[int]]:
|
|
"""Converts LoRAMapping to index tensors.
|
|
|
|
Args:
|
|
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
|
|
lora_index_to_id: List mapping LoRA ids to LoRA indices.
|
|
max_loras: Maximum number of LoRAs.
|
|
vocab_size: Model vocab size.
|
|
extra_vocab_size: Extra vocab size each LoRA can have.
|
|
|
|
Returns:
|
|
A tuple of tensors:
|
|
base_indices: Tensor of shape [batch_size] mapping batch rows to
|
|
LoRA indices.
|
|
sampler_indices: Tensor of shape [batch_size] mapping requests to
|
|
LoRA indices for sampler. For generation, this will be the
|
|
same as base_indices. For prefill, this will map requests
|
|
to LoRA indices.
|
|
sampler_indices_padded: Tensor of shape [batch_size] mapping
|
|
requests to LoRA indices for sampler with padding.
|
|
Same as sampler_indices, but -1 is replaced with
|
|
max_loras.
|
|
embeddings_indices: Tensor of shape [2, batch_size] mapping
|
|
requests to embedding indices. First row is for embeddings
|
|
added by the LoRAs, second row is for the LoRA.lora_a
|
|
embeddings.
|
|
indices_len: List of lengths of the above tensors. It contains
|
|
(base_indices, sampler_indices, sampler_indices_padded,
|
|
embeddings_indices).
|
|
"""
|
|
index_mapping_indices: list[int] = list(mapping.index_mapping).copy()
|
|
embedding_indices = index_mapping_indices.copy()
|
|
lora_indices = index_mapping_indices.copy()
|
|
|
|
prompt_mapping: list[int] = [
|
|
lora_index_to_id.index(x) if x > 0 else -1
|
|
for x in mapping.prompt_mapping
|
|
]
|
|
lora_idx = None
|
|
for i in range(len(index_mapping_indices)):
|
|
# TODO index can be slow. optimize
|
|
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
|
|
if index_mapping_indices[i] > 0 else -1)
|
|
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
|
|
lora_indices[i] = lora_idx
|
|
|
|
indices_list: list[Union[list[int], torch.Tensor]] = [
|
|
index_mapping_indices,
|
|
lora_indices,
|
|
embedding_indices,
|
|
]
|
|
|
|
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
|
|
prompt_mapping_tensor = torch.tensor(prompt_mapping,
|
|
dtype=torch.long,
|
|
device=device)
|
|
embeddings_indices = torch.stack([
|
|
indices[2] * extra_vocab_size,
|
|
indices[2] * (vocab_size + extra_vocab_size),
|
|
])
|
|
embeddings_indices = torch.where(embeddings_indices == -1, max_loras - 1,
|
|
embeddings_indices)
|
|
base_indices = indices[1]
|
|
sampler_indices = prompt_mapping_tensor
|
|
sampler_indices_padded = sampler_indices.clone()
|
|
sampler_indices_padded = torch.where(sampler_indices_padded == -1,
|
|
max_loras - 1, sampler_indices_padded)
|
|
sampler_indices_padded = torch.arange(
|
|
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
|
|
sampler_indices_padded * len(sampler_indices_padded))
|
|
|
|
# Contain length of indices tensors. Used to index into each tensor.
|
|
indices_len = [
|
|
base_indices.shape[-1],
|
|
sampler_indices.shape[-1],
|
|
sampler_indices_padded.shape[-1],
|
|
embeddings_indices.shape[-1],
|
|
]
|
|
|
|
return (
|
|
base_indices,
|
|
sampler_indices,
|
|
sampler_indices_padded,
|
|
embeddings_indices,
|
|
indices_len,
|
|
)
|