mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-22 19:13:39 +08:00
remove hacky code
Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
parent
1c8e3c4486
commit
df3ec22106
@ -3,6 +3,7 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar
|
||||
|
||||
import regex as re
|
||||
@ -42,7 +43,13 @@ from vllm.v1.worker.utils import MultiModalBudget
|
||||
logger = init_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
DEFAULT_WRAPPER_KEY = "language_model"
|
||||
DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LoRATarget:
|
||||
wrapper: PunicaWrapperBase
|
||||
prefix: str
|
||||
|
||||
|
||||
class AdapterLRUCache(LRUCache[int, T]):
|
||||
@ -112,17 +119,16 @@ class LoRAModelManager:
|
||||
def _init_punica_wrapper(
|
||||
self, max_num_batched_tokens: int, vllm_config: VllmConfig
|
||||
) -> None:
|
||||
self._lora_targets: list[tuple[str, PunicaWrapperBase]] = []
|
||||
llm_punica_wrapper = get_punica_wrapper(
|
||||
max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
# NOTE This assumes the existence of a language model LoRA
|
||||
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {
|
||||
DEFAULT_WRAPPER_KEY: llm_punica_wrapper
|
||||
}
|
||||
|
||||
# NOTE This assumes the existence of a language model LoRA
|
||||
self._lora_targets.append((DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper))
|
||||
self._maybe_init_mm(vllm_config)
|
||||
|
||||
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
|
||||
@ -164,15 +170,27 @@ class LoRAModelManager:
|
||||
)
|
||||
|
||||
# Only one language model can be included in the model.
|
||||
assert len(self.mm_mapping.language_model == 1)
|
||||
assert len(self.mm_mapping.language_model) == 1
|
||||
|
||||
# Update prefix of language model
|
||||
lm_prefix = (
|
||||
self.mm_mapping.language_model[0]
|
||||
if self.supports_mm
|
||||
else DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||
)
|
||||
_, llm_punica_wrapper = self._lora_targets.pop()
|
||||
self._lora_targets.append((lm_prefix, llm_punica_wrapper))
|
||||
|
||||
# Tower wrappers
|
||||
for name in self.mm_mapping.tower_model:
|
||||
self.punica_wrapper_mapping[name] = get_punica_wrapper(
|
||||
num_encoder_tokens,
|
||||
max_batches=self.max_num_seqs * limit_per_prompt,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
tower_punica_wrapper = get_punica_wrapper(
|
||||
num_encoder_tokens,
|
||||
max_batches=self.max_num_seqs * limit_per_prompt,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
for prefix in self.mm_mapping.tower_model:
|
||||
self._lora_targets.append((prefix, tower_punica_wrapper))
|
||||
|
||||
# Use wrapper for connector if present.
|
||||
if self.mm_mapping.connector:
|
||||
if hasattr(self.info, "get_num_mm_connector_tokens"):
|
||||
@ -185,12 +203,8 @@ class LoRAModelManager:
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
self.punica_wrapper_mapping.update(
|
||||
{
|
||||
name: connector_punica_wrapper
|
||||
for name in self.mm_mapping.connector
|
||||
}
|
||||
)
|
||||
for prefix in self.mm_mapping.connector:
|
||||
self._lora_targets.append((prefix, connector_punica_wrapper))
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Connector LoRA support disabled: model does not implement "
|
||||
@ -198,6 +212,11 @@ class LoRAModelManager:
|
||||
"determine the connector's token budget for LoRA operations."
|
||||
)
|
||||
|
||||
# Longest-prefix-first
|
||||
self._lora_targets = sorted(
|
||||
self._lora_targets, key=lambda x: len(x[0]), reverse=True
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._registered_adapters)
|
||||
|
||||
@ -326,20 +345,22 @@ class LoRAModelManager:
|
||||
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
||||
# Default to the main language model wrapper
|
||||
if not (self.supports_mm and self.supports_tower_connector_lora):
|
||||
target_wrapper = self.punica_wrapper_mapping[DEFAULT_WRAPPER_KEY]
|
||||
target_prefix = (
|
||||
self.mm_mapping.language_model[0]
|
||||
if self.supports_mm
|
||||
else DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||
)
|
||||
elif mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
|
||||
target_prefix = self.mm_mapping.tower_model[0]
|
||||
elif mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector:
|
||||
target_prefix = self.mm_mapping.connector[0]
|
||||
else:
|
||||
if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
|
||||
target_prefix = self.mm_mapping.tower_model[0]
|
||||
elif (
|
||||
mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector
|
||||
):
|
||||
target_prefix = self.mm_mapping.connector[0]
|
||||
else:
|
||||
target_prefix = self.mm_mapping.language_model[0]
|
||||
target_prefix = self.mm_mapping.language_model[0]
|
||||
|
||||
target_wrapper = self.punica_wrapper_mapping[target_prefix]
|
||||
target = self._get_lora_target(target_prefix)
|
||||
assert target is not None
|
||||
|
||||
target_wrapper.update_metadata(
|
||||
target.wrapper.update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
@ -367,7 +388,8 @@ class LoRAModelManager:
|
||||
if not self._match_target_modules(module_name):
|
||||
continue
|
||||
|
||||
if self._filter_unsupported_mm_module(module_name):
|
||||
target = self._get_lora_target(module_name)
|
||||
if target is None:
|
||||
logger.warning(
|
||||
"Regarding %s, vLLM currently only supports adding LoRA to"
|
||||
" language model, %s will be ignored.",
|
||||
@ -433,10 +455,7 @@ class LoRAModelManager:
|
||||
|
||||
self._register_packed_modules(module_name)
|
||||
# All lora layers share the same punica_wrapper based on reference.
|
||||
wrapper = self._get_punica_wrapper_for_module(module_name)
|
||||
if wrapper is None:
|
||||
continue
|
||||
new_module.set_mapping(wrapper)
|
||||
new_module.set_mapping(target.wrapper)
|
||||
|
||||
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||
assert isinstance(module, BaseLayerWithLoRA), (
|
||||
@ -457,7 +476,7 @@ class LoRAModelManager:
|
||||
if (
|
||||
not self._match_target_modules(module_name)
|
||||
or not isinstance(module, BaseLayerWithLoRA)
|
||||
or self._filter_unsupported_mm_module(module_name)
|
||||
or self._get_lora_target(module_name) is None
|
||||
):
|
||||
continue
|
||||
parts = module_name.split(".")
|
||||
@ -546,42 +565,14 @@ class LoRAModelManager:
|
||||
for target_module in self.supported_lora_modules
|
||||
)
|
||||
|
||||
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
|
||||
def _get_lora_target(self, module_name: str) -> LoRATarget | None:
|
||||
"""
|
||||
Regarding multimodal models, vLLM currently only supports adding LoRA to
|
||||
language model. LoRA for other modules, such as the vision tower, will
|
||||
be filtered out.
|
||||
Determine whether this module supports LoRA and which wrapper to use.
|
||||
"""
|
||||
if not self.supports_mm:
|
||||
return False
|
||||
|
||||
if self.supports_tower_connector_lora:
|
||||
return self._get_punica_wrapper_for_module(module_name) is None
|
||||
|
||||
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
|
||||
return any(module_name.startswith(prefix) for prefix in prefix_lst)
|
||||
|
||||
def _get_punica_wrapper_for_module(
|
||||
self, module_name: str
|
||||
) -> PunicaWrapperBase | None:
|
||||
"""
|
||||
Match the corresponding punica_wrapper based on module_name,
|
||||
and return None if lora is not supported for this module.
|
||||
"""
|
||||
best_prefix = None
|
||||
for prefix in self.punica_wrapper_mapping:
|
||||
if prefix == DEFAULT_WRAPPER_KEY:
|
||||
continue
|
||||
# Ensure matching by the longest prefix.
|
||||
if module_name.startswith(prefix) and (
|
||||
best_prefix is None or len(prefix) > len(best_prefix)
|
||||
):
|
||||
best_prefix = prefix
|
||||
|
||||
if best_prefix is not None:
|
||||
return self.punica_wrapper_mapping[best_prefix]
|
||||
|
||||
return self.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY)
|
||||
for prefix, wrapper in self._lora_targets:
|
||||
if module_name.startswith(prefix):
|
||||
return LoRATarget(wrapper=wrapper, prefix=prefix)
|
||||
return None
|
||||
|
||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||
parts = module_full_name.split(".")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user