mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 18:04:27 +08:00
Fix
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
463074fac8
commit
d053aa73e1
@ -2,9 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from collections import OrderedDict
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
@ -47,12 +45,6 @@ T = TypeVar("T")
|
|||||||
DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model"
|
DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class LoRATarget:
|
|
||||||
wrapper: PunicaWrapperBase
|
|
||||||
prefix: str
|
|
||||||
|
|
||||||
|
|
||||||
class AdapterLRUCache(LRUCache[int, T]):
|
class AdapterLRUCache(LRUCache[int, T]):
|
||||||
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
|
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
|
||||||
super().__init__(capacity)
|
super().__init__(capacity)
|
||||||
@ -120,21 +112,6 @@ class LoRAModelManager:
|
|||||||
def _init_punica_wrapper(
|
def _init_punica_wrapper(
|
||||||
self, max_num_batched_tokens: int, vllm_config: VllmConfig
|
self, max_num_batched_tokens: int, vllm_config: VllmConfig
|
||||||
) -> None:
|
) -> None:
|
||||||
self.punica_wrapper_mapping: OrderedDict[str, PunicaWrapperBase] = OrderedDict()
|
|
||||||
llm_punica_wrapper = get_punica_wrapper(
|
|
||||||
max_num_batched_tokens,
|
|
||||||
max_batches=self.max_num_seqs,
|
|
||||||
device=self.device,
|
|
||||||
max_loras=self.lora_config.max_loras,
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE This assumes the existence of a language model LoRA
|
|
||||||
self.punica_wrapper_mapping.setdefault(
|
|
||||||
DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper
|
|
||||||
)
|
|
||||||
self._maybe_init_mm(vllm_config)
|
|
||||||
|
|
||||||
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
|
|
||||||
# Used to indicate whether the model is a multimodal model
|
# Used to indicate whether the model is a multimodal model
|
||||||
self.supports_mm: bool = (
|
self.supports_mm: bool = (
|
||||||
supports_multimodal(self.model)
|
supports_multimodal(self.model)
|
||||||
@ -142,13 +119,39 @@ class LoRAModelManager:
|
|||||||
# text modules (e.g. ChatGLM)
|
# text modules (e.g. ChatGLM)
|
||||||
and hasattr(self.model, "get_mm_mapping")
|
and hasattr(self.model, "get_mm_mapping")
|
||||||
)
|
)
|
||||||
if not self.supports_mm:
|
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {}
|
||||||
return
|
if self.supports_mm:
|
||||||
|
self._maybe_init_mm(vllm_config,max_num_batched_tokens)
|
||||||
|
else:
|
||||||
|
llm_punica_wrapper = get_punica_wrapper(
|
||||||
|
max_num_batched_tokens,
|
||||||
|
max_batches=self.max_num_seqs,
|
||||||
|
device=self.device,
|
||||||
|
max_loras=self.lora_config.max_loras,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY] = (
|
||||||
|
llm_punica_wrapper
|
||||||
|
)
|
||||||
|
|
||||||
|
def _maybe_init_mm(self, vllm_config: VllmConfig, max_num_batched_tokens) -> None:
|
||||||
self.supports_tower_connector_lora = False
|
self.supports_tower_connector_lora = False
|
||||||
model_config: ModelConfig = vllm_config.model_config
|
model_config: ModelConfig = vllm_config.model_config
|
||||||
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||||
|
|
||||||
|
# Only one language model can be included in the model.
|
||||||
|
assert len(self.mm_mapping.language_model) == 1
|
||||||
|
|
||||||
|
# Language model punica wrapper
|
||||||
|
llm_punica_wrapper = get_punica_wrapper(
|
||||||
|
max_num_batched_tokens,
|
||||||
|
max_batches=self.max_num_seqs,
|
||||||
|
device=self.device,
|
||||||
|
max_loras=self.lora_config.max_loras,
|
||||||
|
)
|
||||||
|
lm_prefix = self.mm_mapping.language_model[0]
|
||||||
|
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
|
||||||
|
|
||||||
if self.lora_config.enable_tower_connector_lora:
|
if self.lora_config.enable_tower_connector_lora:
|
||||||
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
|
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
|
||||||
self.supports_tower_connector_lora = self.supports_mm and hasattr(
|
self.supports_tower_connector_lora = self.supports_mm and hasattr(
|
||||||
@ -156,6 +159,7 @@ class LoRAModelManager:
|
|||||||
)
|
)
|
||||||
if not self.supports_tower_connector_lora:
|
if not self.supports_tower_connector_lora:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"LoRA for the tower and connector of multimodal models is "
|
"LoRA for the tower and connector of multimodal models is "
|
||||||
"experimental and may contain bugs. Please report any related issues on "
|
"experimental and may contain bugs. Please report any related issues on "
|
||||||
@ -172,20 +176,6 @@ class LoRAModelManager:
|
|||||||
mm_budget.get_encoder_budget()
|
mm_budget.get_encoder_budget()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only one language model can be included in the model.
|
|
||||||
assert len(self.mm_mapping.language_model) == 1
|
|
||||||
|
|
||||||
# Update prefix of language model
|
|
||||||
lm_prefix = (
|
|
||||||
self.mm_mapping.language_model[0]
|
|
||||||
if self.supports_mm
|
|
||||||
else DEFAULT_LANGUAGE_WRAPPER_KEY
|
|
||||||
)
|
|
||||||
llm_punica_wrapper = self.punica_wrapper_mapping.pop(
|
|
||||||
DEFAULT_LANGUAGE_WRAPPER_KEY
|
|
||||||
)
|
|
||||||
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
|
|
||||||
|
|
||||||
# Tower wrappers
|
# Tower wrappers
|
||||||
tower_punica_wrapper = get_punica_wrapper(
|
tower_punica_wrapper = get_punica_wrapper(
|
||||||
num_encoder_tokens,
|
num_encoder_tokens,
|
||||||
@ -217,15 +207,6 @@ class LoRAModelManager:
|
|||||||
"determine the connector's token budget for LoRA operations."
|
"determine the connector's token budget for LoRA operations."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Longest-prefix-first
|
|
||||||
self.punica_wrapper_mapping = OrderedDict(
|
|
||||||
sorted(
|
|
||||||
self.punica_wrapper_mapping.items(),
|
|
||||||
key=lambda x: len(x[0]),
|
|
||||||
reverse=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self._registered_adapters)
|
return len(self._registered_adapters)
|
||||||
|
|
||||||
@ -366,10 +347,10 @@ class LoRAModelManager:
|
|||||||
else:
|
else:
|
||||||
target_prefix = self.mm_mapping.language_model[0]
|
target_prefix = self.mm_mapping.language_model[0]
|
||||||
|
|
||||||
target = self._get_lora_target(target_prefix)
|
punica_wrapper = self._get_punica_wrapper(target_prefix)
|
||||||
assert target is not None
|
assert punica_wrapper is not None
|
||||||
|
|
||||||
target.wrapper.update_metadata(
|
punica_wrapper.wrapper.update_metadata(
|
||||||
mapping,
|
mapping,
|
||||||
self.lora_index_to_id,
|
self.lora_index_to_id,
|
||||||
self.lora_slots + 1,
|
self.lora_slots + 1,
|
||||||
@ -397,8 +378,8 @@ class LoRAModelManager:
|
|||||||
if not self._match_target_modules(module_name):
|
if not self._match_target_modules(module_name):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
target = self._get_lora_target(module_name)
|
punica_wrapper = self._get_punica_wrapper(module_name)
|
||||||
if target is None:
|
if punica_wrapper is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Regarding %s, vLLM currently only supports adding LoRA to"
|
"Regarding %s, vLLM currently only supports adding LoRA to"
|
||||||
" language model, %s will be ignored.",
|
" language model, %s will be ignored.",
|
||||||
@ -464,7 +445,7 @@ class LoRAModelManager:
|
|||||||
|
|
||||||
self._register_packed_modules(module_name)
|
self._register_packed_modules(module_name)
|
||||||
# All lora layers share the same punica_wrapper based on reference.
|
# All lora layers share the same punica_wrapper based on reference.
|
||||||
new_module.set_mapping(target.wrapper)
|
new_module.set_mapping(punica_wrapper)
|
||||||
|
|
||||||
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||||
assert isinstance(module, BaseLayerWithLoRA), (
|
assert isinstance(module, BaseLayerWithLoRA), (
|
||||||
@ -485,7 +466,7 @@ class LoRAModelManager:
|
|||||||
if (
|
if (
|
||||||
not self._match_target_modules(module_name)
|
not self._match_target_modules(module_name)
|
||||||
or not isinstance(module, BaseLayerWithLoRA)
|
or not isinstance(module, BaseLayerWithLoRA)
|
||||||
or self._get_lora_target(module_name) is None
|
or self._get_punica_wrapper(module_name) is None
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
parts = module_name.split(".")
|
parts = module_name.split(".")
|
||||||
@ -574,23 +555,24 @@ class LoRAModelManager:
|
|||||||
for target_module in self.supported_lora_modules
|
for target_module in self.supported_lora_modules
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_lora_target(self, module_name: str) -> LoRATarget | None:
|
def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None:
|
||||||
"""
|
"""
|
||||||
Determine whether this module supports LoRA and which wrapper to use.
|
Determine whether this module supports LoRA and which wrapper to use.
|
||||||
"""
|
"""
|
||||||
# For language Model (early return)
|
# For language Model (early return)
|
||||||
if not self.supports_mm:
|
if not self.supports_mm:
|
||||||
wrapper = list(self.punica_wrapper_mapping.values())[0]
|
return self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY]
|
||||||
return LoRATarget(wrapper=wrapper, prefix=DEFAULT_LANGUAGE_WRAPPER_KEY)
|
|
||||||
|
|
||||||
|
|
||||||
# For multimodal model
|
# For multimodal model
|
||||||
for prefix, wrapper in self.punica_wrapper_mapping.items():
|
# for prefix, wrapper in self.punica_wrapper_mapping.items():
|
||||||
is_language_model = (
|
# is_language_model = (
|
||||||
prefix == DEFAULT_LANGUAGE_WRAPPER_KEY
|
# prefix == DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||||
and module_name.startswith(self.mm_mapping.language_model[0])
|
# and module_name.startswith(self.mm_mapping.language_model[0])
|
||||||
)
|
# )
|
||||||
if is_language_model or module_name.startswith(prefix):
|
# if is_language_model or module_name.startswith(prefix):
|
||||||
return LoRATarget(wrapper=wrapper, prefix=prefix)
|
# return LoRATarget(wrapper=wrapper, prefix=prefix)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user