remove hacky code

Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-12-18 16:45:41 +00:00
parent 1c8e3c4486
commit df3ec22106

View File

@ -3,6 +3,7 @@
import math
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar
import regex as re
@ -42,7 +43,13 @@ from vllm.v1.worker.utils import MultiModalBudget
logger = init_logger(__name__)
T = TypeVar("T")
DEFAULT_WRAPPER_KEY = "language_model"
DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model"
@dataclass(frozen=True)
class LoRATarget:
wrapper: PunicaWrapperBase
prefix: str
class AdapterLRUCache(LRUCache[int, T]):
@ -112,17 +119,16 @@ class LoRAModelManager:
def _init_punica_wrapper(
self, max_num_batched_tokens: int, vllm_config: VllmConfig
) -> None:
self._lora_targets: list[tuple[str, PunicaWrapperBase]] = []
llm_punica_wrapper = get_punica_wrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device,
max_loras=self.lora_config.max_loras,
)
# NOTE This assumes the existence of a language model LoRA
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {
DEFAULT_WRAPPER_KEY: llm_punica_wrapper
}
# NOTE This assumes the existence of a language model LoRA
self._lora_targets.append((DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper))
self._maybe_init_mm(vllm_config)
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
@ -164,15 +170,27 @@ class LoRAModelManager:
)
# Only one language model can be included in the model.
assert len(self.mm_mapping.language_model == 1)
assert len(self.mm_mapping.language_model) == 1
# Update prefix of language model
lm_prefix = (
self.mm_mapping.language_model[0]
if self.supports_mm
else DEFAULT_LANGUAGE_WRAPPER_KEY
)
_, llm_punica_wrapper = self._lora_targets.pop()
self._lora_targets.append((lm_prefix, llm_punica_wrapper))
# Tower wrappers
for name in self.mm_mapping.tower_model:
self.punica_wrapper_mapping[name] = get_punica_wrapper(
num_encoder_tokens,
max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device,
max_loras=self.lora_config.max_loras,
)
tower_punica_wrapper = get_punica_wrapper(
num_encoder_tokens,
max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device,
max_loras=self.lora_config.max_loras,
)
for prefix in self.mm_mapping.tower_model:
self._lora_targets.append((prefix, tower_punica_wrapper))
# Use wrapper for connector if present.
if self.mm_mapping.connector:
if hasattr(self.info, "get_num_mm_connector_tokens"):
@ -185,12 +203,8 @@ class LoRAModelManager:
device=self.device,
max_loras=self.lora_config.max_loras,
)
self.punica_wrapper_mapping.update(
{
name: connector_punica_wrapper
for name in self.mm_mapping.connector
}
)
for prefix in self.mm_mapping.connector:
self._lora_targets.append((prefix, connector_punica_wrapper))
else:
logger.warning_once(
"Connector LoRA support disabled: model does not implement "
@ -198,6 +212,11 @@ class LoRAModelManager:
"determine the connector's token budget for LoRA operations."
)
# Longest-prefix-first
self._lora_targets = sorted(
self._lora_targets, key=lambda x: len(x[0]), reverse=True
)
def __len__(self) -> int:
return len(self._registered_adapters)
@ -326,20 +345,22 @@ class LoRAModelManager:
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
# Default to the main language model wrapper
if not (self.supports_mm and self.supports_tower_connector_lora):
target_wrapper = self.punica_wrapper_mapping[DEFAULT_WRAPPER_KEY]
target_prefix = (
self.mm_mapping.language_model[0]
if self.supports_mm
else DEFAULT_LANGUAGE_WRAPPER_KEY
)
elif mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
target_prefix = self.mm_mapping.tower_model[0]
elif mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector:
target_prefix = self.mm_mapping.connector[0]
else:
if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
target_prefix = self.mm_mapping.tower_model[0]
elif (
mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector
):
target_prefix = self.mm_mapping.connector[0]
else:
target_prefix = self.mm_mapping.language_model[0]
target_prefix = self.mm_mapping.language_model[0]
target_wrapper = self.punica_wrapper_mapping[target_prefix]
target = self._get_lora_target(target_prefix)
assert target is not None
target_wrapper.update_metadata(
target.wrapper.update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
@ -367,7 +388,8 @@ class LoRAModelManager:
if not self._match_target_modules(module_name):
continue
if self._filter_unsupported_mm_module(module_name):
target = self._get_lora_target(module_name)
if target is None:
logger.warning(
"Regarding %s, vLLM currently only supports adding LoRA to"
" language model, %s will be ignored.",
@ -433,10 +455,7 @@ class LoRAModelManager:
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
wrapper = self._get_punica_wrapper_for_module(module_name)
if wrapper is None:
continue
new_module.set_mapping(wrapper)
new_module.set_mapping(target.wrapper)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA), (
@ -457,7 +476,7 @@ class LoRAModelManager:
if (
not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA)
or self._filter_unsupported_mm_module(module_name)
or self._get_lora_target(module_name) is None
):
continue
parts = module_name.split(".")
@ -546,42 +565,14 @@ class LoRAModelManager:
for target_module in self.supported_lora_modules
)
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
def _get_lora_target(self, module_name: str) -> LoRATarget | None:
"""
Regarding multimodal models, vLLM currently only supports adding LoRA to
language model. LoRA for other modules, such as the vision tower, will
be filtered out.
Determine whether this module supports LoRA and which wrapper to use.
"""
if not self.supports_mm:
return False
if self.supports_tower_connector_lora:
return self._get_punica_wrapper_for_module(module_name) is None
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
return any(module_name.startswith(prefix) for prefix in prefix_lst)
def _get_punica_wrapper_for_module(
self, module_name: str
) -> PunicaWrapperBase | None:
"""
Match the corresponding punica_wrapper based on module_name,
and return None if lora is not supported for this module.
"""
best_prefix = None
for prefix in self.punica_wrapper_mapping:
if prefix == DEFAULT_WRAPPER_KEY:
continue
# Ensure matching by the longest prefix.
if module_name.startswith(prefix) and (
best_prefix is None or len(prefix) > len(best_prefix)
):
best_prefix = prefix
if best_prefix is not None:
return self.punica_wrapper_mapping[best_prefix]
return self.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY)
for prefix, wrapper in self._lora_targets:
if module_name.startswith(prefix):
return LoRATarget(wrapper=wrapper, prefix=prefix)
return None
def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".")