remove hacky code

Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-12-18 16:45:41 +00:00
parent 1c8e3c4486
commit df3ec22106

View File

@ -3,6 +3,7 @@
import math import math
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar from typing import TypeVar
import regex as re import regex as re
@ -42,7 +43,13 @@ from vllm.v1.worker.utils import MultiModalBudget
logger = init_logger(__name__) logger = init_logger(__name__)
T = TypeVar("T") T = TypeVar("T")
DEFAULT_WRAPPER_KEY = "language_model" DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model"
@dataclass(frozen=True)
class LoRATarget:
wrapper: PunicaWrapperBase
prefix: str
class AdapterLRUCache(LRUCache[int, T]): class AdapterLRUCache(LRUCache[int, T]):
@ -112,17 +119,16 @@ class LoRAModelManager:
def _init_punica_wrapper( def _init_punica_wrapper(
self, max_num_batched_tokens: int, vllm_config: VllmConfig self, max_num_batched_tokens: int, vllm_config: VllmConfig
) -> None: ) -> None:
self._lora_targets: list[tuple[str, PunicaWrapperBase]] = []
llm_punica_wrapper = get_punica_wrapper( llm_punica_wrapper = get_punica_wrapper(
max_num_batched_tokens, max_num_batched_tokens,
max_batches=self.max_num_seqs, max_batches=self.max_num_seqs,
device=self.device, device=self.device,
max_loras=self.lora_config.max_loras, max_loras=self.lora_config.max_loras,
) )
# NOTE This assumes the existence of a language model LoRA
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {
DEFAULT_WRAPPER_KEY: llm_punica_wrapper
}
# NOTE This assumes the existence of a language model LoRA
self._lora_targets.append((DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper))
self._maybe_init_mm(vllm_config) self._maybe_init_mm(vllm_config)
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None: def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
@ -164,15 +170,27 @@ class LoRAModelManager:
) )
# Only one language model can be included in the model. # Only one language model can be included in the model.
assert len(self.mm_mapping.language_model == 1) assert len(self.mm_mapping.language_model) == 1
# Update prefix of language model
lm_prefix = (
self.mm_mapping.language_model[0]
if self.supports_mm
else DEFAULT_LANGUAGE_WRAPPER_KEY
)
_, llm_punica_wrapper = self._lora_targets.pop()
self._lora_targets.append((lm_prefix, llm_punica_wrapper))
# Tower wrappers # Tower wrappers
for name in self.mm_mapping.tower_model: tower_punica_wrapper = get_punica_wrapper(
self.punica_wrapper_mapping[name] = get_punica_wrapper( num_encoder_tokens,
num_encoder_tokens, max_batches=self.max_num_seqs * limit_per_prompt,
max_batches=self.max_num_seqs * limit_per_prompt, device=self.device,
device=self.device, max_loras=self.lora_config.max_loras,
max_loras=self.lora_config.max_loras, )
) for prefix in self.mm_mapping.tower_model:
self._lora_targets.append((prefix, tower_punica_wrapper))
# Use wrapper for connector if present. # Use wrapper for connector if present.
if self.mm_mapping.connector: if self.mm_mapping.connector:
if hasattr(self.info, "get_num_mm_connector_tokens"): if hasattr(self.info, "get_num_mm_connector_tokens"):
@ -185,12 +203,8 @@ class LoRAModelManager:
device=self.device, device=self.device,
max_loras=self.lora_config.max_loras, max_loras=self.lora_config.max_loras,
) )
self.punica_wrapper_mapping.update( for prefix in self.mm_mapping.connector:
{ self._lora_targets.append((prefix, connector_punica_wrapper))
name: connector_punica_wrapper
for name in self.mm_mapping.connector
}
)
else: else:
logger.warning_once( logger.warning_once(
"Connector LoRA support disabled: model does not implement " "Connector LoRA support disabled: model does not implement "
@ -198,6 +212,11 @@ class LoRAModelManager:
"determine the connector's token budget for LoRA operations." "determine the connector's token budget for LoRA operations."
) )
# Longest-prefix-first
self._lora_targets = sorted(
self._lora_targets, key=lambda x: len(x[0]), reverse=True
)
def __len__(self) -> int: def __len__(self) -> int:
return len(self._registered_adapters) return len(self._registered_adapters)
@ -326,20 +345,22 @@ class LoRAModelManager:
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
# Default to the main language model wrapper # Default to the main language model wrapper
if not (self.supports_mm and self.supports_tower_connector_lora): if not (self.supports_mm and self.supports_tower_connector_lora):
target_wrapper = self.punica_wrapper_mapping[DEFAULT_WRAPPER_KEY] target_prefix = (
self.mm_mapping.language_model[0]
if self.supports_mm
else DEFAULT_LANGUAGE_WRAPPER_KEY
)
elif mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
target_prefix = self.mm_mapping.tower_model[0]
elif mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector:
target_prefix = self.mm_mapping.connector[0]
else: else:
if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model: target_prefix = self.mm_mapping.language_model[0]
target_prefix = self.mm_mapping.tower_model[0]
elif (
mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector
):
target_prefix = self.mm_mapping.connector[0]
else:
target_prefix = self.mm_mapping.language_model[0]
target_wrapper = self.punica_wrapper_mapping[target_prefix] target = self._get_lora_target(target_prefix)
assert target is not None
target_wrapper.update_metadata( target.wrapper.update_metadata(
mapping, mapping,
self.lora_index_to_id, self.lora_index_to_id,
self.lora_slots + 1, self.lora_slots + 1,
@ -367,7 +388,8 @@ class LoRAModelManager:
if not self._match_target_modules(module_name): if not self._match_target_modules(module_name):
continue continue
if self._filter_unsupported_mm_module(module_name): target = self._get_lora_target(module_name)
if target is None:
logger.warning( logger.warning(
"Regarding %s, vLLM currently only supports adding LoRA to" "Regarding %s, vLLM currently only supports adding LoRA to"
" language model, %s will be ignored.", " language model, %s will be ignored.",
@ -433,10 +455,7 @@ class LoRAModelManager:
self._register_packed_modules(module_name) self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference. # All lora layers share the same punica_wrapper based on reference.
wrapper = self._get_punica_wrapper_for_module(module_name) new_module.set_mapping(target.wrapper)
if wrapper is None:
continue
new_module.set_mapping(wrapper)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA), ( assert isinstance(module, BaseLayerWithLoRA), (
@ -457,7 +476,7 @@ class LoRAModelManager:
if ( if (
not self._match_target_modules(module_name) not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA) or not isinstance(module, BaseLayerWithLoRA)
or self._filter_unsupported_mm_module(module_name) or self._get_lora_target(module_name) is None
): ):
continue continue
parts = module_name.split(".") parts = module_name.split(".")
@ -546,42 +565,14 @@ class LoRAModelManager:
for target_module in self.supported_lora_modules for target_module in self.supported_lora_modules
) )
def _filter_unsupported_mm_module(self, module_name: str) -> bool: def _get_lora_target(self, module_name: str) -> LoRATarget | None:
""" """
Regarding multimodal models, vLLM currently only supports adding LoRA to Determine whether this module supports LoRA and which wrapper to use.
language model. LoRA for other modules, such as the vision tower, will
be filtered out.
""" """
if not self.supports_mm: for prefix, wrapper in self._lora_targets:
return False if module_name.startswith(prefix):
return LoRATarget(wrapper=wrapper, prefix=prefix)
if self.supports_tower_connector_lora: return None
return self._get_punica_wrapper_for_module(module_name) is None
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
return any(module_name.startswith(prefix) for prefix in prefix_lst)
def _get_punica_wrapper_for_module(
self, module_name: str
) -> PunicaWrapperBase | None:
"""
Match the corresponding punica_wrapper based on module_name,
and return None if lora is not supported for this module.
"""
best_prefix = None
for prefix in self.punica_wrapper_mapping:
if prefix == DEFAULT_WRAPPER_KEY:
continue
# Ensure matching by the longest prefix.
if module_name.startswith(prefix) and (
best_prefix is None or len(prefix) > len(best_prefix)
):
best_prefix = prefix
if best_prefix is not None:
return self.punica_wrapper_mapping[best_prefix]
return self.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY)
def _register_packed_modules(self, module_full_name: str) -> None: def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".") parts = module_full_name.split(".")