mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 18:09:07 +08:00
remove hacky code
Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
parent
1c8e3c4486
commit
df3ec22106
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
@ -42,7 +43,13 @@ from vllm.v1.worker.utils import MultiModalBudget
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
DEFAULT_WRAPPER_KEY = "language_model"
|
DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LoRATarget:
|
||||||
|
wrapper: PunicaWrapperBase
|
||||||
|
prefix: str
|
||||||
|
|
||||||
|
|
||||||
class AdapterLRUCache(LRUCache[int, T]):
|
class AdapterLRUCache(LRUCache[int, T]):
|
||||||
@ -112,17 +119,16 @@ class LoRAModelManager:
|
|||||||
def _init_punica_wrapper(
|
def _init_punica_wrapper(
|
||||||
self, max_num_batched_tokens: int, vllm_config: VllmConfig
|
self, max_num_batched_tokens: int, vllm_config: VllmConfig
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self._lora_targets: list[tuple[str, PunicaWrapperBase]] = []
|
||||||
llm_punica_wrapper = get_punica_wrapper(
|
llm_punica_wrapper = get_punica_wrapper(
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
max_batches=self.max_num_seqs,
|
max_batches=self.max_num_seqs,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
max_loras=self.lora_config.max_loras,
|
max_loras=self.lora_config.max_loras,
|
||||||
)
|
)
|
||||||
# NOTE This assumes the existence of a language model LoRA
|
|
||||||
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {
|
|
||||||
DEFAULT_WRAPPER_KEY: llm_punica_wrapper
|
|
||||||
}
|
|
||||||
|
|
||||||
|
# NOTE This assumes the existence of a language model LoRA
|
||||||
|
self._lora_targets.append((DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper))
|
||||||
self._maybe_init_mm(vllm_config)
|
self._maybe_init_mm(vllm_config)
|
||||||
|
|
||||||
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
|
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
|
||||||
@ -164,15 +170,27 @@ class LoRAModelManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Only one language model can be included in the model.
|
# Only one language model can be included in the model.
|
||||||
assert len(self.mm_mapping.language_model == 1)
|
assert len(self.mm_mapping.language_model) == 1
|
||||||
|
|
||||||
|
# Update prefix of language model
|
||||||
|
lm_prefix = (
|
||||||
|
self.mm_mapping.language_model[0]
|
||||||
|
if self.supports_mm
|
||||||
|
else DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||||
|
)
|
||||||
|
_, llm_punica_wrapper = self._lora_targets.pop()
|
||||||
|
self._lora_targets.append((lm_prefix, llm_punica_wrapper))
|
||||||
|
|
||||||
# Tower wrappers
|
# Tower wrappers
|
||||||
for name in self.mm_mapping.tower_model:
|
tower_punica_wrapper = get_punica_wrapper(
|
||||||
self.punica_wrapper_mapping[name] = get_punica_wrapper(
|
num_encoder_tokens,
|
||||||
num_encoder_tokens,
|
max_batches=self.max_num_seqs * limit_per_prompt,
|
||||||
max_batches=self.max_num_seqs * limit_per_prompt,
|
device=self.device,
|
||||||
device=self.device,
|
max_loras=self.lora_config.max_loras,
|
||||||
max_loras=self.lora_config.max_loras,
|
)
|
||||||
)
|
for prefix in self.mm_mapping.tower_model:
|
||||||
|
self._lora_targets.append((prefix, tower_punica_wrapper))
|
||||||
|
|
||||||
# Use wrapper for connector if present.
|
# Use wrapper for connector if present.
|
||||||
if self.mm_mapping.connector:
|
if self.mm_mapping.connector:
|
||||||
if hasattr(self.info, "get_num_mm_connector_tokens"):
|
if hasattr(self.info, "get_num_mm_connector_tokens"):
|
||||||
@ -185,12 +203,8 @@ class LoRAModelManager:
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
max_loras=self.lora_config.max_loras,
|
max_loras=self.lora_config.max_loras,
|
||||||
)
|
)
|
||||||
self.punica_wrapper_mapping.update(
|
for prefix in self.mm_mapping.connector:
|
||||||
{
|
self._lora_targets.append((prefix, connector_punica_wrapper))
|
||||||
name: connector_punica_wrapper
|
|
||||||
for name in self.mm_mapping.connector
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Connector LoRA support disabled: model does not implement "
|
"Connector LoRA support disabled: model does not implement "
|
||||||
@ -198,6 +212,11 @@ class LoRAModelManager:
|
|||||||
"determine the connector's token budget for LoRA operations."
|
"determine the connector's token budget for LoRA operations."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Longest-prefix-first
|
||||||
|
self._lora_targets = sorted(
|
||||||
|
self._lora_targets, key=lambda x: len(x[0]), reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self._registered_adapters)
|
return len(self._registered_adapters)
|
||||||
|
|
||||||
@ -326,20 +345,22 @@ class LoRAModelManager:
|
|||||||
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
||||||
# Default to the main language model wrapper
|
# Default to the main language model wrapper
|
||||||
if not (self.supports_mm and self.supports_tower_connector_lora):
|
if not (self.supports_mm and self.supports_tower_connector_lora):
|
||||||
target_wrapper = self.punica_wrapper_mapping[DEFAULT_WRAPPER_KEY]
|
target_prefix = (
|
||||||
|
self.mm_mapping.language_model[0]
|
||||||
|
if self.supports_mm
|
||||||
|
else DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||||
|
)
|
||||||
|
elif mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
|
||||||
|
target_prefix = self.mm_mapping.tower_model[0]
|
||||||
|
elif mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector:
|
||||||
|
target_prefix = self.mm_mapping.connector[0]
|
||||||
else:
|
else:
|
||||||
if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
|
target_prefix = self.mm_mapping.language_model[0]
|
||||||
target_prefix = self.mm_mapping.tower_model[0]
|
|
||||||
elif (
|
|
||||||
mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector
|
|
||||||
):
|
|
||||||
target_prefix = self.mm_mapping.connector[0]
|
|
||||||
else:
|
|
||||||
target_prefix = self.mm_mapping.language_model[0]
|
|
||||||
|
|
||||||
target_wrapper = self.punica_wrapper_mapping[target_prefix]
|
target = self._get_lora_target(target_prefix)
|
||||||
|
assert target is not None
|
||||||
|
|
||||||
target_wrapper.update_metadata(
|
target.wrapper.update_metadata(
|
||||||
mapping,
|
mapping,
|
||||||
self.lora_index_to_id,
|
self.lora_index_to_id,
|
||||||
self.lora_slots + 1,
|
self.lora_slots + 1,
|
||||||
@ -367,7 +388,8 @@ class LoRAModelManager:
|
|||||||
if not self._match_target_modules(module_name):
|
if not self._match_target_modules(module_name):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self._filter_unsupported_mm_module(module_name):
|
target = self._get_lora_target(module_name)
|
||||||
|
if target is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Regarding %s, vLLM currently only supports adding LoRA to"
|
"Regarding %s, vLLM currently only supports adding LoRA to"
|
||||||
" language model, %s will be ignored.",
|
" language model, %s will be ignored.",
|
||||||
@ -433,10 +455,7 @@ class LoRAModelManager:
|
|||||||
|
|
||||||
self._register_packed_modules(module_name)
|
self._register_packed_modules(module_name)
|
||||||
# All lora layers share the same punica_wrapper based on reference.
|
# All lora layers share the same punica_wrapper based on reference.
|
||||||
wrapper = self._get_punica_wrapper_for_module(module_name)
|
new_module.set_mapping(target.wrapper)
|
||||||
if wrapper is None:
|
|
||||||
continue
|
|
||||||
new_module.set_mapping(wrapper)
|
|
||||||
|
|
||||||
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||||
assert isinstance(module, BaseLayerWithLoRA), (
|
assert isinstance(module, BaseLayerWithLoRA), (
|
||||||
@ -457,7 +476,7 @@ class LoRAModelManager:
|
|||||||
if (
|
if (
|
||||||
not self._match_target_modules(module_name)
|
not self._match_target_modules(module_name)
|
||||||
or not isinstance(module, BaseLayerWithLoRA)
|
or not isinstance(module, BaseLayerWithLoRA)
|
||||||
or self._filter_unsupported_mm_module(module_name)
|
or self._get_lora_target(module_name) is None
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
parts = module_name.split(".")
|
parts = module_name.split(".")
|
||||||
@ -546,42 +565,14 @@ class LoRAModelManager:
|
|||||||
for target_module in self.supported_lora_modules
|
for target_module in self.supported_lora_modules
|
||||||
)
|
)
|
||||||
|
|
||||||
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
|
def _get_lora_target(self, module_name: str) -> LoRATarget | None:
|
||||||
"""
|
"""
|
||||||
Regarding multimodal models, vLLM currently only supports adding LoRA to
|
Determine whether this module supports LoRA and which wrapper to use.
|
||||||
language model. LoRA for other modules, such as the vision tower, will
|
|
||||||
be filtered out.
|
|
||||||
"""
|
"""
|
||||||
if not self.supports_mm:
|
for prefix, wrapper in self._lora_targets:
|
||||||
return False
|
if module_name.startswith(prefix):
|
||||||
|
return LoRATarget(wrapper=wrapper, prefix=prefix)
|
||||||
if self.supports_tower_connector_lora:
|
return None
|
||||||
return self._get_punica_wrapper_for_module(module_name) is None
|
|
||||||
|
|
||||||
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
|
|
||||||
return any(module_name.startswith(prefix) for prefix in prefix_lst)
|
|
||||||
|
|
||||||
def _get_punica_wrapper_for_module(
|
|
||||||
self, module_name: str
|
|
||||||
) -> PunicaWrapperBase | None:
|
|
||||||
"""
|
|
||||||
Match the corresponding punica_wrapper based on module_name,
|
|
||||||
and return None if lora is not supported for this module.
|
|
||||||
"""
|
|
||||||
best_prefix = None
|
|
||||||
for prefix in self.punica_wrapper_mapping:
|
|
||||||
if prefix == DEFAULT_WRAPPER_KEY:
|
|
||||||
continue
|
|
||||||
# Ensure matching by the longest prefix.
|
|
||||||
if module_name.startswith(prefix) and (
|
|
||||||
best_prefix is None or len(prefix) > len(best_prefix)
|
|
||||||
):
|
|
||||||
best_prefix = prefix
|
|
||||||
|
|
||||||
if best_prefix is not None:
|
|
||||||
return self.punica_wrapper_mapping[best_prefix]
|
|
||||||
|
|
||||||
return self.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY)
|
|
||||||
|
|
||||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||||
parts = module_full_name.split(".")
|
parts = module_full_name.split(".")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user