Merge branch 'mlm-full-lora-support' of https://github.com/jeejeelee/vllm into mlm-full-lora-support

This commit is contained in:
Anexdeus 2025-12-20 20:40:31 +03:00
commit 2b03137fca
6 changed files with 61 additions and 19 deletions

View File

@ -246,6 +246,7 @@ class InputPreprocessor:
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> MultiModalInputs:
"""
Apply the model's multi-modal processor to a multi-modal prompt,
@ -262,6 +263,7 @@ class InputPreprocessor:
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
)
mm_hashes = mm_input["mm_hashes"]
@ -359,6 +361,7 @@ class InputPreprocessor:
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> TokenInputs | MultiModalInputs:
prompt_text = parsed_content["prompt"]
@ -370,6 +373,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs") or {},
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
)
else:
prompt_token_ids = self._tokenize_prompt(
@ -389,6 +393,7 @@ class InputPreprocessor:
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> SingletonInputs:
"""
Extract the singleton inputs from a prompt.
@ -415,6 +420,7 @@ class InputPreprocessor:
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
)
if parsed["type"] == "str":
return self._process_text(
@ -626,6 +632,7 @@ class InputPreprocessor:
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> DecoderOnlyInputs:
"""
For decoder-only models:
@ -645,6 +652,7 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
)
return self._build_decoder_only_llm_inputs(prompt_comps)
@ -655,6 +663,7 @@ class InputPreprocessor:
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> ProcessorInputs:
if self.model_config.is_encoder_decoder:
# Encoder-decoder model requires special mapping of
@ -676,6 +685,7 @@ class InputPreprocessor:
cast(SingletonPrompt, prompt),
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
)
def preprocess(
@ -684,12 +694,14 @@ class InputPreprocessor:
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
res = self._preprocess(
prompt,
tokenization_kwargs,
mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
)
if self.mm_processor_cache and self.mm_cache_stats is not None:

View File

@ -154,8 +154,11 @@ class LoRAModelManager:
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
if self.lora_config.enable_tower_connector_lora:
self.mm_processor_info = MULTIMODAL_REGISTRY.create_processor(
model_config
).info
self.supports_tower_connector_lora = self.supports_mm and hasattr(
self.model, "get_num_mm_encoder_tokens"
self.mm_processor_info, "get_num_mm_encoder_tokens"
)
if not self.supports_tower_connector_lora:
return
@ -171,8 +174,10 @@ class LoRAModelManager:
vllm_config.scheduler_config,
MULTIMODAL_REGISTRY,
)
limit_per_prompt: int = max(self.model.get_allowed_mm_limits().values())
num_encoder_tokens = self.model.get_num_mm_encoder_tokens(
limit_per_prompt: int = max(
self.mm_processor_info.get_allowed_mm_limits().values()
)
num_encoder_tokens = self.mm_processor_info.get_num_mm_encoder_tokens(
mm_budget.get_encoder_budget()
)
@ -188,8 +193,8 @@ class LoRAModelManager:
# Use wrapper for connector if present.
if self.mm_mapping.connector:
if hasattr(self.model, "get_num_mm_connector_tokens"):
connector_tokens = self.model.get_num_mm_connector_tokens(
if hasattr(self.mm_processor_info, "get_num_mm_connector_tokens"):
connector_tokens = self.mm_processor_info.get_num_mm_connector_tokens(
num_encoder_tokens
)
connector_punica_wrapper = get_punica_wrapper(

View File

@ -163,6 +163,12 @@ class WorkerLoRAManager:
if mapping is not None:
self._adapter_manager.set_adapter_mapping(mapping)
def supports_tower_connector_lora(self) -> bool:
return (
self._adapter_manager.supports_mm
and self._adapter_manager.supports_tower_connector_lora
)
def _apply_adapters(self, adapter_requests: set[Any]) -> None:
existing_adapters = self.list_adapters()
models_map = {

View File

@ -1672,6 +1672,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object],
*,
mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> MultiModalHashes:
"""Create MM hashes to be returned.
@ -1683,6 +1684,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hashes: MultiModalHashes = {}
mm_uuids = mm_uuids or {}
lora_kwargs = lora_kwargs or {}
for modality, items in mm_items.items():
if modality in mm_uuids:
@ -1703,6 +1705,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
item_uuid is None
or hf_processor_mm_kwargs
or tokenization_kwargs
or lora_kwargs
):
# NOTE: use provided hash string to hash with kwargs
# if available for better performance.
@ -1713,6 +1716,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
**{modality: item},
**hf_processor_mm_kwargs,
**tokenization_kwargs,
**lora_kwargs,
)
)
else:
@ -1725,6 +1729,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
**{modality: item},
**hf_processor_mm_kwargs,
**tokenization_kwargs,
**lora_kwargs,
)
for item in items
]
@ -1883,6 +1888,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object],
*,
mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
"""
Apply the HF processor on the full prompt text,
@ -1905,6 +1911,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
)
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
@ -2115,6 +2122,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.
@ -2144,6 +2152,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
)
# NOTE: tokenization_kwargs are not required to init processor
@ -2224,6 +2233,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
tokenization_kwargs: Mapping[str, object] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> MultiModalEncDecInputs:
"""
Process multi-modal inputs to be used in vLLM.
@ -2239,6 +2249,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
)
return self._get_enc_dec_inputs(

View File

@ -5,6 +5,8 @@ import time
from collections.abc import Mapping
from typing import Any, Literal, cast
import msgspec
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
@ -458,6 +460,17 @@ class InputProcessor:
else:
mm_uuids = None
# When enable_tower_connector_lora is True, multi-modal embeddings
# vary depending on the LoRA request. Therefore, the mm_hash must be
# generated based on the LoRA request to prevent incorrect cache hits.
lora_config = self.lora_config
lora_kwargs = (
msgspec.structs.asdict(lora_request)
if lora_request and lora_config and lora_config.enable_tower_connector_lora
else {}
)
lora_kwargs = {k: v for k, v in lora_kwargs.items() if v is not None}
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
@ -466,6 +479,7 @@ class InputProcessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
)
from vllm.platforms import current_platform

