mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 23:11:27 +08:00
fix bug
Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
parent
df3ec22106
commit
764aa45140
@ -18,7 +18,7 @@ from vllm.lora.layers import (
|
|||||||
from vllm.lora.lora_model import LoRAModel
|
from vllm.lora.lora_model import LoRAModel
|
||||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
from vllm.lora.model_manager import (
|
from vllm.lora.model_manager import (
|
||||||
DEFAULT_WRAPPER_KEY,
|
DEFAULT_LANGUAGE_WRAPPER_KEY,
|
||||||
LoRAMapping,
|
LoRAMapping,
|
||||||
LoRAModelManager,
|
LoRAModelManager,
|
||||||
LRUCacheLoRAModelManager,
|
LRUCacheLoRAModelManager,
|
||||||
@ -185,7 +185,10 @@ def test_lora_model_manager(dist_init, dummy_model, device):
|
|||||||
assert manager.lora_index_to_id[0] == 3
|
assert manager.lora_index_to_id[0] == 3
|
||||||
assert manager.lora_index_to_id[1] == 2
|
assert manager.lora_index_to_id[1] == 2
|
||||||
assert manager.device == device
|
assert manager.device == device
|
||||||
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
|
assert (
|
||||||
|
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
|
||||||
|
== device
|
||||||
|
)
|
||||||
assert hasattr(manager, "supported_lora_modules")
|
assert hasattr(manager, "supported_lora_modules")
|
||||||
assert sorted(manager.supported_lora_modules) == [
|
assert sorted(manager.supported_lora_modules) == [
|
||||||
"dense1",
|
"dense1",
|
||||||
@ -278,7 +281,10 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
|||||||
assert manager.remove_adapter(3)
|
assert manager.remove_adapter(3)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
assert manager.pin_adapter(3)
|
assert manager.pin_adapter(3)
|
||||||
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
|
assert (
|
||||||
|
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
|
||||||
|
== device
|
||||||
|
)
|
||||||
assert manager.device == device
|
assert manager.device == device
|
||||||
|
|
||||||
|
|
||||||
@ -401,7 +407,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
|||||||
assert manager.remove_oldest_adapter()
|
assert manager.remove_oldest_adapter()
|
||||||
|
|
||||||
assert set(manager.list_adapters()) == {1}
|
assert set(manager.list_adapters()) == {1}
|
||||||
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
|
assert (
|
||||||
|
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
|
||||||
|
== device
|
||||||
|
)
|
||||||
assert manager.device == device
|
assert manager.device == device
|
||||||
|
|
||||||
|
|
||||||
@ -514,7 +523,7 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa
|
|||||||
|
|
||||||
assert worker_adapter_manager.device == device
|
assert worker_adapter_manager.device == device
|
||||||
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
|
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
|
||||||
DEFAULT_WRAPPER_KEY
|
DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||||
)
|
)
|
||||||
assert punica_wrapper.device == device
|
assert punica_wrapper.device == device
|
||||||
|
|
||||||
@ -621,7 +630,7 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
|
|||||||
|
|
||||||
assert worker_adapter_manager.device == device
|
assert worker_adapter_manager.device == device
|
||||||
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
|
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
|
||||||
DEFAULT_WRAPPER_KEY
|
DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||||
)
|
)
|
||||||
assert punica_wrapper.device == device
|
assert punica_wrapper.device == device
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from collections import OrderedDict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
@ -119,7 +120,7 @@ class LoRAModelManager:
|
|||||||
def _init_punica_wrapper(
|
def _init_punica_wrapper(
|
||||||
self, max_num_batched_tokens: int, vllm_config: VllmConfig
|
self, max_num_batched_tokens: int, vllm_config: VllmConfig
|
||||||
) -> None:
|
) -> None:
|
||||||
self._lora_targets: list[tuple[str, PunicaWrapperBase]] = []
|
self.punica_wrapper_mapping: OrderedDict[str, PunicaWrapperBase] = OrderedDict()
|
||||||
llm_punica_wrapper = get_punica_wrapper(
|
llm_punica_wrapper = get_punica_wrapper(
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
max_batches=self.max_num_seqs,
|
max_batches=self.max_num_seqs,
|
||||||
@ -128,7 +129,9 @@ class LoRAModelManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# NOTE This assumes the existence of a language model LoRA
|
# NOTE This assumes the existence of a language model LoRA
|
||||||
self._lora_targets.append((DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper))
|
self.punica_wrapper_mapping.setdefault(
|
||||||
|
DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper
|
||||||
|
)
|
||||||
self._maybe_init_mm(vllm_config)
|
self._maybe_init_mm(vllm_config)
|
||||||
|
|
||||||
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
|
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
|
||||||
@ -178,8 +181,10 @@ class LoRAModelManager:
|
|||||||
if self.supports_mm
|
if self.supports_mm
|
||||||
else DEFAULT_LANGUAGE_WRAPPER_KEY
|
else DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||||
)
|
)
|
||||||
_, llm_punica_wrapper = self._lora_targets.pop()
|
llm_punica_wrapper = self.punica_wrapper_mapping.pop(
|
||||||
self._lora_targets.append((lm_prefix, llm_punica_wrapper))
|
DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||||
|
)
|
||||||
|
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
|
||||||
|
|
||||||
# Tower wrappers
|
# Tower wrappers
|
||||||
tower_punica_wrapper = get_punica_wrapper(
|
tower_punica_wrapper = get_punica_wrapper(
|
||||||
@ -189,7 +194,7 @@ class LoRAModelManager:
|
|||||||
max_loras=self.lora_config.max_loras,
|
max_loras=self.lora_config.max_loras,
|
||||||
)
|
)
|
||||||
for prefix in self.mm_mapping.tower_model:
|
for prefix in self.mm_mapping.tower_model:
|
||||||
self._lora_targets.append((prefix, tower_punica_wrapper))
|
self.punica_wrapper_mapping[prefix] = tower_punica_wrapper
|
||||||
|
|
||||||
# Use wrapper for connector if present.
|
# Use wrapper for connector if present.
|
||||||
if self.mm_mapping.connector:
|
if self.mm_mapping.connector:
|
||||||
@ -204,7 +209,7 @@ class LoRAModelManager:
|
|||||||
max_loras=self.lora_config.max_loras,
|
max_loras=self.lora_config.max_loras,
|
||||||
)
|
)
|
||||||
for prefix in self.mm_mapping.connector:
|
for prefix in self.mm_mapping.connector:
|
||||||
self._lora_targets.append((prefix, connector_punica_wrapper))
|
self.punica_wrapper_mapping[prefix] = connector_punica_wrapper
|
||||||
else:
|
else:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Connector LoRA support disabled: model does not implement "
|
"Connector LoRA support disabled: model does not implement "
|
||||||
@ -213,8 +218,12 @@ class LoRAModelManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Longest-prefix-first
|
# Longest-prefix-first
|
||||||
self._lora_targets = sorted(
|
self.punica_wrapper_mapping = OrderedDict(
|
||||||
self._lora_targets, key=lambda x: len(x[0]), reverse=True
|
sorted(
|
||||||
|
self.punica_wrapper_mapping.items(),
|
||||||
|
key=lambda x: len(x[0]),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
@ -569,9 +578,20 @@ class LoRAModelManager:
|
|||||||
"""
|
"""
|
||||||
Determine whether this module supports LoRA and which wrapper to use.
|
Determine whether this module supports LoRA and which wrapper to use.
|
||||||
"""
|
"""
|
||||||
for prefix, wrapper in self._lora_targets:
|
# For language Model (early return)
|
||||||
if module_name.startswith(prefix):
|
if not self.supports_mm:
|
||||||
|
wrapper = list(self.punica_wrapper_mapping.values())[0]
|
||||||
|
return LoRATarget(wrapper=wrapper, prefix=DEFAULT_LANGUAGE_WRAPPER_KEY)
|
||||||
|
|
||||||
|
# For multimodal model
|
||||||
|
for prefix, wrapper in self.punica_wrapper_mapping.items():
|
||||||
|
is_language_model = (
|
||||||
|
prefix == DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||||
|
and module_name.startswith(self.mm_mapping.language_model[0])
|
||||||
|
)
|
||||||
|
if is_language_model or module_name.startswith(prefix):
|
||||||
return LoRATarget(wrapper=wrapper, prefix=prefix)
|
return LoRATarget(wrapper=wrapper, prefix=prefix)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user