Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-12-19 16:57:25 +00:00
parent df3ec22106
commit 764aa45140
2 changed files with 45 additions and 16 deletions

View File

@ -18,7 +18,7 @@ from vllm.lora.layers import (
from vllm.lora.lora_model import LoRAModel from vllm.lora.lora_model import LoRAModel
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.model_manager import ( from vllm.lora.model_manager import (
DEFAULT_WRAPPER_KEY, DEFAULT_LANGUAGE_WRAPPER_KEY,
LoRAMapping, LoRAMapping,
LoRAModelManager, LoRAModelManager,
LRUCacheLoRAModelManager, LRUCacheLoRAModelManager,
@ -185,7 +185,10 @@ def test_lora_model_manager(dist_init, dummy_model, device):
assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2 assert manager.lora_index_to_id[1] == 2
assert manager.device == device assert manager.device == device
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device assert (
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
== device
)
assert hasattr(manager, "supported_lora_modules") assert hasattr(manager, "supported_lora_modules")
assert sorted(manager.supported_lora_modules) == [ assert sorted(manager.supported_lora_modules) == [
"dense1", "dense1",
@ -278,7 +281,10 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
assert manager.remove_adapter(3) assert manager.remove_adapter(3)
with pytest.raises(ValueError): with pytest.raises(ValueError):
assert manager.pin_adapter(3) assert manager.pin_adapter(3)
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device assert (
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
== device
)
assert manager.device == device assert manager.device == device
@ -401,7 +407,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
assert manager.remove_oldest_adapter() assert manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == {1} assert set(manager.list_adapters()) == {1}
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device assert (
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
== device
)
assert manager.device == device assert manager.device == device
@ -514,7 +523,7 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa
assert worker_adapter_manager.device == device assert worker_adapter_manager.device == device
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get( punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
DEFAULT_WRAPPER_KEY DEFAULT_LANGUAGE_WRAPPER_KEY
) )
assert punica_wrapper.device == device assert punica_wrapper.device == device
@ -621,7 +630,7 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
assert worker_adapter_manager.device == device assert worker_adapter_manager.device == device
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get( punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
DEFAULT_WRAPPER_KEY DEFAULT_LANGUAGE_WRAPPER_KEY
) )
assert punica_wrapper.device == device assert punica_wrapper.device == device

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math import math
from collections import OrderedDict
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeVar from typing import TypeVar
@ -119,7 +120,7 @@ class LoRAModelManager:
def _init_punica_wrapper( def _init_punica_wrapper(
self, max_num_batched_tokens: int, vllm_config: VllmConfig self, max_num_batched_tokens: int, vllm_config: VllmConfig
) -> None: ) -> None:
self._lora_targets: list[tuple[str, PunicaWrapperBase]] = [] self.punica_wrapper_mapping: OrderedDict[str, PunicaWrapperBase] = OrderedDict()
llm_punica_wrapper = get_punica_wrapper( llm_punica_wrapper = get_punica_wrapper(
max_num_batched_tokens, max_num_batched_tokens,
max_batches=self.max_num_seqs, max_batches=self.max_num_seqs,
@ -128,7 +129,9 @@ class LoRAModelManager:
) )
# NOTE This assumes the existence of a language model LoRA # NOTE This assumes the existence of a language model LoRA
self._lora_targets.append((DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper)) self.punica_wrapper_mapping.setdefault(
DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper
)
self._maybe_init_mm(vllm_config) self._maybe_init_mm(vllm_config)
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None: def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
@ -178,8 +181,10 @@ class LoRAModelManager:
if self.supports_mm if self.supports_mm
else DEFAULT_LANGUAGE_WRAPPER_KEY else DEFAULT_LANGUAGE_WRAPPER_KEY
) )
_, llm_punica_wrapper = self._lora_targets.pop() llm_punica_wrapper = self.punica_wrapper_mapping.pop(
self._lora_targets.append((lm_prefix, llm_punica_wrapper)) DEFAULT_LANGUAGE_WRAPPER_KEY
)
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
# Tower wrappers # Tower wrappers
tower_punica_wrapper = get_punica_wrapper( tower_punica_wrapper = get_punica_wrapper(
@ -189,7 +194,7 @@ class LoRAModelManager:
max_loras=self.lora_config.max_loras, max_loras=self.lora_config.max_loras,
) )
for prefix in self.mm_mapping.tower_model: for prefix in self.mm_mapping.tower_model:
self._lora_targets.append((prefix, tower_punica_wrapper)) self.punica_wrapper_mapping[prefix] = tower_punica_wrapper
# Use wrapper for connector if present. # Use wrapper for connector if present.
if self.mm_mapping.connector: if self.mm_mapping.connector:
@ -204,7 +209,7 @@ class LoRAModelManager:
max_loras=self.lora_config.max_loras, max_loras=self.lora_config.max_loras,
) )
for prefix in self.mm_mapping.connector: for prefix in self.mm_mapping.connector:
self._lora_targets.append((prefix, connector_punica_wrapper)) self.punica_wrapper_mapping[prefix] = connector_punica_wrapper
else: else:
logger.warning_once( logger.warning_once(
"Connector LoRA support disabled: model does not implement " "Connector LoRA support disabled: model does not implement "
@ -213,8 +218,12 @@ class LoRAModelManager:
) )
# Longest-prefix-first # Longest-prefix-first
self._lora_targets = sorted( self.punica_wrapper_mapping = OrderedDict(
self._lora_targets, key=lambda x: len(x[0]), reverse=True sorted(
self.punica_wrapper_mapping.items(),
key=lambda x: len(x[0]),
reverse=True,
)
) )
def __len__(self) -> int: def __len__(self) -> int:
@ -569,9 +578,20 @@ class LoRAModelManager:
""" """
Determine whether this module supports LoRA and which wrapper to use. Determine whether this module supports LoRA and which wrapper to use.
""" """
for prefix, wrapper in self._lora_targets: # For language Model (early return)
if module_name.startswith(prefix): if not self.supports_mm:
wrapper = list(self.punica_wrapper_mapping.values())[0]
return LoRATarget(wrapper=wrapper, prefix=DEFAULT_LANGUAGE_WRAPPER_KEY)
# For multimodal model
for prefix, wrapper in self.punica_wrapper_mapping.items():
is_language_model = (
prefix == DEFAULT_LANGUAGE_WRAPPER_KEY
and module_name.startswith(self.mm_mapping.language_model[0])
)
if is_language_model or module_name.startswith(prefix):
return LoRATarget(wrapper=wrapper, prefix=prefix) return LoRATarget(wrapper=wrapper, prefix=prefix)
return None return None
def _register_packed_modules(self, module_full_name: str) -> None: def _register_packed_modules(self, module_full_name: str) -> None: