Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-12-20 01:47:11 +00:00
parent 463074fac8
commit d053aa73e1

View File

@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar
import regex as re
@ -47,12 +45,6 @@ T = TypeVar("T")
DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model"
@dataclass(frozen=True)
class LoRATarget:
wrapper: PunicaWrapperBase
prefix: str
class AdapterLRUCache(LRUCache[int, T]):
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
super().__init__(capacity)
@ -120,21 +112,6 @@ class LoRAModelManager:
def _init_punica_wrapper(
self, max_num_batched_tokens: int, vllm_config: VllmConfig
) -> None:
self.punica_wrapper_mapping: OrderedDict[str, PunicaWrapperBase] = OrderedDict()
llm_punica_wrapper = get_punica_wrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device,
max_loras=self.lora_config.max_loras,
)
# NOTE This assumes the existence of a language model LoRA
self.punica_wrapper_mapping.setdefault(
DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper
)
self._maybe_init_mm(vllm_config)
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = (
supports_multimodal(self.model)
@ -142,13 +119,39 @@ class LoRAModelManager:
# text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping")
)
if not self.supports_mm:
return
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {}
if self.supports_mm:
self._maybe_init_mm(vllm_config,max_num_batched_tokens)
else:
llm_punica_wrapper = get_punica_wrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device,
max_loras=self.lora_config.max_loras,
)
self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY] = (
llm_punica_wrapper
)
def _maybe_init_mm(self, vllm_config: VllmConfig, max_num_batched_tokens) -> None:
self.supports_tower_connector_lora = False
model_config: ModelConfig = vllm_config.model_config
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
# Only one language model can be included in the model.
assert len(self.mm_mapping.language_model) == 1
# Language model punica wrapper
llm_punica_wrapper = get_punica_wrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device,
max_loras=self.lora_config.max_loras,
)
lm_prefix = self.mm_mapping.language_model[0]
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
if self.lora_config.enable_tower_connector_lora:
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
self.supports_tower_connector_lora = self.supports_mm and hasattr(
@ -156,6 +159,7 @@ class LoRAModelManager:
)
if not self.supports_tower_connector_lora:
return
logger.warning(
"LoRA for the tower and connector of multimodal models is "
"experimental and may contain bugs. Please report any related issues on "
@ -172,20 +176,6 @@ class LoRAModelManager:
mm_budget.get_encoder_budget()
)
# Only one language model can be included in the model.
assert len(self.mm_mapping.language_model) == 1
# Update prefix of language model
lm_prefix = (
self.mm_mapping.language_model[0]
if self.supports_mm
else DEFAULT_LANGUAGE_WRAPPER_KEY
)
llm_punica_wrapper = self.punica_wrapper_mapping.pop(
DEFAULT_LANGUAGE_WRAPPER_KEY
)
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
# Tower wrappers
tower_punica_wrapper = get_punica_wrapper(
num_encoder_tokens,
@ -217,15 +207,6 @@ class LoRAModelManager:
"determine the connector's token budget for LoRA operations."
)
# Longest-prefix-first
self.punica_wrapper_mapping = OrderedDict(
sorted(
self.punica_wrapper_mapping.items(),
key=lambda x: len(x[0]),
reverse=True,
)
)
def __len__(self) -> int:
return len(self._registered_adapters)
@ -366,10 +347,10 @@ class LoRAModelManager:
else:
target_prefix = self.mm_mapping.language_model[0]
target = self._get_lora_target(target_prefix)
assert target is not None
punica_wrapper = self._get_punica_wrapper(target_prefix)
assert punica_wrapper is not None
target.wrapper.update_metadata(
punica_wrapper.wrapper.update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
@ -397,8 +378,8 @@ class LoRAModelManager:
if not self._match_target_modules(module_name):
continue
target = self._get_lora_target(module_name)
if target is None:
punica_wrapper = self._get_punica_wrapper(module_name)
if punica_wrapper is None:
logger.warning(
"Regarding %s, vLLM currently only supports adding LoRA to"
" language model, %s will be ignored.",
@ -464,7 +445,7 @@ class LoRAModelManager:
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
new_module.set_mapping(target.wrapper)
new_module.set_mapping(punica_wrapper)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA), (
@ -485,7 +466,7 @@ class LoRAModelManager:
if (
not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA)
or self._get_lora_target(module_name) is None
or self._get_punica_wrapper(module_name) is None
):
continue
parts = module_name.split(".")
@ -574,23 +555,24 @@ class LoRAModelManager:
for target_module in self.supported_lora_modules
)
def _get_lora_target(self, module_name: str) -> LoRATarget | None:
def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None:
"""
Determine whether this module supports LoRA and which wrapper to use.
"""
# For language Model (early return)
if not self.supports_mm:
wrapper = list(self.punica_wrapper_mapping.values())[0]
return LoRATarget(wrapper=wrapper, prefix=DEFAULT_LANGUAGE_WRAPPER_KEY)
return self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY]
# For multimodal model
for prefix, wrapper in self.punica_wrapper_mapping.items():
is_language_model = (
prefix == DEFAULT_LANGUAGE_WRAPPER_KEY
and module_name.startswith(self.mm_mapping.language_model[0])
)
if is_language_model or module_name.startswith(prefix):
return LoRATarget(wrapper=wrapper, prefix=prefix)
# for prefix, wrapper in self.punica_wrapper_mapping.items():
# is_language_model = (
# prefix == DEFAULT_LANGUAGE_WRAPPER_KEY
# and module_name.startswith(self.mm_mapping.language_model[0])
# )
# if is_language_model or module_name.startswith(prefix):
# return LoRATarget(wrapper=wrapper, prefix=prefix)
return None