mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-23 19:00:21 +08:00
Fix
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
94dce5c3d9
commit
3d39188d38
@ -18,6 +18,7 @@ from vllm.lora.layers import (
|
||||
from vllm.lora.lora_model import LoRAModel
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.model_manager import (
|
||||
DEFAULT_WRAPPER_KEY,
|
||||
LoRAMapping,
|
||||
LoRAModelManager,
|
||||
LRUCacheLoRAModelManager,
|
||||
@ -183,9 +184,8 @@ def test_lora_model_manager(dist_init, dummy_model, device):
|
||||
assert manager.activate_adapter(2)
|
||||
assert manager.lora_index_to_id[0] == 3
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
|
||||
assert manager.device == device
|
||||
assert manager.punica_wrapper.device == device
|
||||
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
|
||||
assert hasattr(manager, "supported_lora_modules")
|
||||
assert sorted(manager.supported_lora_modules) == [
|
||||
"dense1",
|
||||
@ -278,8 +278,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
||||
assert manager.remove_adapter(3)
|
||||
with pytest.raises(ValueError):
|
||||
assert manager.pin_adapter(3)
|
||||
|
||||
assert manager.punica_wrapper.device == device
|
||||
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
|
||||
assert manager.device == device
|
||||
|
||||
|
||||
@ -402,7 +401,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
assert manager.remove_oldest_adapter()
|
||||
|
||||
assert set(manager.list_adapters()) == {1}
|
||||
assert manager.punica_wrapper.device == device
|
||||
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
|
||||
assert manager.device == device
|
||||
|
||||
|
||||
@ -514,7 +513,10 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa
|
||||
)
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device
|
||||
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
|
||||
DEFAULT_WRAPPER_KEY
|
||||
)
|
||||
assert punica_wrapper.device == device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@ -618,7 +620,10 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
|
||||
)
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device
|
||||
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
|
||||
DEFAULT_WRAPPER_KEY
|
||||
)
|
||||
assert punica_wrapper.device == device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
|
||||
@ -42,7 +42,7 @@ from vllm.v1.worker.utils import MultiModalBudget
|
||||
logger = init_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
DEFAULT_WRAPPER_KEY = "__default__"
|
||||
DEFAULT_WRAPPER_KEY = "language_model"
|
||||
|
||||
|
||||
class AdapterLRUCache(LRUCache[int, T]):
|
||||
@ -112,15 +112,15 @@ class LoRAModelManager:
|
||||
def _init_punica_wrapper(
|
||||
self, max_num_batched_tokens: int, vllm_config: VllmConfig
|
||||
) -> None:
|
||||
self.punica_wrapper = get_punica_wrapper(
|
||||
llm_punica_wrapper = get_punica_wrapper(
|
||||
max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
|
||||
# NOTE This assumes the existence of a language model LoRA
|
||||
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {
|
||||
DEFAULT_WRAPPER_KEY: self.punica_wrapper
|
||||
DEFAULT_WRAPPER_KEY: llm_punica_wrapper
|
||||
}
|
||||
|
||||
self._maybe_init_mm(vllm_config)
|
||||
@ -163,8 +163,8 @@ class LoRAModelManager:
|
||||
mm_budget.get_encoder_budget()
|
||||
)
|
||||
|
||||
self.punica_wrapper_mapping = {}
|
||||
|
||||
# Only one language model can be included in the model.
|
||||
assert len(self.mm_mapping.language_model == 1)
|
||||
# Tower wrappers
|
||||
for name in self.mm_mapping.tower_model:
|
||||
self.punica_wrapper_mapping[name] = get_punica_wrapper(
|
||||
@ -173,12 +173,6 @@ class LoRAModelManager:
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
|
||||
# Language wrapper
|
||||
self.punica_wrapper_mapping[self.mm_mapping.language_model[0]] = (
|
||||
self.punica_wrapper
|
||||
)
|
||||
|
||||
# Use wrapper for connector if present.
|
||||
if self.mm_mapping.connector:
|
||||
if hasattr(self.info, "get_num_mm_connector_tokens"):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user