vllm/vllm/lora/punica_wrapper/punica_base.py
Didier Durand 66d3d5422c
[Doc]: fixing typos in diverse files (#29492)
Signed-off-by: Didier Durand <durand.didier@gmail.com>
2025-11-27 07:15:50 -08:00

494 lines
15 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
import torch
from .utils import compute_meta, convert_mapping
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
class PunicaWrapperABC(ABC):
"""
PunicaWrapper ABC.
"""
@abstractmethod
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
**kwargs,
) -> None:
"""
Update the lora-related metadata
"""
raise NotImplementedError
@abstractmethod
def add_shrink(
self,
y: tuple[torch.Tensor, ...] | torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> torch.Tensor | None:
"""
Performs GEMM for multiple slices of lora_a.
"""
raise NotImplementedError
@abstractmethod
def add_expand(
self,
y: torch.Tensor,
x: tuple[torch.Tensor, ...] | torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> torch.Tensor | None:
"""
Performs GEMM for multiple slices of lora_b.
"""
raise NotImplementedError
@abstractmethod
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> torch.Tensor | None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA,
and this layer only requires the expand operation.
"""
raise NotImplementedError
@abstractmethod
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: tuple[torch.Tensor, ...] | None = None,
**kwargs,
) -> torch.Tensor | None:
"""
Applicable to linear-related lora.
"""
raise NotImplementedError
@abstractmethod
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
"""
raise NotImplementedError
class PunicaWrapperBase(PunicaWrapperABC):
"""
PunicaWrapperBase is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica.
"""
def __init__(
self,
max_num_batched_tokens: int,
max_batches: int,
device: torch.device | str,
**kwargs,
):
self._token_lora_indices = torch.empty(
max_num_batched_tokens, dtype=torch.long, device=device
)
self._sampler_indices = torch.empty(
max_num_batched_tokens, dtype=torch.long, device=device
)
self._sampler_indices_padded = torch.empty(
max_num_batched_tokens, dtype=torch.long, device=device
)
self._embeddings_indices = torch.empty(
2, max_num_batched_tokens, dtype=torch.long, device=device
)
# 4 is the number of indices tensors.
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
self.indices_len: list[int | None] = [None] * 4
# these attributes are the information required for sgmv kernel
self._seq_start_locs = torch.empty(max_batches, dtype=torch.long, device=device)
self._seq_lengths = torch.empty(max_batches, dtype=torch.long, device=device)
self._lora_indices_per_batch = torch.empty(
max_batches, dtype=torch.long, device=device
)
self.device: torch.device = device
self.max_length: int = 0
self.token_nums: int = 0
self.batch_size: int = -1
self.is_prefill = False
self.no_lora = False
def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
):
# NOTE We have remove lora extra vocab support for now. So we set
# extra_vocab_size always to 0, and extra_vocab_size will be removed.
extra_vocab_size = 0
(
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
indices_len,
) = convert_mapping(
mapping,
lora_index_to_id,
max_loras,
vocab_size,
extra_vocab_size,
self.device,
)
self._token_lora_indices[: base_indices.shape[0]].copy_(base_indices)
self._sampler_indices[: sampler_indices.shape[0]].copy_(sampler_indices)
self._sampler_indices_padded[: sampler_indices_padded.shape[0]].copy_(
sampler_indices_padded
)
self._embeddings_indices[
: embeddings_indices.shape[0], : embeddings_indices.shape[1]
].copy_(embeddings_indices)
self.indices_len[:] = indices_len
def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None:
(
b_seq_start_tensor,
seq_length_tensor,
lora_indices_tensor,
batch_size,
max_length,
token_nums,
no_lora,
) = compute_meta(token_lora_tensor)
self._seq_start_locs[: b_seq_start_tensor.shape[0]].copy_(b_seq_start_tensor)
self._seq_lengths[: seq_length_tensor.shape[0]].copy_(seq_length_tensor)
self._lora_indices_per_batch[: lora_indices_tensor.shape[0]].copy_(
lora_indices_tensor
)
self.batch_size = batch_size
self.max_length = max_length
self.token_nums = token_nums
self.no_lora = no_lora
@property
def prefill_metadata(
self,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
"""
This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations.
1. seq_start_locs: Tensor of sequence start positions.
2. seq_lengths: Tensor of sequence lengths.
3. lora_indices_per_batch: Tensor of lora indices, and an index of
-1 means no lora should be applied.
4. batch_size: Batch size after clustering identical lora indices.
5. max_length: The maximum sequence length in the batch.
6. token_nums: The token numbers in the batch.
"""
return (
self._seq_start_locs[: self.batch_size],
self._seq_lengths[: self.batch_size],
self._lora_indices_per_batch[: self.batch_size],
self.batch_size,
self.max_length,
self.token_nums,
)
@property
def token_lora_indices(self) -> torch.Tensor:
"""
This property provides the lora indices corresponding to each token
in the batch. An index of -1 means no lora should be applied.
"""
token_lora_len = self.indices_len[0]
return self._token_lora_indices[:token_lora_len]
@property
def sampler_indices(self) -> torch.Tensor:
"""
This property is used to access the lora indices specifically for
LogitsProcessorWithLoRA.
"""
sampler_indices_len = self.indices_len[1]
return self._sampler_indices[:sampler_indices_len]
@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices.
"""
indices_padded_len = self.indices_len[2]
return self._sampler_indices_padded[:indices_padded_len]
@property
def embeddings_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA.
"""
embeddings_indices_len = self.indices_len[3]
return self._embeddings_indices[:, :embeddings_indices_len]
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
**kwargs,
):
self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)
if mapping.is_prefill:
# Update metadata required for prefill-related operators.
self._update_prefill_metadata(self.token_lora_indices)
self.is_prefill = True
else:
self.is_prefill = False
@abstractmethod
def add_shrink(
self,
y: tuple[torch.Tensor, ...] | torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> torch.Tensor | None:
"""
Performs GEMM for multiple slices of lora_a.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_expand(
self,
y: torch.Tensor,
x: tuple[torch.Tensor, ...] | torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> torch.Tensor | None:
"""
Performs GEMM for multiple slices of lora_b.
Semantics:
offset = offset_start
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
output_slices (tuple[int, ...]): Every slice's size
offset_start (int): The starting position of y, defaults to 0
add_inputs (bool): Defaults to True.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> torch.Tensor | None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
and this layer only requires the expand operation.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: tuple[torch.Tensor, ...] | None = None,
**kwargs,
) -> torch.Tensor | None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
def moe_lora_align_block_size(
self,
topk_ids: torch.Tensor,
num_tokens: int,
block_size: int,
num_experts: int,
max_loras: int,
adapter_enabled: torch.Tensor,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
def add_lora_fused_moe(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
shrink_config,
expand_config,
adapter_enabled: torch.Tensor,
mul_routed_weight=False,
fully_sharded: bool = False,
offset: int = 0,
):
"""
Performs a fused forward computation for LoRA of
Mixture-of-Experts (MoE) layer.
"""
# TODO: implement it based on torch ops
raise NotImplementedError