mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 23:05:45 +08:00
Fix
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
463074fac8
commit
d053aa73e1
@ -2,9 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar
|
||||
|
||||
import regex as re
|
||||
@ -47,12 +45,6 @@ T = TypeVar("T")
|
||||
DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LoRATarget:
|
||||
wrapper: PunicaWrapperBase
|
||||
prefix: str
|
||||
|
||||
|
||||
class AdapterLRUCache(LRUCache[int, T]):
|
||||
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
|
||||
super().__init__(capacity)
|
||||
@ -120,21 +112,6 @@ class LoRAModelManager:
|
||||
def _init_punica_wrapper(
|
||||
self, max_num_batched_tokens: int, vllm_config: VllmConfig
|
||||
) -> None:
|
||||
self.punica_wrapper_mapping: OrderedDict[str, PunicaWrapperBase] = OrderedDict()
|
||||
llm_punica_wrapper = get_punica_wrapper(
|
||||
max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
|
||||
# NOTE This assumes the existence of a language model LoRA
|
||||
self.punica_wrapper_mapping.setdefault(
|
||||
DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper
|
||||
)
|
||||
self._maybe_init_mm(vllm_config)
|
||||
|
||||
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
|
||||
# Used to indicate whether the model is a multimodal model
|
||||
self.supports_mm: bool = (
|
||||
supports_multimodal(self.model)
|
||||
@ -142,13 +119,39 @@ class LoRAModelManager:
|
||||
# text modules (e.g. ChatGLM)
|
||||
and hasattr(self.model, "get_mm_mapping")
|
||||
)
|
||||
if not self.supports_mm:
|
||||
return
|
||||
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {}
|
||||
if self.supports_mm:
|
||||
self._maybe_init_mm(vllm_config,max_num_batched_tokens)
|
||||
else:
|
||||
llm_punica_wrapper = get_punica_wrapper(
|
||||
max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
|
||||
self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY] = (
|
||||
llm_punica_wrapper
|
||||
)
|
||||
|
||||
def _maybe_init_mm(self, vllm_config: VllmConfig, max_num_batched_tokens) -> None:
|
||||
self.supports_tower_connector_lora = False
|
||||
model_config: ModelConfig = vllm_config.model_config
|
||||
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||
|
||||
# Only one language model can be included in the model.
|
||||
assert len(self.mm_mapping.language_model) == 1
|
||||
|
||||
# Language model punica wrapper
|
||||
llm_punica_wrapper = get_punica_wrapper(
|
||||
max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
lm_prefix = self.mm_mapping.language_model[0]
|
||||
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
|
||||
|
||||
if self.lora_config.enable_tower_connector_lora:
|
||||
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
|
||||
self.supports_tower_connector_lora = self.supports_mm and hasattr(
|
||||
@ -156,6 +159,7 @@ class LoRAModelManager:
|
||||
)
|
||||
if not self.supports_tower_connector_lora:
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"LoRA for the tower and connector of multimodal models is "
|
||||
"experimental and may contain bugs. Please report any related issues on "
|
||||
@ -172,20 +176,6 @@ class LoRAModelManager:
|
||||
mm_budget.get_encoder_budget()
|
||||
)
|
||||
|
||||
# Only one language model can be included in the model.
|
||||
assert len(self.mm_mapping.language_model) == 1
|
||||
|
||||
# Update prefix of language model
|
||||
lm_prefix = (
|
||||
self.mm_mapping.language_model[0]
|
||||
if self.supports_mm
|
||||
else DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||
)
|
||||
llm_punica_wrapper = self.punica_wrapper_mapping.pop(
|
||||
DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||
)
|
||||
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
|
||||
|
||||
# Tower wrappers
|
||||
tower_punica_wrapper = get_punica_wrapper(
|
||||
num_encoder_tokens,
|
||||
@ -217,15 +207,6 @@ class LoRAModelManager:
|
||||
"determine the connector's token budget for LoRA operations."
|
||||
)
|
||||
|
||||
# Longest-prefix-first
|
||||
self.punica_wrapper_mapping = OrderedDict(
|
||||
sorted(
|
||||
self.punica_wrapper_mapping.items(),
|
||||
key=lambda x: len(x[0]),
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._registered_adapters)
|
||||
|
||||
@ -366,10 +347,10 @@ class LoRAModelManager:
|
||||
else:
|
||||
target_prefix = self.mm_mapping.language_model[0]
|
||||
|
||||
target = self._get_lora_target(target_prefix)
|
||||
assert target is not None
|
||||
punica_wrapper = self._get_punica_wrapper(target_prefix)
|
||||
assert punica_wrapper is not None
|
||||
|
||||
target.wrapper.update_metadata(
|
||||
punica_wrapper.wrapper.update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
@ -397,8 +378,8 @@ class LoRAModelManager:
|
||||
if not self._match_target_modules(module_name):
|
||||
continue
|
||||
|
||||
target = self._get_lora_target(module_name)
|
||||
if target is None:
|
||||
punica_wrapper = self._get_punica_wrapper(module_name)
|
||||
if punica_wrapper is None:
|
||||
logger.warning(
|
||||
"Regarding %s, vLLM currently only supports adding LoRA to"
|
||||
" language model, %s will be ignored.",
|
||||
@ -464,7 +445,7 @@ class LoRAModelManager:
|
||||
|
||||
self._register_packed_modules(module_name)
|
||||
# All lora layers share the same punica_wrapper based on reference.
|
||||
new_module.set_mapping(target.wrapper)
|
||||
new_module.set_mapping(punica_wrapper)
|
||||
|
||||
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||
assert isinstance(module, BaseLayerWithLoRA), (
|
||||
@ -485,7 +466,7 @@ class LoRAModelManager:
|
||||
if (
|
||||
not self._match_target_modules(module_name)
|
||||
or not isinstance(module, BaseLayerWithLoRA)
|
||||
or self._get_lora_target(module_name) is None
|
||||
or self._get_punica_wrapper(module_name) is None
|
||||
):
|
||||
continue
|
||||
parts = module_name.split(".")
|
||||
@ -574,23 +555,24 @@ class LoRAModelManager:
|
||||
for target_module in self.supported_lora_modules
|
||||
)
|
||||
|
||||
def _get_lora_target(self, module_name: str) -> LoRATarget | None:
|
||||
def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None:
|
||||
"""
|
||||
Determine whether this module supports LoRA and which wrapper to use.
|
||||
"""
|
||||
# For language Model (early return)
|
||||
if not self.supports_mm:
|
||||
wrapper = list(self.punica_wrapper_mapping.values())[0]
|
||||
return LoRATarget(wrapper=wrapper, prefix=DEFAULT_LANGUAGE_WRAPPER_KEY)
|
||||
return self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY]
|
||||
|
||||
|
||||
|
||||
# For multimodal model
|
||||
for prefix, wrapper in self.punica_wrapper_mapping.items():
|
||||
is_language_model = (
|
||||
prefix == DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||
and module_name.startswith(self.mm_mapping.language_model[0])
|
||||
)
|
||||
if is_language_model or module_name.startswith(prefix):
|
||||
return LoRATarget(wrapper=wrapper, prefix=prefix)
|
||||
# for prefix, wrapper in self.punica_wrapper_mapping.items():
|
||||
# is_language_model = (
|
||||
# prefix == DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||
# and module_name.startswith(self.mm_mapping.language_model[0])
|
||||
# )
|
||||
# if is_language_model or module_name.startswith(prefix):
|
||||
# return LoRATarget(wrapper=wrapper, prefix=prefix)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user