Merge branch 'mlm-full-lora-support' of https://github.com/jeejeelee/vllm into mlm-full-lora-support

This commit is contained in:
Anexdeus 2025-12-20 20:40:31 +03:00
commit 2b03137fca
6 changed files with 61 additions and 19 deletions

View File

@ -246,6 +246,7 @@ class InputPreprocessor:
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Apply the model's multi-modal processor to a multi-modal prompt, Apply the model's multi-modal processor to a multi-modal prompt,
@ -262,6 +263,7 @@ class InputPreprocessor:
hf_processor_mm_kwargs=mm_processor_kwargs, hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
) )
mm_hashes = mm_input["mm_hashes"] mm_hashes = mm_input["mm_hashes"]
@ -359,6 +361,7 @@ class InputPreprocessor:
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> TokenInputs | MultiModalInputs: ) -> TokenInputs | MultiModalInputs:
prompt_text = parsed_content["prompt"] prompt_text = parsed_content["prompt"]
@ -370,6 +373,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs") or {}, parsed_content.get("mm_processor_kwargs") or {},
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
) )
else: else:
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
@ -389,6 +393,7 @@ class InputPreprocessor:
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> SingletonInputs: ) -> SingletonInputs:
""" """
Extract the singleton inputs from a prompt. Extract the singleton inputs from a prompt.
@ -415,6 +420,7 @@ class InputPreprocessor:
parsed["content"], parsed["content"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
) )
if parsed["type"] == "str": if parsed["type"] == "str":
return self._process_text( return self._process_text(
@ -626,6 +632,7 @@ class InputPreprocessor:
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
For decoder-only models: For decoder-only models:
@ -645,6 +652,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
) )
return self._build_decoder_only_llm_inputs(prompt_comps) return self._build_decoder_only_llm_inputs(prompt_comps)
@ -655,6 +663,7 @@ class InputPreprocessor:
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
if self.model_config.is_encoder_decoder: if self.model_config.is_encoder_decoder:
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
@ -676,6 +685,7 @@ class InputPreprocessor:
cast(SingletonPrompt, prompt), cast(SingletonPrompt, prompt),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
) )
def preprocess( def preprocess(
@ -684,12 +694,14 @@ class InputPreprocessor:
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
"""Preprocess the input prompt.""" """Preprocess the input prompt."""
res = self._preprocess( res = self._preprocess(
prompt, prompt,
tokenization_kwargs, tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
) )
if self.mm_processor_cache and self.mm_cache_stats is not None: if self.mm_processor_cache and self.mm_cache_stats is not None:

View File

@ -154,8 +154,11 @@ class LoRAModelManager:
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
if self.lora_config.enable_tower_connector_lora: if self.lora_config.enable_tower_connector_lora:
self.mm_processor_info = MULTIMODAL_REGISTRY.create_processor(
model_config
).info
self.supports_tower_connector_lora = self.supports_mm and hasattr( self.supports_tower_connector_lora = self.supports_mm and hasattr(
self.model, "get_num_mm_encoder_tokens" self.mm_processor_info, "get_num_mm_encoder_tokens"
) )
if not self.supports_tower_connector_lora: if not self.supports_tower_connector_lora:
return return
@ -171,8 +174,10 @@ class LoRAModelManager:
vllm_config.scheduler_config, vllm_config.scheduler_config,
MULTIMODAL_REGISTRY, MULTIMODAL_REGISTRY,
) )
limit_per_prompt: int = max(self.model.get_allowed_mm_limits().values()) limit_per_prompt: int = max(
num_encoder_tokens = self.model.get_num_mm_encoder_tokens( self.mm_processor_info.get_allowed_mm_limits().values()
)
num_encoder_tokens = self.mm_processor_info.get_num_mm_encoder_tokens(
mm_budget.get_encoder_budget() mm_budget.get_encoder_budget()
) )
@ -188,8 +193,8 @@ class LoRAModelManager:
# Use wrapper for connector if present. # Use wrapper for connector if present.
if self.mm_mapping.connector: if self.mm_mapping.connector:
if hasattr(self.model, "get_num_mm_connector_tokens"): if hasattr(self.mm_processor_info, "get_num_mm_connector_tokens"):
connector_tokens = self.model.get_num_mm_connector_tokens( connector_tokens = self.mm_processor_info.get_num_mm_connector_tokens(
num_encoder_tokens num_encoder_tokens
) )
connector_punica_wrapper = get_punica_wrapper( connector_punica_wrapper = get_punica_wrapper(

View File

@ -163,6 +163,12 @@ class WorkerLoRAManager:
if mapping is not None: if mapping is not None:
self._adapter_manager.set_adapter_mapping(mapping) self._adapter_manager.set_adapter_mapping(mapping)
def supports_tower_connector_lora(self) -> bool:
return (
self._adapter_manager.supports_mm
and self._adapter_manager.supports_tower_connector_lora
)
def _apply_adapters(self, adapter_requests: set[Any]) -> None: def _apply_adapters(self, adapter_requests: set[Any]) -> None:
existing_adapters = self.list_adapters() existing_adapters = self.list_adapters()
models_map = { models_map = {

View File

@ -1672,6 +1672,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> MultiModalHashes: ) -> MultiModalHashes:
"""Create MM hashes to be returned. """Create MM hashes to be returned.
@ -1683,6 +1684,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hashes: MultiModalHashes = {} hashes: MultiModalHashes = {}
mm_uuids = mm_uuids or {} mm_uuids = mm_uuids or {}
lora_kwargs = lora_kwargs or {}
for modality, items in mm_items.items(): for modality, items in mm_items.items():
if modality in mm_uuids: if modality in mm_uuids:
@ -1703,6 +1705,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
item_uuid is None item_uuid is None
or hf_processor_mm_kwargs or hf_processor_mm_kwargs
or tokenization_kwargs or tokenization_kwargs
or lora_kwargs
): ):
# NOTE: use provided hash string to hash with kwargs # NOTE: use provided hash string to hash with kwargs
# if available for better performance. # if available for better performance.
@ -1713,6 +1716,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
**{modality: item}, **{modality: item},
**hf_processor_mm_kwargs, **hf_processor_mm_kwargs,
**tokenization_kwargs, **tokenization_kwargs,
**lora_kwargs,
) )
) )
else: else:
@ -1725,6 +1729,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
**{modality: item}, **{modality: item},
**hf_processor_mm_kwargs, **hf_processor_mm_kwargs,
**tokenization_kwargs, **tokenization_kwargs,
**lora_kwargs,
) )
for item in items for item in items
] ]
@ -1883,6 +1888,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
""" """
Apply the HF processor on the full prompt text, Apply the HF processor on the full prompt text,
@ -1905,6 +1911,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs, tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
) )
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items( mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
@ -2115,6 +2122,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object] | None = None, tokenization_kwargs: Mapping[str, object] | None = None,
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
@ -2144,6 +2152,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
) )
# NOTE: tokenization_kwargs are not required to init processor # NOTE: tokenization_kwargs are not required to init processor
@ -2224,6 +2233,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
tokenization_kwargs: Mapping[str, object] | None = None, tokenization_kwargs: Mapping[str, object] | None = None,
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
lora_kwargs: dict[str, Any] | None = None,
) -> MultiModalEncDecInputs: ) -> MultiModalEncDecInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
@ -2239,6 +2249,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs, tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
) )
return self._get_enc_dec_inputs( return self._get_enc_dec_inputs(

View File

@ -5,6 +5,8 @@ import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Literal, cast from typing import Any, Literal, cast
import msgspec
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.parse import split_enc_dec_inputs
@ -458,6 +460,17 @@ class InputProcessor:
else: else:
mm_uuids = None mm_uuids = None
# When enable_tower_connector_lora is True, multi-modal embeddings
# vary depending on the LoRA request. Therefore, the mm_hash must be
# generated based on the LoRA request to prevent incorrect cache hits.
lora_config = self.lora_config
lora_kwargs = (
msgspec.structs.asdict(lora_request)
if lora_request and lora_config and lora_config.enable_tower_connector_lora
else {}
)
lora_kwargs = {k: v for k, v in lora_kwargs.items() if v is not None}
# Process inputs, which includes: # Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists. # 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess # 2. For multimodal models with a merged preprocessor, preprocess
@ -466,6 +479,7 @@ class InputProcessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
lora_kwargs=lora_kwargs,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform

View File

@ -590,15 +590,6 @@ class GPUModelRunner(
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
# Multimodal LoRA support
self.enable_tower_connector_lora = False
if self.supports_mm_inputs and self.lora_config:
self.mm_model_cls = self.mm_registry._get_model_cls(model_config)
self.enable_tower_connector_lora = (
hasattr(self.mm_model_cls, "get_num_mm_encoder_tokens")
and self.lora_config.enable_tower_connector_lora
)
# Pre-allocated tensor for copying valid sampled token counts to CPU, # Pre-allocated tensor for copying valid sampled token counts to CPU,
# with dedicated stream for overlapping and event for coordination. # with dedicated stream for overlapping and event for coordination.
self.valid_sampled_token_count_event: torch.Event | None = None self.valid_sampled_token_count_event: torch.Event | None = None
@ -2169,12 +2160,15 @@ class GPUModelRunner(
# encoder outputs. # encoder outputs.
model = cast(SupportsMultiModal, self.model) model = cast(SupportsMultiModal, self.model)
if self.enable_tower_connector_lora: if self.lora_manager.supports_tower_connector_lora():
# Build LoRA mappings independently for encoder inputs # Build LoRA mappings independently for encoder inputs
# (encoder batch structure is different from main batch) # (encoder batch structure is different from main batch)
prompt_lora_mapping = [] prompt_lora_mapping = []
token_lora_mapping = [] token_lora_mapping = []
lora_requests = set() lora_requests = set()
# This implementation is a bit hacky, but it's mainly to retrieve
# the get_num_mm_*_tokens helper functions from ProcessingInfo.
mm_processor_info = self.lora_manager._adapter_manager.mm_processor_info
for req_id, (_, pos_info) in zip(encoder_req_ids, mm_hashes_pos): for req_id, (_, pos_info) in zip(encoder_req_ids, mm_hashes_pos):
req_idx = self.input_batch.req_id_to_index[req_id] req_idx = self.input_batch.req_id_to_index[req_id]
@ -2183,7 +2177,7 @@ class GPUModelRunner(
# Prefer pos_info.is_embed to count actual MM embedding tokens. # Prefer pos_info.is_embed to count actual MM embedding tokens.
# pos_info.length may overcount (e.g., special tokens in Qwen-VL). # pos_info.length may overcount (e.g., special tokens in Qwen-VL).
# Fall back to length if is_embed is None. # Fall back to length if is_embed is None.
num_tokens = model.get_num_mm_encoder_tokens( # type: ignore[attr-defined] num_tokens = mm_processor_info.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
pos_info.get_num_embeds pos_info.get_num_embeds
) )
prompt_lora_mapping.append(lora_id) prompt_lora_mapping.append(lora_id)
@ -2202,13 +2196,13 @@ class GPUModelRunner(
) )
self.lora_manager.set_active_adapters(lora_requests, lora_mapping) self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
if hasattr(model, "get_num_mm_connector_tokens"): if hasattr(mm_processor_info, "get_num_mm_connector_tokens"):
num_post_op_tokens = [] num_post_op_tokens = []
for _, pos_info in mm_hashes_pos: for _, pos_info in mm_hashes_pos:
mm_token_count = model.get_num_mm_encoder_tokens( # type: ignore[attr-defined] mm_token_count = mm_processor_info.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
pos_info.length pos_info.length
) )
post_op_count = model.get_num_mm_connector_tokens( # type: ignore[attr-defined] post_op_count = mm_processor_info.get_num_mm_connector_tokens( # type: ignore[attr-defined]
mm_token_count mm_token_count
) )
num_post_op_tokens.append(post_op_count) num_post_op_tokens.append(post_op_count)