Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-12-17 04:24:25 +00:00
parent 94dce5c3d9
commit 3d39188d38
2 changed files with 18 additions and 19 deletions

View File

@ -18,6 +18,7 @@ from vllm.lora.layers import (
from vllm.lora.lora_model import LoRAModel from vllm.lora.lora_model import LoRAModel
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.model_manager import ( from vllm.lora.model_manager import (
DEFAULT_WRAPPER_KEY,
LoRAMapping, LoRAMapping,
LoRAModelManager, LoRAModelManager,
LRUCacheLoRAModelManager, LRUCacheLoRAModelManager,
@ -183,9 +184,8 @@ def test_lora_model_manager(dist_init, dummy_model, device):
assert manager.activate_adapter(2) assert manager.activate_adapter(2)
assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2 assert manager.lora_index_to_id[1] == 2
assert manager.device == device assert manager.device == device
assert manager.punica_wrapper.device == device assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
assert hasattr(manager, "supported_lora_modules") assert hasattr(manager, "supported_lora_modules")
assert sorted(manager.supported_lora_modules) == [ assert sorted(manager.supported_lora_modules) == [
"dense1", "dense1",
@ -278,8 +278,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
assert manager.remove_adapter(3) assert manager.remove_adapter(3)
with pytest.raises(ValueError): with pytest.raises(ValueError):
assert manager.pin_adapter(3) assert manager.pin_adapter(3)
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
assert manager.punica_wrapper.device == device
assert manager.device == device assert manager.device == device
@ -402,7 +401,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
assert manager.remove_oldest_adapter() assert manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == {1} assert set(manager.list_adapters()) == {1}
assert manager.punica_wrapper.device == device assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
assert manager.device == device assert manager.device == device
@ -514,7 +513,10 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa
) )
assert worker_adapter_manager.device == device assert worker_adapter_manager.device == device
assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
DEFAULT_WRAPPER_KEY
)
assert punica_wrapper.device == device
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
@ -618,7 +620,10 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
) )
assert worker_adapter_manager.device == device assert worker_adapter_manager.device == device
assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
DEFAULT_WRAPPER_KEY
)
assert punica_wrapper.device == device
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)

View File

@ -42,7 +42,7 @@ from vllm.v1.worker.utils import MultiModalBudget
logger = init_logger(__name__) logger = init_logger(__name__)
T = TypeVar("T") T = TypeVar("T")
DEFAULT_WRAPPER_KEY = "__default__" DEFAULT_WRAPPER_KEY = "language_model"
class AdapterLRUCache(LRUCache[int, T]): class AdapterLRUCache(LRUCache[int, T]):
@ -112,15 +112,15 @@ class LoRAModelManager:
def _init_punica_wrapper( def _init_punica_wrapper(
self, max_num_batched_tokens: int, vllm_config: VllmConfig self, max_num_batched_tokens: int, vllm_config: VllmConfig
) -> None: ) -> None:
self.punica_wrapper = get_punica_wrapper( llm_punica_wrapper = get_punica_wrapper(
max_num_batched_tokens, max_num_batched_tokens,
max_batches=self.max_num_seqs, max_batches=self.max_num_seqs,
device=self.device, device=self.device,
max_loras=self.lora_config.max_loras, max_loras=self.lora_config.max_loras,
) )
# NOTE This assumes the existence of a language model LoRA
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = { self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {
DEFAULT_WRAPPER_KEY: self.punica_wrapper DEFAULT_WRAPPER_KEY: llm_punica_wrapper
} }
self._maybe_init_mm(vllm_config) self._maybe_init_mm(vllm_config)
@ -163,8 +163,8 @@ class LoRAModelManager:
mm_budget.get_encoder_budget() mm_budget.get_encoder_budget()
) )
self.punica_wrapper_mapping = {} # Only one language model can be included in the model.
assert len(self.mm_mapping.language_model == 1)
# Tower wrappers # Tower wrappers
for name in self.mm_mapping.tower_model: for name in self.mm_mapping.tower_model:
self.punica_wrapper_mapping[name] = get_punica_wrapper( self.punica_wrapper_mapping[name] = get_punica_wrapper(
@ -173,12 +173,6 @@ class LoRAModelManager:
device=self.device, device=self.device,
max_loras=self.lora_config.max_loras, max_loras=self.lora_config.max_loras,
) )
# Language wrapper
self.punica_wrapper_mapping[self.mm_mapping.language_model[0]] = (
self.punica_wrapper
)
# Use wrapper for connector if present. # Use wrapper for connector if present.
if self.mm_mapping.connector: if self.mm_mapping.connector:
if hasattr(self.info, "get_num_mm_connector_tokens"): if hasattr(self.info, "get_num_mm_connector_tokens"):