mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 18:07:56 +08:00
Fix
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
94dce5c3d9
commit
3d39188d38
@ -18,6 +18,7 @@ from vllm.lora.layers import (
|
|||||||
from vllm.lora.lora_model import LoRAModel
|
from vllm.lora.lora_model import LoRAModel
|
||||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
from vllm.lora.model_manager import (
|
from vllm.lora.model_manager import (
|
||||||
|
DEFAULT_WRAPPER_KEY,
|
||||||
LoRAMapping,
|
LoRAMapping,
|
||||||
LoRAModelManager,
|
LoRAModelManager,
|
||||||
LRUCacheLoRAModelManager,
|
LRUCacheLoRAModelManager,
|
||||||
@ -183,9 +184,8 @@ def test_lora_model_manager(dist_init, dummy_model, device):
|
|||||||
assert manager.activate_adapter(2)
|
assert manager.activate_adapter(2)
|
||||||
assert manager.lora_index_to_id[0] == 3
|
assert manager.lora_index_to_id[0] == 3
|
||||||
assert manager.lora_index_to_id[1] == 2
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
assert manager.device == device
|
assert manager.device == device
|
||||||
assert manager.punica_wrapper.device == device
|
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
|
||||||
assert hasattr(manager, "supported_lora_modules")
|
assert hasattr(manager, "supported_lora_modules")
|
||||||
assert sorted(manager.supported_lora_modules) == [
|
assert sorted(manager.supported_lora_modules) == [
|
||||||
"dense1",
|
"dense1",
|
||||||
@ -278,8 +278,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
|||||||
assert manager.remove_adapter(3)
|
assert manager.remove_adapter(3)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
assert manager.pin_adapter(3)
|
assert manager.pin_adapter(3)
|
||||||
|
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
|
||||||
assert manager.punica_wrapper.device == device
|
|
||||||
assert manager.device == device
|
assert manager.device == device
|
||||||
|
|
||||||
|
|
||||||
@ -402,7 +401,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
|||||||
assert manager.remove_oldest_adapter()
|
assert manager.remove_oldest_adapter()
|
||||||
|
|
||||||
assert set(manager.list_adapters()) == {1}
|
assert set(manager.list_adapters()) == {1}
|
||||||
assert manager.punica_wrapper.device == device
|
assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device
|
||||||
assert manager.device == device
|
assert manager.device == device
|
||||||
|
|
||||||
|
|
||||||
@ -514,7 +513,10 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert worker_adapter_manager.device == device
|
assert worker_adapter_manager.device == device
|
||||||
assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device
|
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
|
||||||
|
DEFAULT_WRAPPER_KEY
|
||||||
|
)
|
||||||
|
assert punica_wrapper.device == device
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", DEVICES)
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
@ -618,7 +620,10 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert worker_adapter_manager.device == device
|
assert worker_adapter_manager.device == device
|
||||||
assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device
|
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
|
||||||
|
DEFAULT_WRAPPER_KEY
|
||||||
|
)
|
||||||
|
assert punica_wrapper.device == device
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", DEVICES)
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
|
|||||||
@ -42,7 +42,7 @@ from vllm.v1.worker.utils import MultiModalBudget
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
DEFAULT_WRAPPER_KEY = "__default__"
|
DEFAULT_WRAPPER_KEY = "language_model"
|
||||||
|
|
||||||
|
|
||||||
class AdapterLRUCache(LRUCache[int, T]):
|
class AdapterLRUCache(LRUCache[int, T]):
|
||||||
@ -112,15 +112,15 @@ class LoRAModelManager:
|
|||||||
def _init_punica_wrapper(
|
def _init_punica_wrapper(
|
||||||
self, max_num_batched_tokens: int, vllm_config: VllmConfig
|
self, max_num_batched_tokens: int, vllm_config: VllmConfig
|
||||||
) -> None:
|
) -> None:
|
||||||
self.punica_wrapper = get_punica_wrapper(
|
llm_punica_wrapper = get_punica_wrapper(
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
max_batches=self.max_num_seqs,
|
max_batches=self.max_num_seqs,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
max_loras=self.lora_config.max_loras,
|
max_loras=self.lora_config.max_loras,
|
||||||
)
|
)
|
||||||
|
# NOTE This assumes the existence of a language model LoRA
|
||||||
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {
|
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {
|
||||||
DEFAULT_WRAPPER_KEY: self.punica_wrapper
|
DEFAULT_WRAPPER_KEY: llm_punica_wrapper
|
||||||
}
|
}
|
||||||
|
|
||||||
self._maybe_init_mm(vllm_config)
|
self._maybe_init_mm(vllm_config)
|
||||||
@ -163,8 +163,8 @@ class LoRAModelManager:
|
|||||||
mm_budget.get_encoder_budget()
|
mm_budget.get_encoder_budget()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.punica_wrapper_mapping = {}
|
# Only one language model can be included in the model.
|
||||||
|
assert len(self.mm_mapping.language_model == 1)
|
||||||
# Tower wrappers
|
# Tower wrappers
|
||||||
for name in self.mm_mapping.tower_model:
|
for name in self.mm_mapping.tower_model:
|
||||||
self.punica_wrapper_mapping[name] = get_punica_wrapper(
|
self.punica_wrapper_mapping[name] = get_punica_wrapper(
|
||||||
@ -173,12 +173,6 @@ class LoRAModelManager:
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
max_loras=self.lora_config.max_loras,
|
max_loras=self.lora_config.max_loras,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Language wrapper
|
|
||||||
self.punica_wrapper_mapping[self.mm_mapping.language_model[0]] = (
|
|
||||||
self.punica_wrapper
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use wrapper for connector if present.
|
# Use wrapper for connector if present.
|
||||||
if self.mm_mapping.connector:
|
if self.mm_mapping.connector:
|
||||||
if hasattr(self.info, "get_num_mm_connector_tokens"):
|
if hasattr(self.info, "get_num_mm_connector_tokens"):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user