mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-20 13:55:44 +08:00
fix bug
Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
parent
df3ec22106
commit
764aa45140
@ -18,7 +18,7 @@ from vllm.lora.layers import (
|
||||
from vllm.lora.lora_model import LoRAModel
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.model_manager import (
|
||||
DEFAULT_WRAPPER_KEY,
|
||||
DEFAULT_LANGUAGE_WRAPPER_KEY,
|
||||
LoRAMapping,
|
||||
LoRAModelManager,
|
||||
LRUCacheLoRAModelManager,
|
||||
@ -185,7 +185,10 @@ def test_lora_model_manager(dist_init, dummy_model, device):
|
||||
assert manager.lora_index_to_id[0] == 3
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
assert manager.device == device
|
||||
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
|
||||
assert (
|
||||
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
|
||||
== device
|
||||
)
|
||||
assert hasattr(manager, "supported_lora_modules")
|
||||
assert sorted(manager.supported_lora_modules) == [
|
||||
"dense1",
|
||||
@ -278,7 +281,10 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
||||
assert manager.remove_adapter(3)
|
||||
with pytest.raises(ValueError):
|
||||
assert manager.pin_adapter(3)
|
||||
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
|
||||
assert (
|
||||
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
|
||||
== device
|
||||
)
|
||||
assert manager.device == device
|
||||
|
||||
|
||||
@ -401,7 +407,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
assert manager.remove_oldest_adapter()
|
||||
|
||||
assert set(manager.list_adapters()) == {1}
|
||||
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
|
||||
assert (
|
||||
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
|
||||
== device
|
||||
)
|
||||
assert manager.device == device
|
||||
|
||||
|
||||
@ -514,7 +523,7 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
|
||||
DEFAULT_WRAPPER_KEY
|
||||
DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||
)
|
||||
assert punica_wrapper.device == device
|
||||
|
||||
@ -621,7 +630,7 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
|
||||
DEFAULT_WRAPPER_KEY
|
||||
DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||
)
|
||||
assert punica_wrapper.device == device
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar
|
||||
@ -119,7 +120,7 @@ class LoRAModelManager:
|
||||
def _init_punica_wrapper(
|
||||
self, max_num_batched_tokens: int, vllm_config: VllmConfig
|
||||
) -> None:
|
||||
self._lora_targets: list[tuple[str, PunicaWrapperBase]] = []
|
||||
self.punica_wrapper_mapping: OrderedDict[str, PunicaWrapperBase] = OrderedDict()
|
||||
llm_punica_wrapper = get_punica_wrapper(
|
||||
max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
@ -128,7 +129,9 @@ class LoRAModelManager:
|
||||
)
|
||||
|
||||
# NOTE This assumes the existence of a language model LoRA
|
||||
self._lora_targets.append((DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper))
|
||||
self.punica_wrapper_mapping.setdefault(
|
||||
DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper
|
||||
)
|
||||
self._maybe_init_mm(vllm_config)
|
||||
|
||||
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
|
||||
@ -178,8 +181,10 @@ class LoRAModelManager:
|
||||
if self.supports_mm
|
||||
else DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||
)
|
||||
_, llm_punica_wrapper = self._lora_targets.pop()
|
||||
self._lora_targets.append((lm_prefix, llm_punica_wrapper))
|
||||
llm_punica_wrapper = self.punica_wrapper_mapping.pop(
|
||||
DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||
)
|
||||
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
|
||||
|
||||
# Tower wrappers
|
||||
tower_punica_wrapper = get_punica_wrapper(
|
||||
@ -189,7 +194,7 @@ class LoRAModelManager:
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
for prefix in self.mm_mapping.tower_model:
|
||||
self._lora_targets.append((prefix, tower_punica_wrapper))
|
||||
self.punica_wrapper_mapping[prefix] = tower_punica_wrapper
|
||||
|
||||
# Use wrapper for connector if present.
|
||||
if self.mm_mapping.connector:
|
||||
@ -204,7 +209,7 @@ class LoRAModelManager:
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
for prefix in self.mm_mapping.connector:
|
||||
self._lora_targets.append((prefix, connector_punica_wrapper))
|
||||
self.punica_wrapper_mapping[prefix] = connector_punica_wrapper
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Connector LoRA support disabled: model does not implement "
|
||||
@ -213,8 +218,12 @@ class LoRAModelManager:
|
||||
)
|
||||
|
||||
# Longest-prefix-first
|
||||
self._lora_targets = sorted(
|
||||
self._lora_targets, key=lambda x: len(x[0]), reverse=True
|
||||
self.punica_wrapper_mapping = OrderedDict(
|
||||
sorted(
|
||||
self.punica_wrapper_mapping.items(),
|
||||
key=lambda x: len(x[0]),
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
@ -569,9 +578,20 @@ class LoRAModelManager:
|
||||
"""
|
||||
Determine whether this module supports LoRA and which wrapper to use.
|
||||
"""
|
||||
for prefix, wrapper in self._lora_targets:
|
||||
if module_name.startswith(prefix):
|
||||
# For language Model (early return)
|
||||
if not self.supports_mm:
|
||||
wrapper = list(self.punica_wrapper_mapping.values())[0]
|
||||
return LoRATarget(wrapper=wrapper, prefix=DEFAULT_LANGUAGE_WRAPPER_KEY)
|
||||
|
||||
# For multimodal model
|
||||
for prefix, wrapper in self.punica_wrapper_mapping.items():
|
||||
is_language_model = (
|
||||
prefix == DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||
and module_name.startswith(self.mm_mapping.language_model[0])
|
||||
)
|
||||
if is_language_model or module_name.startswith(prefix):
|
||||
return LoRATarget(wrapper=wrapper, prefix=prefix)
|
||||
|
||||
return None
|
||||
|
||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user