mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 19:15:01 +08:00
356 lines
13 KiB
Python
356 lines
13 KiB
Python
import logging
|
|
import math
|
|
from typing import Any, Callable, Dict, List, Optional, Type
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
|
|
AdapterModelManager)
|
|
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
|
|
get_adapter, list_adapters,
|
|
remove_adapter, set_adapter_mapping)
|
|
from vllm.config import PromptAdapterConfig
|
|
from vllm.prompt_adapter.layers import (
|
|
VocabParallelEmbeddingWithPromptAdapter) # yapf: disable
|
|
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
|
from vllm.prompt_adapter.utils import load_peft_weights
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_GLOBAL_PROMPT_ADAPTER_ID = 0
|
|
|
|
|
|
def get_prompt_adapter_id():
|
|
global _GLOBAL_PROMPT_ADAPTER_ID
|
|
_GLOBAL_PROMPT_ADAPTER_ID += 1
|
|
return _GLOBAL_PROMPT_ADAPTER_ID
|
|
|
|
|
|
def convert_to_embedding_indices(indices):
|
|
embedding_indices = []
|
|
count = 0
|
|
|
|
for value in indices:
|
|
if value == -1:
|
|
count = 0
|
|
else:
|
|
embedding_indices.append([value, count])
|
|
count += 1
|
|
|
|
return torch.tensor(embedding_indices)
|
|
|
|
|
|
def convert_mapping(
|
|
mapping: PromptAdapterMapping,
|
|
prompt_adapter_index_to_id: List[Optional[int]],
|
|
) -> torch.Tensor:
|
|
"""Converts PromptAdapterMapping to index tensors.
|
|
|
|
Args:
|
|
mapping: PromptAdapterMapping mapping rows in a
|
|
batch to PromptAdapter ids.
|
|
prompt_adapter_index_to_id: List mapping PromptAdapter
|
|
ids to PromptAdapter indices.
|
|
|
|
Returns:
|
|
pa_indices: Tensor of shape [batch_size] mapping batch rows to
|
|
PromptAdapter indices.
|
|
"""
|
|
id_to_index = {
|
|
id_: idx
|
|
for idx, id_ in enumerate(prompt_adapter_index_to_id)
|
|
if id_ is not None
|
|
}
|
|
pa_indices = ([
|
|
id_to_index.get(id_, -1) if id_ > 0 else -1
|
|
for id_ in mapping.index_mapping
|
|
])
|
|
|
|
pa_embedding_mapping = convert_to_embedding_indices(pa_indices)
|
|
pa_indices = torch.tensor(pa_indices)
|
|
return pa_indices, pa_embedding_mapping
|
|
|
|
|
|
class PromptAdapterModel(AdapterModel):
|
|
|
|
def __init__(self,
|
|
prompt_adapter_id=None,
|
|
num_virtual_tokens=None,
|
|
prompt_embedding=None) -> None:
|
|
self.id = prompt_adapter_id
|
|
self.prompt_embedding = prompt_embedding
|
|
self.num_virtual_tokens = num_virtual_tokens
|
|
|
|
@classmethod
|
|
def from_local_checkpoint(
|
|
cls,
|
|
adapter_model_path: str,
|
|
prompt_adapter_id: int,
|
|
num_virtual_tokens: int,
|
|
config: PromptAdapterConfig,
|
|
device: str = "cuda",
|
|
) -> "PromptAdapterModel":
|
|
|
|
if num_virtual_tokens > config.max_prompt_adapter_token:
|
|
raise ValueError(
|
|
f'num_virtual_tokens ({num_virtual_tokens}) should be <= '
|
|
f'max_prompt_adapter_token({config.max_prompt_adapter_token})')
|
|
|
|
adapters_weights = load_peft_weights(adapter_model_path, device)
|
|
prompt_embedding = adapters_weights["prompt_embeddings"].to(
|
|
config.prompt_adapter_dtype)
|
|
|
|
return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding)
|
|
|
|
|
|
class PromptAdapterModelManager(AdapterModelManager):
|
|
"""A manager that manages multiple Prompt Adapter models."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
max_num_seqs: int,
|
|
max_num_batched_tokens: int,
|
|
prompt_adapter_config: PromptAdapterConfig,
|
|
):
|
|
"""Create a PromptAdapterModel and adapter for a given model.
|
|
|
|
Args:
|
|
model: the model to be adapted.
|
|
max_num_seqs: the maximum number of sequences model can run in a
|
|
single batch.
|
|
max_num_batched_tokens: the maximum number of tokens model can run
|
|
in a single batch.
|
|
prompt_adapter_config: the PromptAdapter config,
|
|
"""
|
|
self.model: nn.Module = model
|
|
# Dict instead of a Set for compatibility with LRUCache.
|
|
self.prompt_adapter_index_to_id: List[
|
|
Optional[int]] = [None] * self.prompt_adapter_slots
|
|
self.max_num_seqs = max_num_seqs
|
|
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
|
self.prompt_adapter_config = prompt_adapter_config
|
|
self.model.prompt_adapter_manager = self
|
|
self.adapter_type = 'PromptAdapter'
|
|
|
|
self.base_indices = torch.tensor([-1])
|
|
self.base_embedding_indices = torch.tensor([])
|
|
|
|
self.modules: Dict[str, nn.Module] = {}
|
|
self._create_prompt_adapter_modules()
|
|
self._last_mapping: Optional[PromptAdapterMapping] = None
|
|
|
|
@property
|
|
def prompt_adapter_slots(self) -> int:
|
|
return self.prompt_adapter_config.max_prompt_adapters
|
|
|
|
@property
|
|
def adapter_slots(self) -> int:
|
|
return self.prompt_adapter_slots
|
|
|
|
@property
|
|
def capacity(self) -> int:
|
|
return self.prompt_adapter_config.max_cpu_prompt_adapters
|
|
|
|
def activate_adapter(
|
|
self,
|
|
prompt_adapter_id: int,
|
|
) -> bool:
|
|
"""Move PromptAdapter into a GPU buffer
|
|
to be used in the forward pass."""
|
|
if prompt_adapter_id in self._active_adapters:
|
|
return False
|
|
first_free_slot = next(
|
|
((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate(
|
|
self.prompt_adapter_index_to_id) if prompt_adapter_id is None),
|
|
None)
|
|
if first_free_slot is None:
|
|
raise ValueError("No free prompt_adapter slots")
|
|
index, _ = first_free_slot
|
|
self._active_adapters[prompt_adapter_id] = None
|
|
prompt_adapter_model = (self._registered_adapters[prompt_adapter_id])
|
|
logger.debug("Activating prompt_adapter. int id: %d, slot index: %d",
|
|
prompt_adapter_model.id, index)
|
|
self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id
|
|
for _, v in self.modules.items():
|
|
v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding)
|
|
return True
|
|
|
|
def _deactivate_adapter(self, prompt_adapter_id: int):
|
|
try:
|
|
index = self.prompt_adapter_index_to_id.index(prompt_adapter_id)
|
|
self.prompt_adapter_index_to_id[index] = None
|
|
for _, v in self.modules.items():
|
|
v.reset_prompt_adapter(index)
|
|
except ValueError:
|
|
pass
|
|
|
|
def _add_adapter(self, prompt_adapter: PromptAdapterModel):
|
|
self._registered_adapters[prompt_adapter.id] = prompt_adapter
|
|
|
|
def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
|
|
base_indices, base_embedding_indices = convert_mapping(
|
|
mapping, self.prompt_adapter_index_to_id)
|
|
for k, v in self.modules.items():
|
|
v.set_mapping(base_indices, base_embedding_indices)
|
|
|
|
def _create_prompt_adapter_modules(self):
|
|
for module_name, module in self.model.named_modules(
|
|
remove_duplicate=False):
|
|
if "VocabParallel" in module.__class__.__name__:
|
|
new_module = VocabParallelEmbeddingWithPromptAdapter(module)
|
|
new_module.create_prompt_adapter_weights(
|
|
self.prompt_adapter_config)
|
|
replaced_module = self.replace_submodule(
|
|
self.model, module_name, new_module)
|
|
self.register_module(module.__class__.__name__,
|
|
replaced_module)
|
|
replaced_module.set_mapping(self.base_indices,
|
|
self.base_embedding_indices)
|
|
break
|
|
|
|
def replace_submodule(self, model: nn.Module, module_name: str,
|
|
new_module: nn.Module) -> nn.Module:
|
|
"""Replace a submodule in a model with a new module."""
|
|
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
|
|
target_name = module_name.split(".")[-1]
|
|
setattr(parent, target_name, new_module)
|
|
return new_module
|
|
|
|
def register_module(self, module_name: str, module: nn.Module):
|
|
self.modules[module_name] = module
|
|
|
|
def pin_adapter(self, prompt_adapter_id: int) -> bool:
|
|
"""Pin a PromptAdapterModel in the manager cache."""
|
|
raise NotImplementedError(
|
|
"Pinning is not supported in PromptAdapterModelManager."
|
|
"Use LRUCachePromptAdapterModelManager for pinning"
|
|
) # type: ignore
|
|
|
|
def remove_all_adapters(self):
|
|
"""Remove all PromptAdapterModel from the manager."""
|
|
self._registered_adapters.clear()
|
|
self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots
|
|
self._active_adapters.clear()
|
|
|
|
def deactivate_adapter(self, adapter_id: int) -> bool:
|
|
return deactivate_adapter(adapter_id, self._active_adapters,
|
|
self._deactivate_adapter)
|
|
|
|
def add_adapter(self, adapter: PromptAdapterModel) -> bool:
|
|
return add_adapter(adapter, self._registered_adapters, self.capacity,
|
|
self._add_adapter)
|
|
|
|
def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
|
|
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
|
|
self._set_adapter_mapping)
|
|
|
|
def remove_adapter(self, adapter_id: int) -> bool:
|
|
return remove_adapter(adapter_id, self._registered_adapters,
|
|
self.deactivate_adapter)
|
|
|
|
def list_adapters(self) -> Dict[int, Any]:
|
|
return list_adapters(self._registered_adapters)
|
|
|
|
def get_adapter(self, adapter_id: int) -> Optional[Any]:
|
|
return get_adapter(adapter_id, self._registered_adapters)
|
|
|
|
|
|
class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]):
|
|
|
|
def __init__(self, capacity: int,
|
|
deactivate_prompt_adapter_fn: Callable[[int], bool]):
|
|
super().__init__(capacity, deactivate_prompt_adapter_fn)
|
|
|
|
|
|
class LRUCachePromptAdapterModelManager(PromptAdapterModelManager):
|
|
"""A model manager that manages multiple prompt_adapters with LRU cache."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
max_num_seqs: int,
|
|
max_num_batched_tokens: int,
|
|
prompt_adapter_config: PromptAdapterConfig,
|
|
):
|
|
self.prompt_adapter_config = prompt_adapter_config
|
|
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
|
prompt_adapter_config)
|
|
self._registered_adapters = PromptAdapterLRUCache(
|
|
self.capacity, self.deactivate_adapter)
|
|
self._active_adapters = PromptAdapterLRUCache(
|
|
self.prompt_adapter_slots, self._deactivate_adapter)
|
|
|
|
def list_adapters(self) -> Dict[int, PromptAdapterModel]:
|
|
"""List all registered PromptAdapterModel."""
|
|
return dict(self._registered_adapters.cache)
|
|
|
|
def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool:
|
|
"""Add a PromptAdapterModel to the manager."""
|
|
if prompt_adapter.id not in self._registered_adapters:
|
|
self._add_adapter(prompt_adapter)
|
|
was_added = True
|
|
else:
|
|
# We always touch to update the LRU cache order
|
|
self._registered_adapters.touch(prompt_adapter.id)
|
|
was_added = False
|
|
return was_added
|
|
|
|
def activate_adapter(
|
|
self,
|
|
prompt_adapter_id: int,
|
|
) -> bool:
|
|
if prompt_adapter_id not in self._active_adapters and len(
|
|
self._active_adapters) >= self.prompt_adapter_slots:
|
|
self._active_adapters.remove_oldest()
|
|
result = super().activate_adapter(prompt_adapter_id)
|
|
# We always touch to update the LRU cache order
|
|
self._active_adapters.touch(prompt_adapter_id)
|
|
return result
|
|
|
|
def remove_oldest_adapter(self) -> bool:
|
|
if len(self._registered_adapters) > 0:
|
|
self._registered_adapters.remove_oldest()
|
|
return True
|
|
return False
|
|
|
|
def pin_adapter(self, prompt_adapter_id: int) -> bool:
|
|
"""Pin a PromptAdapterModel in the manager cache."""
|
|
self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id)
|
|
self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id)
|
|
return True
|
|
|
|
def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int):
|
|
try:
|
|
self._registered_adapters.pin(prompt_adapter_id)
|
|
except ValueError as err:
|
|
raise ValueError(
|
|
"Pinning failed. "
|
|
f"Prompt Adapter {prompt_adapter_id} is not registered."
|
|
) from err
|
|
|
|
def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int):
|
|
if prompt_adapter_id not in self._active_adapters:
|
|
# move adapter to gpu if not already active
|
|
self.activate_adapter(prompt_adapter_id)
|
|
self._active_adapters.pin(prompt_adapter_id)
|
|
|
|
|
|
def create_prompt_adapter_manager(
|
|
model: nn.Module,
|
|
max_num_seqs: int,
|
|
max_num_batched_tokens: int,
|
|
prompt_adapter_config: PromptAdapterConfig,
|
|
prompt_adapter_manager_cls: Type[
|
|
PromptAdapterModelManager] = PromptAdapterModelManager,
|
|
**kwargs) -> PromptAdapterModelManager:
|
|
"""Create a PromptAdapterModel for a given model."""
|
|
prompt_adapter_manager = prompt_adapter_manager_cls(
|
|
model=model,
|
|
max_num_seqs=max_num_seqs,
|
|
max_num_batched_tokens=max_num_batched_tokens,
|
|
prompt_adapter_config=prompt_adapter_config,
|
|
**kwargs)
|
|
return prompt_adapter_manager
|