View File

@ -590,15 +590,6 @@ class GPUModelRunner(
pin_memory=self.pin_memory,
)
# Multimodal LoRA support
self.enable_tower_connector_lora = False
if self.supports_mm_inputs and self.lora_config:
self.mm_model_cls = self.mm_registry._get_model_cls(model_config)
self.enable_tower_connector_lora = (
hasattr(self.mm_model_cls, "get_num_mm_encoder_tokens")
and self.lora_config.enable_tower_connector_lora
)
# Pre-allocated tensor for copying valid sampled token counts to CPU,
# with dedicated stream for overlapping and event for coordination.
self.valid_sampled_token_count_event: torch.Event | None = None
@ -2169,12 +2160,15 @@ class GPUModelRunner(
# encoder outputs.
model = cast(SupportsMultiModal, self.model)
if self.enable_tower_connector_lora:
if self.lora_manager.supports_tower_connector_lora():
# Build LoRA mappings independently for encoder inputs
# (encoder batch structure is different from main batch)
prompt_lora_mapping = []
token_lora_mapping = []
lora_requests = set()
# This implementation is a bit hacky, but it's mainly to retrieve
# the get_num_mm_*_tokens helper functions from ProcessingInfo.
mm_processor_info = self.lora_manager._adapter_manager.mm_processor_info
for req_id, (_, pos_info) in zip(encoder_req_ids, mm_hashes_pos):
req_idx = self.input_batch.req_id_to_index[req_id]
@ -2183,7 +2177,7 @@ class GPUModelRunner(
# Prefer pos_info.is_embed to count actual MM embedding tokens.
# pos_info.length may overcount (e.g., special tokens in Qwen-VL).
# Fall back to length if is_embed is None.
num_tokens = model.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
num_tokens = mm_processor_info.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
pos_info.get_num_embeds
)
prompt_lora_mapping.append(lora_id)
@ -2202,13 +2196,13 @@ class GPUModelRunner(
)
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
if hasattr(model, "get_num_mm_connector_tokens"):
if hasattr(mm_processor_info, "get_num_mm_connector_tokens"):
num_post_op_tokens = []
for _, pos_info in mm_hashes_pos:
mm_token_count = model.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
mm_token_count = mm_processor_info.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
pos_info.length
)
post_op_count = model.get_num_mm_connector_tokens( # type: ignore[attr-defined]
post_op_count = mm_processor_info.get_num_mm_connector_tokens( # type: ignore[attr-defined]
mm_token_count
)
num_post_op_tokens.append(post_op_count)