mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 22:39:08 +08:00
Merge branch 'mlm-full-lora-support' of https://github.com/jeejeelee/vllm into mlm-full-lora-support
This commit is contained in:
commit
2b03137fca
@ -246,6 +246,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
lora_kwargs: dict[str, Any] | None = None,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Apply the model's multi-modal processor to a multi-modal prompt,
|
||||
@ -262,6 +263,7 @@ class InputPreprocessor:
|
||||
hf_processor_mm_kwargs=mm_processor_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
lora_kwargs=lora_kwargs,
|
||||
)
|
||||
mm_hashes = mm_input["mm_hashes"]
|
||||
|
||||
@ -359,6 +361,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
lora_kwargs: dict[str, Any] | None = None,
|
||||
) -> TokenInputs | MultiModalInputs:
|
||||
prompt_text = parsed_content["prompt"]
|
||||
|
||||
@ -370,6 +373,7 @@ class InputPreprocessor:
|
||||
parsed_content.get("mm_processor_kwargs") or {},
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
lora_kwargs=lora_kwargs,
|
||||
)
|
||||
else:
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
@ -389,6 +393,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
lora_kwargs: dict[str, Any] | None = None,
|
||||
) -> SingletonInputs:
|
||||
"""
|
||||
Extract the singleton inputs from a prompt.
|
||||
@ -415,6 +420,7 @@ class InputPreprocessor:
|
||||
parsed["content"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
lora_kwargs=lora_kwargs,
|
||||
)
|
||||
if parsed["type"] == "str":
|
||||
return self._process_text(
|
||||
@ -626,6 +632,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
lora_kwargs: dict[str, Any] | None = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
"""
|
||||
For decoder-only models:
|
||||
@ -645,6 +652,7 @@ class InputPreprocessor:
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
lora_kwargs=lora_kwargs,
|
||||
)
|
||||
|
||||
return self._build_decoder_only_llm_inputs(prompt_comps)
|
||||
@ -655,6 +663,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
lora_kwargs: dict[str, Any] | None = None,
|
||||
) -> ProcessorInputs:
|
||||
if self.model_config.is_encoder_decoder:
|
||||
# Encoder-decoder model requires special mapping of
|
||||
@ -676,6 +685,7 @@ class InputPreprocessor:
|
||||
cast(SingletonPrompt, prompt),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
lora_kwargs=lora_kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
@ -684,12 +694,14 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
lora_kwargs: dict[str, Any] | None = None,
|
||||
) -> ProcessorInputs:
|
||||
"""Preprocess the input prompt."""
|
||||
res = self._preprocess(
|
||||
prompt,
|
||||
tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
lora_kwargs=lora_kwargs,
|
||||
)
|
||||
|
||||
if self.mm_processor_cache and self.mm_cache_stats is not None:
|
||||
|
||||
@ -154,8 +154,11 @@ class LoRAModelManager:
|
||||
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
|
||||
|
||||
if self.lora_config.enable_tower_connector_lora:
|
||||
self.mm_processor_info = MULTIMODAL_REGISTRY.create_processor(
|
||||
model_config
|
||||
).info
|
||||
self.supports_tower_connector_lora = self.supports_mm and hasattr(
|
||||
self.model, "get_num_mm_encoder_tokens"
|
||||
self.mm_processor_info, "get_num_mm_encoder_tokens"
|
||||
)
|
||||
if not self.supports_tower_connector_lora:
|
||||
return
|
||||
@ -171,8 +174,10 @@ class LoRAModelManager:
|
||||
vllm_config.scheduler_config,
|
||||
MULTIMODAL_REGISTRY,
|
||||
)
|
||||
limit_per_prompt: int = max(self.model.get_allowed_mm_limits().values())
|
||||
num_encoder_tokens = self.model.get_num_mm_encoder_tokens(
|
||||
limit_per_prompt: int = max(
|
||||
self.mm_processor_info.get_allowed_mm_limits().values()
|
||||
)
|
||||
num_encoder_tokens = self.mm_processor_info.get_num_mm_encoder_tokens(
|
||||
mm_budget.get_encoder_budget()
|
||||
)
|
||||
|
||||
@ -188,8 +193,8 @@ class LoRAModelManager:
|
||||
|
||||
# Use wrapper for connector if present.
|
||||
if self.mm_mapping.connector:
|
||||
if hasattr(self.model, "get_num_mm_connector_tokens"):
|
||||
connector_tokens = self.model.get_num_mm_connector_tokens(
|
||||
if hasattr(self.mm_processor_info, "get_num_mm_connector_tokens"):
|
||||
connector_tokens = self.mm_processor_info.get_num_mm_connector_tokens(
|
||||
num_encoder_tokens
|
||||
)
|
||||
connector_punica_wrapper = get_punica_wrapper(
|
||||
|
||||
@ -163,6 +163,12 @@ class WorkerLoRAManager:
|
||||
if mapping is not None:
|
||||
self._adapter_manager.set_adapter_mapping(mapping)
|
||||
|
||||
def supports_tower_connector_lora(self) -> bool:
|
||||
return (
|
||||
self._adapter_manager.supports_mm
|
||||
and self._adapter_manager.supports_tower_connector_lora
|
||||
)
|
||||
|
||||
def _apply_adapters(self, adapter_requests: set[Any]) -> None:
|
||||
existing_adapters = self.list_adapters()
|
||||
models_map = {
|
||||
|
||||
@ -1672,6 +1672,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
lora_kwargs: dict[str, Any] | None = None,
|
||||
) -> MultiModalHashes:
|
||||
"""Create MM hashes to be returned.
|
||||
|
||||
@ -1683,6 +1684,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
hashes: MultiModalHashes = {}
|
||||
mm_uuids = mm_uuids or {}
|
||||
lora_kwargs = lora_kwargs or {}
|
||||
|
||||
for modality, items in mm_items.items():
|
||||
if modality in mm_uuids:
|
||||
@ -1703,6 +1705,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
item_uuid is None
|
||||
or hf_processor_mm_kwargs
|
||||
or tokenization_kwargs
|
||||
or lora_kwargs
|
||||
):
|
||||
# NOTE: use provided hash string to hash with kwargs
|
||||
# if available for better performance.
|
||||
@ -1713,6 +1716,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs,
|
||||
**tokenization_kwargs,
|
||||
**lora_kwargs,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -1725,6 +1729,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs,
|
||||
**tokenization_kwargs,
|
||||
**lora_kwargs,
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
@ -1883,6 +1888,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
lora_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
"""
|
||||
Apply the HF processor on the full prompt text,
|
||||
@ -1905,6 +1911,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
hf_processor_mm_kwargs,
|
||||
tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
lora_kwargs=lora_kwargs,
|
||||
)
|
||||
|
||||
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
|
||||
@ -2115,6 +2122,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
tokenization_kwargs: Mapping[str, object] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
lora_kwargs: dict[str, Any] | None = None,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Process multi-modal inputs to be used in vLLM.
|
||||
@ -2144,6 +2152,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
lora_kwargs=lora_kwargs,
|
||||
)
|
||||
|
||||
# NOTE: tokenization_kwargs are not required to init processor
|
||||
@ -2224,6 +2233,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
tokenization_kwargs: Mapping[str, object] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
lora_kwargs: dict[str, Any] | None = None,
|
||||
) -> MultiModalEncDecInputs:
|
||||
"""
|
||||
Process multi-modal inputs to be used in vLLM.
|
||||
@ -2239,6 +2249,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
hf_processor_mm_kwargs,
|
||||
tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
lora_kwargs=lora_kwargs,
|
||||
)
|
||||
|
||||
return self._get_enc_dec_inputs(
|
||||
|
||||
@ -5,6 +5,8 @@ import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
@ -458,6 +460,17 @@ class InputProcessor:
|
||||
else:
|
||||
mm_uuids = None
|
||||
|
||||
# When enable_tower_connector_lora is True, multi-modal embeddings
|
||||
# vary depending on the LoRA request. Therefore, the mm_hash must be
|
||||
# generated based on the LoRA request to prevent incorrect cache hits.
|
||||
lora_config = self.lora_config
|
||||
lora_kwargs = (
|
||||
msgspec.structs.asdict(lora_request)
|
||||
if lora_request and lora_config and lora_config.enable_tower_connector_lora
|
||||
else {}
|
||||
)
|
||||
lora_kwargs = {k: v for k, v in lora_kwargs.items() if v is not None}
|
||||
|
||||
# Process inputs, which includes:
|
||||
# 1. Tokenize text prompt, with LoRA request if one exists.
|
||||
# 2. For multimodal models with a merged preprocessor, preprocess
|
||||
@ -466,6 +479,7 @@ class InputProcessor:
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
lora_kwargs=lora_kwargs,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@ -590,15 +590,6 @@ class GPUModelRunner(
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
# Multimodal LoRA support
|
||||
self.enable_tower_connector_lora = False
|
||||
if self.supports_mm_inputs and self.lora_config:
|
||||
self.mm_model_cls = self.mm_registry._get_model_cls(model_config)
|
||||
self.enable_tower_connector_lora = (
|
||||
hasattr(self.mm_model_cls, "get_num_mm_encoder_tokens")
|
||||
and self.lora_config.enable_tower_connector_lora
|
||||
)
|
||||
|
||||
# Pre-allocated tensor for copying valid sampled token counts to CPU,
|
||||
# with dedicated stream for overlapping and event for coordination.
|
||||
self.valid_sampled_token_count_event: torch.Event | None = None
|
||||
@ -2169,12 +2160,15 @@ class GPUModelRunner(
|
||||
# encoder outputs.
|
||||
model = cast(SupportsMultiModal, self.model)
|
||||
|
||||
if self.enable_tower_connector_lora:
|
||||
if self.lora_manager.supports_tower_connector_lora():
|
||||
# Build LoRA mappings independently for encoder inputs
|
||||
# (encoder batch structure is different from main batch)
|
||||
prompt_lora_mapping = []
|
||||
token_lora_mapping = []
|
||||
lora_requests = set()
|
||||
# This implementation is a bit hacky, but it's mainly to retrieve
|
||||
# the get_num_mm_*_tokens helper functions from ProcessingInfo.
|
||||
mm_processor_info = self.lora_manager._adapter_manager.mm_processor_info
|
||||
|
||||
for req_id, (_, pos_info) in zip(encoder_req_ids, mm_hashes_pos):
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
@ -2183,7 +2177,7 @@ class GPUModelRunner(
|
||||
# Prefer pos_info.is_embed to count actual MM embedding tokens.
|
||||
# pos_info.length may overcount (e.g., special tokens in Qwen-VL).
|
||||
# Fall back to length if is_embed is None.
|
||||
num_tokens = model.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
|
||||
num_tokens = mm_processor_info.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
|
||||
pos_info.get_num_embeds
|
||||
)
|
||||
prompt_lora_mapping.append(lora_id)
|
||||
@ -2202,13 +2196,13 @@ class GPUModelRunner(
|
||||
)
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
if hasattr(model, "get_num_mm_connector_tokens"):
|
||||
if hasattr(mm_processor_info, "get_num_mm_connector_tokens"):
|
||||
num_post_op_tokens = []
|
||||
for _, pos_info in mm_hashes_pos:
|
||||
mm_token_count = model.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
|
||||
mm_token_count = mm_processor_info.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
|
||||
pos_info.length
|
||||
)
|
||||
post_op_count = model.get_num_mm_connector_tokens( # type: ignore[attr-defined]
|
||||
post_op_count = mm_processor_info.get_num_mm_connector_tokens( # type: ignore[attr-defined]
|
||||
mm_token_count
|
||||
)
|
||||
num_post_op_tokens.append(post_op_count)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user