Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-12-17 04:24:25 +00:00
parent 94dce5c3d9
commit 3d39188d38
2 changed files with 18 additions and 19 deletions

View File

@ -18,6 +18,7 @@ from vllm.lora.layers import (
from vllm.lora.lora_model import LoRAModel
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.model_manager import (
DEFAULT_WRAPPER_KEY,
LoRAMapping,
LoRAModelManager,
LRUCacheLoRAModelManager,
@ -183,9 +184,8 @@ def test_lora_model_manager(dist_init, dummy_model, device):
assert manager.activate_adapter(2)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2
assert manager.device == device
assert manager.punica_wrapper.device == device
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
assert hasattr(manager, "supported_lora_modules")
assert sorted(manager.supported_lora_modules) == [
"dense1",
@ -278,8 +278,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
assert manager.remove_adapter(3)
with pytest.raises(ValueError):
assert manager.pin_adapter(3)
assert manager.punica_wrapper.device == device
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
assert manager.device == device
@ -402,7 +401,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
assert manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == {1}
assert manager.punica_wrapper.device == device
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
assert manager.device == device
@ -514,7 +513,10 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa
)
assert worker_adapter_manager.device == device
assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
DEFAULT_WRAPPER_KEY
)
assert punica_wrapper.device == device
@pytest.mark.parametrize("device", DEVICES)
@ -618,7 +620,10 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
)
assert worker_adapter_manager.device == device
assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
DEFAULT_WRAPPER_KEY
)
assert punica_wrapper.device == device
@pytest.mark.parametrize("device", DEVICES)

View File

@ -42,7 +42,7 @@ from vllm.v1.worker.utils import MultiModalBudget
logger = init_logger(__name__)
T = TypeVar("T")
DEFAULT_WRAPPER_KEY = "__default__"
DEFAULT_WRAPPER_KEY = "language_model"
class AdapterLRUCache(LRUCache[int, T]):
@ -112,15 +112,15 @@ class LoRAModelManager:
def _init_punica_wrapper(
self, max_num_batched_tokens: int, vllm_config: VllmConfig
) -> None:
self.punica_wrapper = get_punica_wrapper(
llm_punica_wrapper = get_punica_wrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device,
max_loras=self.lora_config.max_loras,
)
# NOTE This assumes the existence of a language model LoRA
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {
DEFAULT_WRAPPER_KEY: self.punica_wrapper
DEFAULT_WRAPPER_KEY: llm_punica_wrapper
}
self._maybe_init_mm(vllm_config)
@ -163,8 +163,8 @@ class LoRAModelManager:
mm_budget.get_encoder_budget()
)
self.punica_wrapper_mapping = {}
# Only one language model can be included in the model.
assert len(self.mm_mapping.language_model == 1)
# Tower wrappers
for name in self.mm_mapping.tower_model:
self.punica_wrapper_mapping[name] = get_punica_wrapper(
@ -173,12 +173,6 @@ class LoRAModelManager:
device=self.device,
max_loras=self.lora_config.max_loras,
)
# Language wrapper
self.punica_wrapper_mapping[self.mm_mapping.language_model[0]] = (
self.punica_wrapper
)
# Use wrapper for connector if present.
if self.mm_mapping.connector:
if hasattr(self.info, "get_num_mm_connector_tokens"):