mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-19 21:27:31 +08:00
Merge 57d7267fee3c21f36547f3a6cff4675552879ae2 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
145dcc8f62
@ -275,6 +275,10 @@ The new format of `--lora-modules` is mainly to support the display of parent mo
|
||||
}
|
||||
```
|
||||
|
||||
## LoRA Support for Tower and Connector of Multi-Modal Model
|
||||
|
||||
Currently, vLLM experimentally supports LoRA for the Tower and Connector components of multi-modal models. To enable this feature, you need to implement the corresponding token helper functions for the tower and connector. For more details on the rationale behind this approach, please refer to [PR 26674](https://github.com/vllm-project/vllm/pull/26674). We welcome contributions to extend LoRA support to additional models' tower and connector.
|
||||
|
||||
## Default LoRA Models For Multimodal Models
|
||||
|
||||
Some models, e.g., [Granite Speech](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) and [Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) multimodal, contain LoRA adapter(s) that are expected to always be applied when a given modality is present. This can be a bit tedious to manage with the above approaches, as it requires the user to send the `LoRARequest` (offline) or to filter requests between the base model and LoRA model (server) depending on the content of the request's multimodal data.
|
||||
|
||||
@ -208,6 +208,31 @@ def qwen25vl_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen2vl_language_lora_files():
|
||||
return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-language")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen2vl_vision_tower_connector_lora_files():
|
||||
return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-tower-connector")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen2vl_vision_tower_lora_files():
|
||||
return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-tower")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen25vl_vision_lora_files():
|
||||
return snapshot_download(repo_id="prashanth058/qwen2.5-3b-vl-flickr-lora-vision")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen3vl_vision_lora_files():
|
||||
return snapshot_download(repo_id="prashanth058/qwen3-4b-vl-lora-vision-connector")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tinyllama_lora_files():
|
||||
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
|
||||
|
||||
@ -18,6 +18,7 @@ from vllm.lora.layers import (
|
||||
from vllm.lora.lora_model import LoRAModel
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.model_manager import (
|
||||
DEFAULT_LANGUAGE_WRAPPER_KEY,
|
||||
LoRAMapping,
|
||||
LoRAModelManager,
|
||||
LRUCacheLoRAModelManager,
|
||||
@ -183,9 +184,11 @@ def test_lora_model_manager(dist_init, dummy_model, device):
|
||||
assert manager.activate_adapter(2)
|
||||
assert manager.lora_index_to_id[0] == 3
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
|
||||
assert manager.device == device
|
||||
assert manager.punica_wrapper.device == device
|
||||
assert (
|
||||
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
|
||||
== device
|
||||
)
|
||||
assert hasattr(manager, "supported_lora_modules")
|
||||
assert sorted(manager.supported_lora_modules) == [
|
||||
"dense1",
|
||||
@ -278,8 +281,10 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
||||
assert manager.remove_adapter(3)
|
||||
with pytest.raises(ValueError):
|
||||
assert manager.pin_adapter(3)
|
||||
|
||||
assert manager.punica_wrapper.device == device
|
||||
assert (
|
||||
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
|
||||
== device
|
||||
)
|
||||
assert manager.device == device
|
||||
|
||||
|
||||
@ -402,7 +407,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
assert manager.remove_oldest_adapter()
|
||||
|
||||
assert set(manager.list_adapters()) == {1}
|
||||
assert manager.punica_wrapper.device == device
|
||||
assert (
|
||||
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
|
||||
== device
|
||||
)
|
||||
assert manager.device == device
|
||||
|
||||
|
||||
@ -514,7 +522,10 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa
|
||||
)
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device
|
||||
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
|
||||
DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||
)
|
||||
assert punica_wrapper.device == device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@ -618,7 +629,10 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
|
||||
)
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device
|
||||
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
|
||||
DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||
)
|
||||
assert punica_wrapper.device == device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
|
||||
@ -14,9 +14,12 @@ class TestConfig:
|
||||
lora_path: str
|
||||
max_num_seqs: int = 2
|
||||
max_loras: int = 2
|
||||
max_lora_rank: int = 16
|
||||
max_model_len: int = 4096
|
||||
max_lora_rank: int = 32
|
||||
enable_tower_connector_lora: bool = False
|
||||
max_model_len: int = 8192
|
||||
gpu_memory_utilization: float = 0.85
|
||||
mm_processor_kwargs: dict[str, int] | None = None
|
||||
mm_processor_cache_gb: float = 4
|
||||
|
||||
def __post_init__(self):
|
||||
if self.mm_processor_kwargs is None:
|
||||
@ -48,8 +51,11 @@ class Qwen2VLTester:
|
||||
enable_lora=True,
|
||||
max_loras=self.config.max_loras,
|
||||
max_lora_rank=self.config.max_lora_rank,
|
||||
enable_tower_connector_lora=self.config.enable_tower_connector_lora,
|
||||
trust_remote_code=True,
|
||||
gpu_memory_utilization=self.config.gpu_memory_utilization,
|
||||
mm_processor_kwargs=self.config.mm_processor_kwargs,
|
||||
mm_processor_cache_gb=self.config.mm_processor_cache_gb,
|
||||
max_model_len=self.config.max_model_len,
|
||||
)
|
||||
|
||||
@ -58,6 +64,7 @@ class Qwen2VLTester:
|
||||
images: list[ImageAsset],
|
||||
expected_outputs: list[str],
|
||||
lora_id: int | None = None,
|
||||
lora_name: str | None = None,
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 5,
|
||||
):
|
||||
@ -73,10 +80,11 @@ class Qwen2VLTester:
|
||||
for asset in images
|
||||
]
|
||||
|
||||
lora_request = LoRARequest(str(lora_id), lora_id, self.config.lora_path)
|
||||
lora_request = LoRARequest(
|
||||
lora_name if lora_name else str(lora_id), lora_id, self.config.lora_path
|
||||
)
|
||||
outputs = self.llm.generate(inputs, sampling_params, lora_request=lora_request)
|
||||
generated_texts = [output.outputs[0].text.strip() for output in outputs]
|
||||
|
||||
# Validate outputs
|
||||
for generated, expected in zip(generated_texts, expected_outputs):
|
||||
assert expected.startswith(generated), (
|
||||
@ -127,6 +135,32 @@ EXPECTED_OUTPUTS = [
|
||||
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
|
||||
]
|
||||
|
||||
EXPECTED_OUTPUTS_LANGUAGE = [
|
||||
"A stop sign is shown in an Asian city, with buildings and a car in the "
|
||||
"background.",
|
||||
"The Tokyo Skytree can be seen behind the pink blossoms of the cherry trees.",
|
||||
]
|
||||
|
||||
EXPECTED_OUTPUTS_VISION = [
|
||||
"A stop sign in front of oriental buildings.",
|
||||
"A tree with pink flowers in front of it and a blue sky behind the flowers.",
|
||||
]
|
||||
|
||||
EXPECTED_OUTPUTS_VISION_NO_CONNECTOR = [
|
||||
"A stop sign is located on the street of a Chinese neighborhood.",
|
||||
"A closeup shot of the Tokyo Skytree with pink flowers in the foreground.",
|
||||
]
|
||||
|
||||
EXPECTED_OUTPUTS_VISION_QWEN2_5_VL = [
|
||||
"A black car is driving past a stop sign and a large red and gold arch.",
|
||||
"A view of the Tokyo Skytree through the branches of a cherry blossom tree.",
|
||||
]
|
||||
|
||||
EXPECTED_OUTPUTS_VISION_QWEN3_VL = [
|
||||
"A black SUV drives past a stop sign in front of a Chinese gate.",
|
||||
"A tall white tower is seen through pink flowers.",
|
||||
]
|
||||
|
||||
# NOTE - beam search .text contains the whole text
|
||||
EXPECTED_BEAM_SEARCH_OUTPUTS = [
|
||||
[
|
||||
@ -137,6 +171,7 @@ EXPECTED_BEAM_SEARCH_OUTPUTS = [
|
||||
|
||||
QWEN2VL_MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct"
|
||||
QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
QWEN3VL_MODEL_PATH = "Qwen/Qwen3-VL-4B-Instruct"
|
||||
|
||||
|
||||
def test_qwen2vl_lora(qwen2vl_lora_files):
|
||||
@ -175,3 +210,99 @@ def test_qwen25vl_lora(qwen25vl_lora_files):
|
||||
# Test with different LoRA IDs
|
||||
for lora_id in [1, 2]:
|
||||
tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id)
|
||||
|
||||
|
||||
def test_qwen25vl_vision_lora(qwen25vl_vision_lora_files):
|
||||
config = TestConfig(
|
||||
model_path=QWEN25VL_MODEL_PATH,
|
||||
lora_path=qwen25vl_vision_lora_files,
|
||||
# Currently, tower_connector_lora is incompatible with
|
||||
# the multi-modal processor cache.
|
||||
# TODO: Remove this restriction
|
||||
mm_processor_cache_gb=0,
|
||||
enable_tower_connector_lora=True,
|
||||
)
|
||||
tester = Qwen2VLTester(config)
|
||||
for lora_id in [1, 2]:
|
||||
tester.run_test(
|
||||
TEST_IMAGES,
|
||||
expected_outputs=EXPECTED_OUTPUTS_VISION_QWEN2_5_VL,
|
||||
lora_id=lora_id,
|
||||
)
|
||||
|
||||
|
||||
def test_qwen3vl_vision_lora(qwen3vl_vision_lora_files):
|
||||
config = TestConfig(
|
||||
model_path=QWEN3VL_MODEL_PATH,
|
||||
lora_path=qwen3vl_vision_lora_files,
|
||||
# Currently, tower_connector_lora is incompatible with
|
||||
# the multi-modal processor cache.
|
||||
# TODO: Remove this restriction
|
||||
mm_processor_cache_gb=0,
|
||||
enable_tower_connector_lora=True,
|
||||
)
|
||||
tester = Qwen2VLTester(config)
|
||||
for lora_id in [1, 2]:
|
||||
tester.run_test(
|
||||
TEST_IMAGES,
|
||||
expected_outputs=EXPECTED_OUTPUTS_VISION_QWEN3_VL,
|
||||
lora_id=lora_id,
|
||||
)
|
||||
|
||||
|
||||
def test_qwen2vl_multiple_lora_types(
|
||||
qwen2vl_language_lora_files,
|
||||
qwen2vl_vision_tower_connector_lora_files,
|
||||
qwen2vl_vision_tower_lora_files,
|
||||
):
|
||||
"""
|
||||
Test multiple LoRA adapter types (language, vision tower + connector,
|
||||
vision tower only) using the same LLM instance to verify mm_encoder_cache
|
||||
behavior with different LoRA requests.
|
||||
|
||||
By reusing the same LLM instance across different LoRA requests, we ensure that
|
||||
the multimodal encoder cache correctly manages state transitions between
|
||||
language-only and vision-enabled LoRA adapters.
|
||||
"""
|
||||
config = TestConfig(
|
||||
model_path=QWEN2VL_MODEL_PATH,
|
||||
# We'll override the lora_path for each specific test, but need to provide
|
||||
# an initial path for initialization
|
||||
lora_path=qwen2vl_language_lora_files,
|
||||
# Currently, tower_connector_lora is incompatible with
|
||||
# the multi-modal processor cache.
|
||||
# TODO: Remove this restriction
|
||||
mm_processor_cache_gb=0,
|
||||
enable_tower_connector_lora=True,
|
||||
)
|
||||
tester = Qwen2VLTester(config)
|
||||
|
||||
# Test 1: Language-only LoRA adapter
|
||||
tester.config.lora_path = qwen2vl_language_lora_files
|
||||
for lora_id in [1, 2]:
|
||||
tester.run_test(
|
||||
TEST_IMAGES,
|
||||
expected_outputs=EXPECTED_OUTPUTS_LANGUAGE,
|
||||
lora_id=lora_id,
|
||||
lora_name="language_only",
|
||||
)
|
||||
|
||||
# Test 2: Vision tower + connector LoRA adapter
|
||||
tester.config.lora_path = qwen2vl_vision_tower_connector_lora_files
|
||||
for lora_id in [3, 4]:
|
||||
tester.run_test(
|
||||
TEST_IMAGES,
|
||||
expected_outputs=EXPECTED_OUTPUTS_VISION,
|
||||
lora_id=lora_id,
|
||||
lora_name="vision_tower_connector",
|
||||
)
|
||||
|
||||
# Test 3: Vision tower only LoRA adapter (no connector)
|
||||
tester.config.lora_path = qwen2vl_vision_tower_lora_files
|
||||
for lora_id in [5, 6]:
|
||||
tester.run_test(
|
||||
TEST_IMAGES,
|
||||
expected_outputs=EXPECTED_OUTPUTS_VISION_NO_CONNECTOR,
|
||||
lora_id=lora_id,
|
||||
lora_name="vision_tower",
|
||||
)
|
||||
@ -55,6 +55,11 @@ class LoRAConfig:
|
||||
per prompt. When run in offline mode, the lora IDs for n modalities
|
||||
will be automatically assigned to 1-n with the names of the modalities
|
||||
in alphabetic order."""
|
||||
enable_tower_connector_lora: bool = False
|
||||
"""If `True`, LoRA support for the tower (vision encoder) and connector
|
||||
of multimodal models will be enabled. This is an experimental feature and
|
||||
currently only supports some MM models such as the Qwen VL series. The default
|
||||
is False."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@ -73,6 +78,7 @@ class LoRAConfig:
|
||||
factors.append(self.max_loras)
|
||||
factors.append(self.fully_sharded_loras)
|
||||
factors.append(self.lora_dtype)
|
||||
factors.append(self.enable_tower_connector_lora)
|
||||
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@ -483,6 +483,7 @@ class EngineArgs:
|
||||
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
|
||||
max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
|
||||
lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
|
||||
enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
|
||||
|
||||
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
|
||||
num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
|
||||
@ -996,6 +997,10 @@ class EngineArgs:
|
||||
"--lora-dtype",
|
||||
**lora_kwargs["lora_dtype"],
|
||||
)
|
||||
lora_group.add_argument(
|
||||
"--enable-tower-connector-lora",
|
||||
**lora_kwargs["enable_tower_connector_lora"],
|
||||
)
|
||||
lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"])
|
||||
lora_group.add_argument(
|
||||
"--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
|
||||
@ -1631,6 +1636,7 @@ class EngineArgs:
|
||||
default_mm_loras=self.default_mm_loras,
|
||||
fully_sharded_loras=self.fully_sharded_loras,
|
||||
lora_dtype=self.lora_dtype,
|
||||
enable_tower_connector_lora=self.enable_tower_connector_lora,
|
||||
max_cpu_loras=self.max_cpu_loras
|
||||
if self.max_cpu_loras and self.max_cpu_loras > 0
|
||||
else None,
|
||||
@ -1639,6 +1645,19 @@ class EngineArgs:
|
||||
else None
|
||||
)
|
||||
|
||||
if (
|
||||
lora_config is not None
|
||||
and lora_config.enable_tower_connector_lora
|
||||
and self.mm_processor_cache_gb != 0
|
||||
):
|
||||
raise ValueError(
|
||||
"Currently, enable_tower_connector_lora is "
|
||||
"incompatible with the multi-modal processor cache. "
|
||||
"When enable_tower_connector_lora is set, "
|
||||
"mm_processor_cache_gb must be 0, got %s",
|
||||
self.mm_processor_cache_gb,
|
||||
)
|
||||
|
||||
if (
|
||||
lora_config is not None
|
||||
and speculative_config is not None
|
||||
|
||||
@ -18,7 +18,7 @@ from vllm.lora.layers.row_parallel_linear import (
|
||||
RowParallelLinearWithLoRA,
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
)
|
||||
from vllm.lora.layers.utils import LoRAMapping
|
||||
from vllm.lora.layers.utils import LoRAMapping, LoRAMappingType
|
||||
from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA
|
||||
|
||||
__all__ = [
|
||||
@ -37,6 +37,7 @@ __all__ = [
|
||||
"RowParallelLinearWithShardedLoRA",
|
||||
"ReplicatedLinearWithLoRA",
|
||||
"LoRAMapping",
|
||||
"LoRAMappingType",
|
||||
"FusedMoEWithLoRA",
|
||||
"FusedMoE3DWithLoRA",
|
||||
]
|
||||
|
||||
@ -122,7 +122,9 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
|
||||
# In Transformers modeling backend, x and output have extra batch dimension like
|
||||
original_shape = output.shape if output.ndim == 3 else None
|
||||
|
||||
# In transformers backend, x and output have extra batch dimension like
|
||||
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
|
||||
# therefore we need to flatten the batch dimensions.
|
||||
if x.ndim == 3 and output.ndim == 3:
|
||||
@ -135,6 +137,11 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
# Reshape the flattened output back to its original shape,
|
||||
# as some MM encoders cannot handle flattened inputs.
|
||||
if original_shape is not None:
|
||||
output = output.reshape(original_shape)
|
||||
|
||||
return output
|
||||
|
||||
@property
|
||||
|
||||
@ -2,16 +2,24 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LoRAMappingType(Enum):
|
||||
LANGUAGE = 1
|
||||
TOWER = 2
|
||||
CONNECTOR = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAMapping:
|
||||
index_mapping: tuple[int, ...]
|
||||
prompt_mapping: tuple[int, ...]
|
||||
is_prefill: bool = False
|
||||
type: LoRAMappingType = LoRAMappingType.LANGUAGE
|
||||
|
||||
def __post_init__(self):
|
||||
self.index_mapping = tuple(self.index_mapping)
|
||||
|
||||
@ -9,12 +9,18 @@ import regex as re
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.lora import LoRAConfig, ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoE3DWithLoRA, LoRAMapping
|
||||
from vllm.lora.layers import (
|
||||
BaseLayerWithLoRA,
|
||||
FusedMoE3DWithLoRA,
|
||||
LoRAMapping,
|
||||
LoRAMappingType,
|
||||
)
|
||||
from vllm.lora.lora_model import LoRAModel
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.punica_wrapper import get_punica_wrapper
|
||||
from vllm.lora.punica_wrapper import PunicaWrapperBase, get_punica_wrapper
|
||||
from vllm.lora.utils import (
|
||||
from_layer,
|
||||
from_layer_logits_processor,
|
||||
@ -28,12 +34,15 @@ from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
||||
from vllm.model_executor.models.interfaces import is_pooling_model
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.utils import PPMissingLayer
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.utils.cache import LRUCache
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.worker.utils import MultiModalBudget
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model"
|
||||
|
||||
|
||||
class AdapterLRUCache(LRUCache[int, T]):
|
||||
@ -58,6 +67,7 @@ class LoRAModelManager:
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
device: torch.device,
|
||||
vllm_config: VllmConfig | None = None,
|
||||
):
|
||||
"""Create a LoRAModelManager and adapter for a given model.
|
||||
|
||||
@ -71,6 +81,11 @@ class LoRAModelManager:
|
||||
lora_config: the LoRA configuration.
|
||||
"""
|
||||
self.model: SupportsLoRA = model
|
||||
self.supported_lora_modules = get_supported_lora_modules(self.model)
|
||||
assert self.supported_lora_modules, (
|
||||
f"No supported LoRA modules found in {self.model.__class__.__name__}."
|
||||
)
|
||||
|
||||
self._registered_adapters: dict[int, LoRAModel] = {}
|
||||
# Dict instead of a set for compatibility with LRUCache.
|
||||
self._active_adapters: dict[int, None] = {}
|
||||
@ -82,18 +97,22 @@ class LoRAModelManager:
|
||||
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
||||
self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
|
||||
self.vocab_size = vocab_size
|
||||
self.punica_wrapper = get_punica_wrapper(
|
||||
max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
|
||||
self.supported_lora_modules = get_supported_lora_modules(self.model)
|
||||
assert self.supported_lora_modules, "No supported LoRA modules found in"
|
||||
f" {self.model.__class__.__name__}."
|
||||
|
||||
self.packed_modules_mapping = process_packed_modules_mapping(self.model)
|
||||
|
||||
self.is_pooling_model = is_pooling_model(self.model)
|
||||
self.packed_modules: dict[str, list[str]] = {}
|
||||
self.modules: dict[str, BaseLayerWithLoRA] = {}
|
||||
# Dict instead of a set for compatibility with LRUCache.
|
||||
self._last_mapping: LoRAMapping | None = None
|
||||
self._is_3d_moe_model = is_moe_model(self.model) and self.model.is_3d_moe_weight
|
||||
self._init_punica_wrapper(max_num_batched_tokens, vllm_config)
|
||||
self._create_lora_modules()
|
||||
|
||||
self.model.lora_manager = self
|
||||
|
||||
def _init_punica_wrapper(
|
||||
self, max_num_batched_tokens: int, vllm_config: VllmConfig
|
||||
) -> None:
|
||||
# Used to indicate whether the model is a multimodal model
|
||||
self.supports_mm: bool = (
|
||||
supports_multimodal(self.model)
|
||||
@ -101,15 +120,97 @@ class LoRAModelManager:
|
||||
# text modules (e.g. ChatGLM)
|
||||
and hasattr(self.model, "get_mm_mapping")
|
||||
)
|
||||
self.is_pooling_model = is_pooling_model(self.model)
|
||||
self.packed_modules: dict[str, list[str]] = {}
|
||||
self.modules: dict[str, BaseLayerWithLoRA] = {}
|
||||
# Dict instead of a set for compatibility with LRUCache.
|
||||
self._last_mapping: LoRAMapping | None = None
|
||||
self._is_3d_moe_model = is_moe_model(self.model) and self.model.is_3d_moe_weight
|
||||
self._create_lora_modules()
|
||||
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {}
|
||||
if self.supports_mm:
|
||||
self._maybe_init_mm(vllm_config, max_num_batched_tokens)
|
||||
else:
|
||||
llm_punica_wrapper = get_punica_wrapper(
|
||||
max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
|
||||
self.model.lora_manager = self
|
||||
self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY] = (
|
||||
llm_punica_wrapper
|
||||
)
|
||||
|
||||
def _maybe_init_mm(self, vllm_config: VllmConfig, max_num_batched_tokens) -> None:
|
||||
self.supports_tower_connector_lora = False
|
||||
model_config: ModelConfig = vllm_config.model_config
|
||||
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||
|
||||
# Only one language model can be included in the model.
|
||||
assert len(self.mm_mapping.language_model) == 1
|
||||
|
||||
# Language model punica wrapper
|
||||
llm_punica_wrapper = get_punica_wrapper(
|
||||
max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
lm_prefix = self.mm_mapping.language_model[0]
|
||||
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
|
||||
|
||||
if self.lora_config.enable_tower_connector_lora:
|
||||
self.mm_processor_info = MULTIMODAL_REGISTRY.create_processor(
|
||||
model_config
|
||||
).info
|
||||
self.supports_tower_connector_lora = self.supports_mm and hasattr(
|
||||
self.model, "get_num_mm_encoder_tokens"
|
||||
)
|
||||
if not self.supports_tower_connector_lora:
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"LoRA for the tower and connector of multimodal models is "
|
||||
"experimental and may contain bugs. Please report any related issues on "
|
||||
"GitHub if you encounter them."
|
||||
)
|
||||
|
||||
mm_budget = MultiModalBudget(
|
||||
model_config,
|
||||
vllm_config.scheduler_config,
|
||||
MULTIMODAL_REGISTRY,
|
||||
)
|
||||
limit_per_prompt: int = max(
|
||||
self.mm_processor_info.get_allowed_mm_limits().values()
|
||||
)
|
||||
num_encoder_tokens = self.model.get_num_mm_encoder_tokens(
|
||||
mm_budget.get_encoder_budget()
|
||||
)
|
||||
|
||||
# Tower wrappers
|
||||
tower_punica_wrapper = get_punica_wrapper(
|
||||
num_encoder_tokens,
|
||||
max_batches=self.max_num_seqs * limit_per_prompt,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
for prefix in self.mm_mapping.tower_model:
|
||||
self.punica_wrapper_mapping[prefix] = tower_punica_wrapper
|
||||
|
||||
# Use wrapper for connector if present.
|
||||
if self.mm_mapping.connector:
|
||||
if hasattr(self.model, "get_num_mm_connector_tokens"):
|
||||
connector_tokens = self.model.get_num_mm_connector_tokens(
|
||||
num_encoder_tokens
|
||||
)
|
||||
connector_punica_wrapper = get_punica_wrapper(
|
||||
connector_tokens,
|
||||
max_batches=self.max_num_seqs * limit_per_prompt,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
for prefix in self.mm_mapping.connector:
|
||||
self.punica_wrapper_mapping[prefix] = connector_punica_wrapper
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Connector LoRA support disabled: model does not implement "
|
||||
"get_num_mm_connector_tokens(). This method is required to "
|
||||
"determine the connector's token budget for LoRA operations."
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._registered_adapters)
|
||||
@ -237,8 +338,24 @@ class LoRAModelManager:
|
||||
) # type: ignore
|
||||
|
||||
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
||||
# update lora states
|
||||
self.punica_wrapper.update_metadata(
|
||||
# Default to the main language model wrapper
|
||||
if not (self.supports_mm and self.supports_tower_connector_lora):
|
||||
target_prefix = (
|
||||
self.mm_mapping.language_model[0]
|
||||
if self.supports_mm
|
||||
else DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||
)
|
||||
elif mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
|
||||
target_prefix = self.mm_mapping.tower_model[0]
|
||||
elif mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector:
|
||||
target_prefix = self.mm_mapping.connector[0]
|
||||
else:
|
||||
target_prefix = self.mm_mapping.language_model[0]
|
||||
|
||||
punica_wrapper = self._get_punica_wrapper(target_prefix)
|
||||
assert punica_wrapper is not None
|
||||
|
||||
punica_wrapper.update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
@ -265,15 +382,17 @@ class LoRAModelManager:
|
||||
|
||||
if not self._match_target_modules(module_name):
|
||||
continue
|
||||
# A temporary approach for multimodal models to support LoRA
|
||||
# TODO: Remove this restriction
|
||||
if self._filter_unsupported_mm_module(module_name):
|
||||
|
||||
punica_wrapper = self._get_punica_wrapper(module_name)
|
||||
if punica_wrapper is None:
|
||||
logger.warning(
|
||||
"Regarding multimodal models, vLLM currently only supports "
|
||||
"adding LoRA to language model, %s will be ignored.",
|
||||
"Regarding %s, vLLM currently only supports adding LoRA to"
|
||||
" language model, %s will be ignored.",
|
||||
self.model.__class__.__name__,
|
||||
module_name,
|
||||
)
|
||||
continue
|
||||
|
||||
parts = module_name.split(".")[-1]
|
||||
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
|
||||
if isinstance(module, FusedMoE):
|
||||
@ -328,10 +447,10 @@ class LoRAModelManager:
|
||||
if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
|
||||
continue
|
||||
self.register_module(module_name, new_module)
|
||||
|
||||
self._register_packed_modules(module_name)
|
||||
# All lora layers share the same punica_wrapper based on reference.
|
||||
new_module.set_mapping(self.punica_wrapper)
|
||||
pass
|
||||
new_module.set_mapping(punica_wrapper)
|
||||
|
||||
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||
assert isinstance(module, BaseLayerWithLoRA), (
|
||||
@ -352,7 +471,7 @@ class LoRAModelManager:
|
||||
if (
|
||||
not self._match_target_modules(module_name)
|
||||
or not isinstance(module, BaseLayerWithLoRA)
|
||||
or self._filter_unsupported_mm_module(module_name)
|
||||
or self._get_punica_wrapper(module_name) is None
|
||||
):
|
||||
continue
|
||||
parts = module_name.split(".")
|
||||
@ -441,17 +560,22 @@ class LoRAModelManager:
|
||||
for target_module in self.supported_lora_modules
|
||||
)
|
||||
|
||||
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
|
||||
def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None:
|
||||
"""
|
||||
Regarding multimodal models, vLLM currently only supports adding LoRA to
|
||||
language model. LoRA for other modules, such as the vision tower, will
|
||||
be filtered out.
|
||||
Determine whether this module supports LoRA and which wrapper to use.
|
||||
"""
|
||||
if self.supports_mm:
|
||||
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||
prefix_lst = module_mapping.connector + module_mapping.tower_model
|
||||
return any([module_name.startswith(prefix) for prefix in prefix_lst])
|
||||
return False
|
||||
# For language model (early return)
|
||||
if not self.supports_mm:
|
||||
return self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY]
|
||||
|
||||
# For multimodal model
|
||||
# NOTE Sort by prefix length (descending) to match the longest prefix first
|
||||
# e.g., 'visual.merger' should match 'visual.merger' instead of 'visual.'
|
||||
for prefix in sorted(self.punica_wrapper_mapping.keys(), key=len, reverse=True):
|
||||
if module_name.startswith(prefix):
|
||||
return self.punica_wrapper_mapping[prefix]
|
||||
|
||||
return None
|
||||
|
||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||
parts = module_full_name.split(".")
|
||||
@ -596,9 +720,16 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
device: torch.device,
|
||||
vllm_config: VllmConfig | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device
|
||||
model,
|
||||
max_num_seqs,
|
||||
max_num_batched_tokens,
|
||||
vocab_size,
|
||||
lora_config,
|
||||
device,
|
||||
vllm_config,
|
||||
)
|
||||
self._registered_adapters: LoRALRUCache = LoRALRUCache(
|
||||
self.capacity, self.deactivate_adapter
|
||||
@ -671,6 +802,7 @@ def create_lora_manager(
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
|
||||
**kwargs,
|
||||
@ -684,6 +816,7 @@ def create_lora_manager(
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
vocab_size=vocab_size,
|
||||
lora_config=lora_config,
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -69,6 +69,7 @@ class WorkerLoRAManager:
|
||||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
vllm_config: VllmConfig | None = None,
|
||||
) -> Any:
|
||||
lora_manager = create_lora_manager(
|
||||
model,
|
||||
@ -78,6 +79,7 @@ class WorkerLoRAManager:
|
||||
lora_config=self.lora_config,
|
||||
device=self.device,
|
||||
lora_manager_cls=self._manager_cls,
|
||||
vllm_config=vllm_config,
|
||||
)
|
||||
self._adapter_manager = lora_manager
|
||||
return lora_manager.model
|
||||
@ -161,6 +163,12 @@ class WorkerLoRAManager:
|
||||
if mapping is not None:
|
||||
self._adapter_manager.set_adapter_mapping(mapping)
|
||||
|
||||
def supports_tower_connector_lora(self) -> bool:
|
||||
return (
|
||||
self._adapter_manager.supports_mm
|
||||
and self._adapter_manager.supports_tower_connector_lora
|
||||
)
|
||||
|
||||
def _apply_adapters(self, adapter_requests: set[Any]) -> None:
|
||||
existing_adapters = self.list_adapters()
|
||||
models_map = {
|
||||
@ -210,6 +218,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
vllm_config: VllmConfig | None = None,
|
||||
) -> Any:
|
||||
lora_manager = create_lora_manager(
|
||||
model,
|
||||
@ -219,6 +228,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||
lora_config=self.lora_config,
|
||||
device=self.device,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
vllm_config=vllm_config,
|
||||
)
|
||||
self._adapter_manager = lora_manager
|
||||
return lora_manager.model
|
||||
|
||||
@ -714,3 +714,21 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo
|
||||
connector="model.connector",
|
||||
tower_model="model.vision_model",
|
||||
)
|
||||
|
||||
def get_num_mm_encoder_tokens(
|
||||
self,
|
||||
num_image_tokens: int,
|
||||
) -> int:
|
||||
hf_config = self.config
|
||||
scale_factor = hf_config.scale_factor
|
||||
|
||||
return num_image_tokens * scale_factor**2
|
||||
|
||||
def get_num_mm_connector_tokens(
|
||||
self,
|
||||
num_vision_tokens: int,
|
||||
) -> int:
|
||||
hf_config = self.config
|
||||
scale_factor = hf_config.scale_factor
|
||||
|
||||
return num_vision_tokens // scale_factor**2
|
||||
|
||||
@ -136,6 +136,24 @@ class SupportsMultiModal(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
|
||||
"""
|
||||
Implement this function to enable LoRA support
|
||||
for the tower module of the multi-modal model.
|
||||
Given the number of image tokens, output the number of
|
||||
multi-modal encoder tokens.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
|
||||
"""
|
||||
Implement this function to enable LoRA support
|
||||
for the connector module of the multi-modal model.
|
||||
Given the number of vision tokens, output the number of
|
||||
multi-modal connector tokens.
|
||||
"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
|
||||
"""
|
||||
|
||||
@ -1026,6 +1026,7 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"qkv": ["qkv"], # For vision tower's already-packed QKV
|
||||
}
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
@ -1568,6 +1569,25 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
tower_model="visual.",
|
||||
)
|
||||
|
||||
def get_num_mm_encoder_tokens(
|
||||
self,
|
||||
num_image_tokens: int,
|
||||
) -> int:
|
||||
hf_config = self.config
|
||||
vision_config = hf_config.vision_config
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
|
||||
return num_image_tokens * merge_size**2
|
||||
|
||||
def get_num_mm_connector_tokens(
|
||||
self,
|
||||
num_vision_tokens: int,
|
||||
) -> int:
|
||||
hf_config = self.config
|
||||
vision_config = hf_config.vision_config
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
return num_vision_tokens // merge_size**2
|
||||
|
||||
@classmethod
|
||||
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
|
||||
"""
|
||||
|
||||
@ -1491,6 +1491,25 @@ class Qwen2VLForConditionalGeneration(
|
||||
tower_model="visual.",
|
||||
)
|
||||
|
||||
def get_num_mm_encoder_tokens(
|
||||
self,
|
||||
num_image_tokens: int,
|
||||
) -> int:
|
||||
hf_config = self.config
|
||||
vision_config = hf_config.vision_config
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
|
||||
return num_image_tokens * merge_size**2
|
||||
|
||||
def get_num_mm_connector_tokens(
|
||||
self,
|
||||
num_vision_tokens: int,
|
||||
) -> int:
|
||||
hf_config = self.config
|
||||
vision_config = hf_config.vision_config
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
return num_vision_tokens // merge_size**2
|
||||
|
||||
|
||||
class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
|
||||
pass
|
||||
|
||||
@ -1240,6 +1240,7 @@ class Qwen3VLForConditionalGeneration(
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"qkv": ["qkv"], # For vision tower's already-packed QKV
|
||||
}
|
||||
|
||||
supports_encoder_tp_data = True
|
||||
@ -2087,10 +2088,29 @@ class Qwen3VLForConditionalGeneration(
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="visual.merger",
|
||||
connector=["visual.merger", "visual.deepstack_merger_list"],
|
||||
tower_model="visual.",
|
||||
)
|
||||
|
||||
def get_num_mm_encoder_tokens(
|
||||
self,
|
||||
num_image_tokens: int,
|
||||
) -> int:
|
||||
hf_config = self.config
|
||||
vision_config = hf_config.vision_config
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
|
||||
return num_image_tokens * merge_size**2
|
||||
|
||||
def get_num_mm_connector_tokens(
|
||||
self,
|
||||
num_vision_tokens: int,
|
||||
) -> int:
|
||||
hf_config = self.config
|
||||
vision_config = hf_config.vision_config
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
return num_vision_tokens // merge_size**2
|
||||
|
||||
@classmethod
|
||||
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
|
||||
"""
|
||||
|
||||
@ -406,6 +406,24 @@ class InputProcessor:
|
||||
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
|
||||
return mm_uuids
|
||||
|
||||
def _get_mm_identifier(
|
||||
self,
|
||||
mm_hash: str,
|
||||
lora_request: LoRARequest | None,
|
||||
) -> str:
|
||||
"""
|
||||
When enable_tower_connector_lora is True, multi-modal embeddings
|
||||
vary depending on the LoRA request. Therefore, the mm_hash must be
|
||||
generated based on the LoRA request to prevent incorrect cache hits.
|
||||
"""
|
||||
if (
|
||||
lora_request is None
|
||||
or self.lora_config is None
|
||||
or not self.lora_config.enable_tower_connector_lora
|
||||
):
|
||||
return mm_hash
|
||||
return f"{lora_request.lora_name}:{mm_hash}"
|
||||
|
||||
@staticmethod
|
||||
def assign_request_id(request: EngineCoreRequest):
|
||||
"""Replace the externally supplied request ID with an internal request ID
|
||||
@ -539,7 +557,10 @@ class InputProcessor:
|
||||
MultiModalFeatureSpec(
|
||||
data=decoder_mm_inputs[modality][idx],
|
||||
modality=modality,
|
||||
identifier=decoder_mm_hashes[modality][idx],
|
||||
identifier=self._get_mm_identifier(
|
||||
decoder_mm_hashes[modality][idx],
|
||||
lora_request,
|
||||
),
|
||||
mm_position=decoder_mm_positions[modality][idx],
|
||||
)
|
||||
)
|
||||
|
||||
@ -54,6 +54,7 @@ from vllm.forward_context import (
|
||||
set_forward_context,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping, LoRAMappingType
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
MRotaryEmbedding,
|
||||
@ -79,7 +80,11 @@ from vllm.model_executor.models.interfaces_base import (
|
||||
is_text_generation_model,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import BatchedTensorInputs, MultiModalKwargsItem
|
||||
from vllm.multimodal.inputs import (
|
||||
BatchedTensorInputs,
|
||||
MultiModalKwargsItem,
|
||||
PlaceholderRange,
|
||||
)
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingType
|
||||
@ -2095,7 +2100,11 @@ class GPUModelRunner(
|
||||
def _batch_mm_inputs_from_scheduler(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[list[str], list[MultiModalKwargsItem]]:
|
||||
) -> tuple[
|
||||
list[str],
|
||||
list[MultiModalKwargsItem],
|
||||
list[tuple[str, PlaceholderRange]],
|
||||
]:
|
||||
"""Batch multimodal inputs from scheduled encoder inputs.
|
||||
|
||||
Args:
|
||||
@ -2103,16 +2112,20 @@ class GPUModelRunner(
|
||||
inputs.
|
||||
|
||||
Returns:
|
||||
A tuple of (mm_hashes, mm_kwargs) where:
|
||||
A tuple of (mm_hashes, mm_kwargs, mm_lora_refs) where:
|
||||
- mm_hashes: List of multimodal hashes for each item
|
||||
- mm_kwargs: List of multimodal kwargs for each item
|
||||
- mm_lora_refs: List of (req_id, placeholder_range) for each item
|
||||
"""
|
||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||
if not scheduled_encoder_inputs:
|
||||
return [], []
|
||||
return [], [], []
|
||||
|
||||
mm_hashes = list[str]()
|
||||
mm_kwargs = list[MultiModalKwargsItem]()
|
||||
# Multimodal LoRA reference info to map each multimodal item
|
||||
# back to its request & position
|
||||
mm_lora_refs = list[tuple[str, PlaceholderRange]]()
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
@ -2123,13 +2136,16 @@ class GPUModelRunner(
|
||||
|
||||
mm_hashes.append(mm_feature.identifier)
|
||||
mm_kwargs.append(mm_feature.data)
|
||||
mm_lora_refs.append((req_id, mm_feature.mm_position))
|
||||
|
||||
return mm_hashes, mm_kwargs
|
||||
return mm_hashes, mm_kwargs, mm_lora_refs
|
||||
|
||||
def _execute_mm_encoder(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> list[torch.Tensor]:
|
||||
mm_hashes, mm_kwargs = self._batch_mm_inputs_from_scheduler(scheduler_output)
|
||||
mm_hashes, mm_kwargs, mm_lora_refs = self._batch_mm_inputs_from_scheduler(
|
||||
scheduler_output
|
||||
)
|
||||
|
||||
if not mm_kwargs:
|
||||
return []
|
||||
@ -2142,6 +2158,63 @@ class GPUModelRunner(
|
||||
# multimodal inputs. The proper solution should be reordering the
|
||||
# encoder outputs.
|
||||
model = cast(SupportsMultiModal, self.model)
|
||||
|
||||
if self.lora_config and self.lora_manager.supports_tower_connector_lora():
|
||||
# Build LoRA mappings independently for encoder inputs
|
||||
# (encoder batch structure is different from main batch)
|
||||
prompt_lora_mapping = []
|
||||
token_lora_mapping = []
|
||||
lora_requests = set()
|
||||
encoder_token_counts = []
|
||||
|
||||
for req_id, pos_info in mm_lora_refs:
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
lora_id = int(self.input_batch.request_lora_mapping[req_idx])
|
||||
|
||||
# Prefer pos_info.get_num_embeds to count precise MM embedding tokens.
|
||||
num_tokens = self.model.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
|
||||
pos_info.get_num_embeds
|
||||
)
|
||||
prompt_lora_mapping.append(lora_id)
|
||||
token_lora_mapping.extend([lora_id] * num_tokens)
|
||||
encoder_token_counts.append(num_tokens)
|
||||
|
||||
if lora_id > 0:
|
||||
lora_request = self.input_batch.lora_id_to_lora_request.get(lora_id)
|
||||
if lora_request is not None:
|
||||
lora_requests.add(lora_request)
|
||||
|
||||
# Set tower adapter mapping
|
||||
tower_mapping = LoRAMapping(
|
||||
tuple(token_lora_mapping),
|
||||
tuple(prompt_lora_mapping),
|
||||
is_prefill=True,
|
||||
type=LoRAMappingType.TOWER,
|
||||
)
|
||||
self.lora_manager.set_active_adapters(lora_requests, tower_mapping)
|
||||
|
||||
if hasattr(self.model, "get_num_mm_connector_tokens"):
|
||||
post_op_counts = [
|
||||
self.model.get_num_mm_connector_tokens(num_tokens) # type: ignore[attr-defined]
|
||||
for num_tokens in encoder_token_counts
|
||||
]
|
||||
|
||||
connector_token_mapping = np.repeat(
|
||||
np.array(prompt_lora_mapping, dtype=np.int32),
|
||||
np.array(post_op_counts, dtype=np.int32),
|
||||
)
|
||||
connector_mapping = LoRAMapping(
|
||||
index_mapping=tuple(connector_token_mapping.tolist()),
|
||||
prompt_mapping=tuple(prompt_lora_mapping),
|
||||
is_prefill=True,
|
||||
type=LoRAMappingType.CONNECTOR,
|
||||
)
|
||||
|
||||
self.lora_manager.set_active_adapters(
|
||||
lora_requests,
|
||||
connector_mapping,
|
||||
)
|
||||
|
||||
encoder_outputs: list[torch.Tensor] = []
|
||||
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
|
||||
@ -5,6 +5,7 @@ Define LoRA functionality mixin for model runners.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -13,14 +14,14 @@ import torch.nn as nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.layers import LoRAMapping, LoRAMappingType
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.model_executor.models import supports_lora
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
|
||||
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
|
||||
|
||||
InputBatch = TPUInputBatch | GPUInputBatch
|
||||
InputBatch: TypeAlias = TPUInputBatch | GPUInputBatch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -28,29 +29,28 @@ logger = init_logger(__name__)
|
||||
# Defined as a mixin for GPUModelRunner
|
||||
class LoRAModelRunnerMixin:
|
||||
def load_lora_model(
|
||||
self, model: nn.Module, vllm_config: VllmConfig, device: torch.device
|
||||
self,
|
||||
model: nn.Module,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
) -> nn.Module:
|
||||
if not supports_lora(model):
|
||||
raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.")
|
||||
|
||||
if supports_multimodal(model):
|
||||
logger.warning(
|
||||
"Regarding multimodal models, vLLM currently "
|
||||
"only supports adding LoRA to language model."
|
||||
)
|
||||
# Add LoRA Manager to the Model Runner
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
vllm_config,
|
||||
device,
|
||||
model.embedding_modules,
|
||||
)
|
||||
return self.lora_manager.create_lora_manager(model)
|
||||
return self.lora_manager.create_lora_manager(model, vllm_config)
|
||||
|
||||
def _set_active_loras(
|
||||
self,
|
||||
prompt_lora_mapping: tuple[int, ...],
|
||||
token_lora_mapping: tuple[int, ...],
|
||||
lora_requests: set[LoRARequest],
|
||||
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
|
||||
) -> None:
|
||||
self._ensure_lora_enabled()
|
||||
|
||||
@ -59,7 +59,10 @@ class LoRAModelRunnerMixin:
|
||||
# On cuda platforms we use the same kernels for prefill and
|
||||
# decode and this flag is generally ignored.
|
||||
lora_mapping = LoRAMapping(
|
||||
token_lora_mapping, prompt_lora_mapping, is_prefill=True
|
||||
token_lora_mapping,
|
||||
prompt_lora_mapping,
|
||||
is_prefill=True,
|
||||
type=mapping_type,
|
||||
)
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
@ -72,6 +75,7 @@ class LoRAModelRunnerMixin:
|
||||
input_batch: InputBatch,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_sampled_tokens: np.ndarray | None = None,
|
||||
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
|
||||
) -> None:
|
||||
if num_sampled_tokens is None:
|
||||
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
|
||||
@ -83,7 +87,7 @@ class LoRAModelRunnerMixin:
|
||||
input_batch.make_lora_inputs(num_scheduled_tokens, num_sampled_tokens)
|
||||
)
|
||||
return self._set_active_loras(
|
||||
prompt_lora_mapping, token_lora_mapping, lora_requests
|
||||
prompt_lora_mapping, token_lora_mapping, lora_requests, mapping_type
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
@ -127,6 +131,7 @@ class LoRAModelRunnerMixin:
|
||||
self,
|
||||
lora_config: LoRAConfig | None,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
|
||||
num_sampled_tokens: np.ndarray | None = None,
|
||||
activate_lora: bool = True,
|
||||
):
|
||||
@ -168,7 +173,10 @@ class LoRAModelRunnerMixin:
|
||||
}
|
||||
|
||||
self._set_active_loras(
|
||||
tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests
|
||||
tuple(sample_lora_mapping),
|
||||
tuple(token_lora_mapping),
|
||||
lora_requests,
|
||||
mapping_type,
|
||||
)
|
||||
|
||||
yield
|
||||
@ -181,11 +189,16 @@ class LoRAModelRunnerMixin:
|
||||
num_sampled_tokens: np.ndarray,
|
||||
activate_lora: bool = True,
|
||||
remove_lora: bool = True,
|
||||
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
|
||||
):
|
||||
with (
|
||||
self.maybe_setup_dummy_loras(lora_config, remove_lora),
|
||||
self.maybe_select_dummy_loras(
|
||||
lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora
|
||||
lora_config,
|
||||
num_scheduled_tokens,
|
||||
mapping_type,
|
||||
num_sampled_tokens,
|
||||
activate_lora,
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
@ -31,7 +31,8 @@ from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMappingType
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.tpu import TPUModelLoader
|
||||
@ -1468,11 +1469,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self._hidden_states_dtype = out.dtype
|
||||
|
||||
def _set_active_loras(
|
||||
self, prompt_lora_mapping, token_lora_mapping, lora_requests
|
||||
self,
|
||||
prompt_lora_mapping: tuple[int, ...],
|
||||
token_lora_mapping: tuple[int, ...],
|
||||
lora_requests: set[LoRARequest],
|
||||
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
|
||||
) -> None:
|
||||
torch_xla.sync(wait=False) # Captures input updates
|
||||
super()._set_active_loras(
|
||||
prompt_lora_mapping, token_lora_mapping, lora_requests
|
||||
prompt_lora_mapping, token_lora_mapping, lora_requests, mapping_type
|
||||
)
|
||||
torch_xla.sync(wait=False) # Captures metadata updates
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user