Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-12-20 01:47:11 +00:00
parent 463074fac8
commit d053aa73e1

View File

@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math import math
from collections import OrderedDict
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar from typing import TypeVar
import regex as re import regex as re
@ -47,12 +45,6 @@ T = TypeVar("T")
DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model" DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model"
@dataclass(frozen=True)
class LoRATarget:
wrapper: PunicaWrapperBase
prefix: str
class AdapterLRUCache(LRUCache[int, T]): class AdapterLRUCache(LRUCache[int, T]):
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]): def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
super().__init__(capacity) super().__init__(capacity)
@ -120,21 +112,6 @@ class LoRAModelManager:
def _init_punica_wrapper( def _init_punica_wrapper(
self, max_num_batched_tokens: int, vllm_config: VllmConfig self, max_num_batched_tokens: int, vllm_config: VllmConfig
) -> None: ) -> None:
self.punica_wrapper_mapping: OrderedDict[str, PunicaWrapperBase] = OrderedDict()
llm_punica_wrapper = get_punica_wrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device,
max_loras=self.lora_config.max_loras,
)
# NOTE This assumes the existence of a language model LoRA
self.punica_wrapper_mapping.setdefault(
DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper
)
self._maybe_init_mm(vllm_config)
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
# Used to indicate whether the model is a multimodal model # Used to indicate whether the model is a multimodal model
self.supports_mm: bool = ( self.supports_mm: bool = (
supports_multimodal(self.model) supports_multimodal(self.model)
@ -142,13 +119,39 @@ class LoRAModelManager:
# text modules (e.g. ChatGLM) # text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping") and hasattr(self.model, "get_mm_mapping")
) )
if not self.supports_mm: self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {}
return if self.supports_mm:
self._maybe_init_mm(vllm_config,max_num_batched_tokens)
else:
llm_punica_wrapper = get_punica_wrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device,
max_loras=self.lora_config.max_loras,
)
self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY] = (
llm_punica_wrapper
)
def _maybe_init_mm(self, vllm_config: VllmConfig, max_num_batched_tokens) -> None:
self.supports_tower_connector_lora = False self.supports_tower_connector_lora = False
model_config: ModelConfig = vllm_config.model_config model_config: ModelConfig = vllm_config.model_config
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping() self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
# Only one language model can be included in the model.
assert len(self.mm_mapping.language_model) == 1
# Language model punica wrapper
llm_punica_wrapper = get_punica_wrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device,
max_loras=self.lora_config.max_loras,
)
lm_prefix = self.mm_mapping.language_model[0]
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
if self.lora_config.enable_tower_connector_lora: if self.lora_config.enable_tower_connector_lora:
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
self.supports_tower_connector_lora = self.supports_mm and hasattr( self.supports_tower_connector_lora = self.supports_mm and hasattr(
@ -156,6 +159,7 @@ class LoRAModelManager:
) )
if not self.supports_tower_connector_lora: if not self.supports_tower_connector_lora:
return return
logger.warning( logger.warning(
"LoRA for the tower and connector of multimodal models is " "LoRA for the tower and connector of multimodal models is "
"experimental and may contain bugs. Please report any related issues on " "experimental and may contain bugs. Please report any related issues on "
@ -172,20 +176,6 @@ class LoRAModelManager:
mm_budget.get_encoder_budget() mm_budget.get_encoder_budget()
) )
# Only one language model can be included in the model.
assert len(self.mm_mapping.language_model) == 1
# Update prefix of language model
lm_prefix = (
self.mm_mapping.language_model[0]
if self.supports_mm
else DEFAULT_LANGUAGE_WRAPPER_KEY
)
llm_punica_wrapper = self.punica_wrapper_mapping.pop(
DEFAULT_LANGUAGE_WRAPPER_KEY
)
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
# Tower wrappers # Tower wrappers
tower_punica_wrapper = get_punica_wrapper( tower_punica_wrapper = get_punica_wrapper(
num_encoder_tokens, num_encoder_tokens,
@ -217,15 +207,6 @@ class LoRAModelManager:
"determine the connector's token budget for LoRA operations." "determine the connector's token budget for LoRA operations."
) )
# Longest-prefix-first
self.punica_wrapper_mapping = OrderedDict(
sorted(
self.punica_wrapper_mapping.items(),
key=lambda x: len(x[0]),
reverse=True,
)
)
def __len__(self) -> int: def __len__(self) -> int:
return len(self._registered_adapters) return len(self._registered_adapters)
@ -366,10 +347,10 @@ class LoRAModelManager:
else: else:
target_prefix = self.mm_mapping.language_model[0] target_prefix = self.mm_mapping.language_model[0]
target = self._get_lora_target(target_prefix) punica_wrapper = self._get_punica_wrapper(target_prefix)
assert target is not None assert punica_wrapper is not None
target.wrapper.update_metadata( punica_wrapper.wrapper.update_metadata(
mapping, mapping,
self.lora_index_to_id, self.lora_index_to_id,
self.lora_slots + 1, self.lora_slots + 1,
@ -397,8 +378,8 @@ class LoRAModelManager:
if not self._match_target_modules(module_name): if not self._match_target_modules(module_name):
continue continue
target = self._get_lora_target(module_name) punica_wrapper = self._get_punica_wrapper(module_name)
if target is None: if punica_wrapper is None:
logger.warning( logger.warning(
"Regarding %s, vLLM currently only supports adding LoRA to" "Regarding %s, vLLM currently only supports adding LoRA to"
" language model, %s will be ignored.", " language model, %s will be ignored.",
@ -464,7 +445,7 @@ class LoRAModelManager:
self._register_packed_modules(module_name) self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference. # All lora layers share the same punica_wrapper based on reference.
new_module.set_mapping(target.wrapper) new_module.set_mapping(punica_wrapper)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA), ( assert isinstance(module, BaseLayerWithLoRA), (
@ -485,7 +466,7 @@ class LoRAModelManager:
if ( if (
not self._match_target_modules(module_name) not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA) or not isinstance(module, BaseLayerWithLoRA)
or self._get_lora_target(module_name) is None or self._get_punica_wrapper(module_name) is None
): ):
continue continue
parts = module_name.split(".") parts = module_name.split(".")
@ -574,23 +555,24 @@ class LoRAModelManager:
for target_module in self.supported_lora_modules for target_module in self.supported_lora_modules
) )
def _get_lora_target(self, module_name: str) -> LoRATarget | None: def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None:
""" """
Determine whether this module supports LoRA and which wrapper to use. Determine whether this module supports LoRA and which wrapper to use.
""" """
# For language Model (early return) # For language Model (early return)
if not self.supports_mm: if not self.supports_mm:
wrapper = list(self.punica_wrapper_mapping.values())[0] return self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY]
return LoRATarget(wrapper=wrapper, prefix=DEFAULT_LANGUAGE_WRAPPER_KEY)
# For multimodal model # For multimodal model
for prefix, wrapper in self.punica_wrapper_mapping.items(): # for prefix, wrapper in self.punica_wrapper_mapping.items():
is_language_model = ( # is_language_model = (
prefix == DEFAULT_LANGUAGE_WRAPPER_KEY # prefix == DEFAULT_LANGUAGE_WRAPPER_KEY
and module_name.startswith(self.mm_mapping.language_model[0]) # and module_name.startswith(self.mm_mapping.language_model[0])
) # )
if is_language_model or module_name.startswith(prefix): # if is_language_model or module_name.startswith(prefix):
return LoRATarget(wrapper=wrapper, prefix=prefix) # return LoRATarget(wrapper=wrapper, prefix=prefix)
return None return None