Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-12-19 16:57:25 +00:00
parent df3ec22106
commit 764aa45140
2 changed files with 45 additions and 16 deletions

View File

@ -18,7 +18,7 @@ from vllm.lora.layers import (
from vllm.lora.lora_model import LoRAModel
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.model_manager import (
DEFAULT_WRAPPER_KEY,
DEFAULT_LANGUAGE_WRAPPER_KEY,
LoRAMapping,
LoRAModelManager,
LRUCacheLoRAModelManager,
@ -185,7 +185,10 @@ def test_lora_model_manager(dist_init, dummy_model, device):
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2
assert manager.device == device
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
assert (
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
== device
)
assert hasattr(manager, "supported_lora_modules")
assert sorted(manager.supported_lora_modules) == [
"dense1",
@ -278,7 +281,10 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
assert manager.remove_adapter(3)
with pytest.raises(ValueError):
assert manager.pin_adapter(3)
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
assert (
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
== device
)
assert manager.device == device
@ -401,7 +407,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
assert manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == {1}
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
assert (
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
== device
)
assert manager.device == device
@ -514,7 +523,7 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa
assert worker_adapter_manager.device == device
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
DEFAULT_WRAPPER_KEY
DEFAULT_LANGUAGE_WRAPPER_KEY
)
assert punica_wrapper.device == device
@ -621,7 +630,7 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
assert worker_adapter_manager.device == device
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
DEFAULT_WRAPPER_KEY
DEFAULT_LANGUAGE_WRAPPER_KEY
)
assert punica_wrapper.device == device

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar
@ -119,7 +120,7 @@ class LoRAModelManager:
def _init_punica_wrapper(
self, max_num_batched_tokens: int, vllm_config: VllmConfig
) -> None:
self._lora_targets: list[tuple[str, PunicaWrapperBase]] = []
self.punica_wrapper_mapping: OrderedDict[str, PunicaWrapperBase] = OrderedDict()
llm_punica_wrapper = get_punica_wrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs,
@ -128,7 +129,9 @@ class LoRAModelManager:
)
# NOTE This assumes the existence of a language model LoRA
self._lora_targets.append((DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper))
self.punica_wrapper_mapping.setdefault(
DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper
)
self._maybe_init_mm(vllm_config)
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
@ -178,8 +181,10 @@ class LoRAModelManager:
if self.supports_mm
else DEFAULT_LANGUAGE_WRAPPER_KEY
)
_, llm_punica_wrapper = self._lora_targets.pop()
self._lora_targets.append((lm_prefix, llm_punica_wrapper))
llm_punica_wrapper = self.punica_wrapper_mapping.pop(
DEFAULT_LANGUAGE_WRAPPER_KEY
)
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
# Tower wrappers
tower_punica_wrapper = get_punica_wrapper(
@ -189,7 +194,7 @@ class LoRAModelManager:
max_loras=self.lora_config.max_loras,
)
for prefix in self.mm_mapping.tower_model:
self._lora_targets.append((prefix, tower_punica_wrapper))
self.punica_wrapper_mapping[prefix] = tower_punica_wrapper
# Use wrapper for connector if present.
if self.mm_mapping.connector:
@ -204,7 +209,7 @@ class LoRAModelManager:
max_loras=self.lora_config.max_loras,
)
for prefix in self.mm_mapping.connector:
self._lora_targets.append((prefix, connector_punica_wrapper))
self.punica_wrapper_mapping[prefix] = connector_punica_wrapper
else:
logger.warning_once(
"Connector LoRA support disabled: model does not implement "
@ -213,8 +218,12 @@ class LoRAModelManager:
)
# Longest-prefix-first
self._lora_targets = sorted(
self._lora_targets, key=lambda x: len(x[0]), reverse=True
self.punica_wrapper_mapping = OrderedDict(
sorted(
self.punica_wrapper_mapping.items(),
key=lambda x: len(x[0]),
reverse=True,
)
)
def __len__(self) -> int:
@ -569,9 +578,20 @@ class LoRAModelManager:
"""
Determine whether this module supports LoRA and which wrapper to use.
"""
for prefix, wrapper in self._lora_targets:
if module_name.startswith(prefix):
# For language Model (early return)
if not self.supports_mm:
wrapper = list(self.punica_wrapper_mapping.values())[0]
return LoRATarget(wrapper=wrapper, prefix=DEFAULT_LANGUAGE_WRAPPER_KEY)
# For multimodal model
for prefix, wrapper in self.punica_wrapper_mapping.items():
is_language_model = (
prefix == DEFAULT_LANGUAGE_WRAPPER_KEY
and module_name.startswith(self.mm_mapping.language_model[0])
)
if is_language_model or module_name.startswith(prefix):
return LoRATarget(wrapper=wrapper, prefix=prefix)
return None
def _register_packed_modules(self, module_full_name: str) -> None: