diff --git a/examples/online_serving/disaggregated_encoder/README.md b/examples/online_serving/disaggregated_encoder/README.md index b2c3bb974dfab..2a59f86d15fb7 100644 --- a/examples/online_serving/disaggregated_encoder/README.md +++ b/examples/online_serving/disaggregated_encoder/README.md @@ -38,6 +38,8 @@ Encoder engines should be launched with the following flags: - `--max-num-batched-tokens=` **(default: 2048)** – This flag controls the token scheduling budget per decoding step and is irrelevant to encoder-only instances. **Set it to a very high value (effectively unlimited) to bypass scheduler limitations.** The actual token budget is managed by the encoder cache manager. +- `--convert "mm_encoder_only"` **(Optional)** - The language model is skipped during initialization to reduce device memory usage. **Models using this option must implement the `get_language_model_spec` interface.** + ## Local media inputs To support local image inputs (from your ```MEDIA_PATH``` directory), add the following flag to the encoder instance: diff --git a/vllm/config/model.py b/vllm/config/model.py index c796e300ab155..dd2b7b9d7a786 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -71,7 +71,7 @@ else: logger = init_logger(__name__) RunnerOption = Literal["auto", RunnerType] -ConvertType = Literal["none", "embed", "classify", "reward"] +ConvertType = Literal["none", "embed", "classify", "reward", "mm_encoder_only"] ConvertOption = Literal["auto", ConvertType] TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 74b02e4c62583..08d7a851ac9ab 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -189,7 +189,9 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], ) convert_type = model_config.convert_type - if convert_type != "none" and supports_multimodal(model_cls): + if convert_type not in ["none", "mm_encoder_only"] and supports_multimodal( + model_cls + ): logger.debug_once("Detected conversion of Multi Modal model.") converted = try_create_mm_pooling_model_cls(model_cls) if converted is not None: @@ -200,6 +202,11 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], if convert_type == "none": pass + elif convert_type == "mm_encoder_only": + logger.debug_once("Converting to mm encoder only model.") + from vllm.model_executor.models.adapters import as_mm_encoder_only_model + + model_cls = as_mm_encoder_only_model(model_cls) elif convert_type == "embed": logger.debug_once("Converting to embedding model.") model_cls = as_embedding_model(model_cls) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 504de9fe10871..acf1e57a59a97 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -520,3 +520,64 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]): method = getattr(text_config, "method", None) assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported" return SEQ_CLS_LOAD_METHODS[method](model, weights) + + +def as_mm_encoder_only_model(cls: _T) -> _T: + """ + Subclass an existing vLLM vl model to support mm encoder only for + EPD encoder instances. + """ + if not hasattr(cls, "embed_multimodal"): + # Submodel case: return the original class. + return cls + + if not hasattr(cls, "get_language_model_spec"): + raise TypeError(f"{cls} need to implement `get_language_model_spec` method.") + + lm_model_cls, lm_attr = cls.get_language_model_spec() + + if lm_model_cls is None or lm_attr is None: + raise TypeError( + f"{cls}.get_language_model_spec() must return (lm_model_cls, lm_attr)" + ) + + class DummyLM(nn.Module): + def __init__(self, *args, **kwargs): + self.make_empty_intermediate_tensors = None + + class ModelForMMEncoderOnly(cls): + def __init__( + self, + *, + vllm_config: "VllmConfig", + prefix: str = "", + **kwargs: Any, + ) -> None: + self.is_mm_encoder_only_model = True + origin_init = lm_model_cls.__init__ + try: + lm_model_cls.__init__ = DummyLM.__init__ + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + if hasattr(self, lm_attr): + delattr(self, lm_attr) + finally: + lm_model_cls.__init__ = origin_init + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + from .utils import AutoWeightsLoader + + origin_init_ = AutoWeightsLoader.__init__ + + def _new_init_(self, *args, **kwargs): + origin_init_(self, *args, **kwargs) + self.skip_prefixes = (self.skip_prefixes or []) + [f"{lm_attr}."] + + try: + AutoWeightsLoader.__init__ = _new_init_ + result = super().load_weights(weights) + finally: + AutoWeightsLoader.__init__ = origin_init_ + return result + + return ModelForMMEncoderOnly # type: ignore diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index cb99d57e8b8c7..67c65a44dcf7f 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -141,6 +141,14 @@ class SupportsMultiModal(Protocol): """ ... + @classmethod + def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]: + """ + Return the language model spec: + (language model class, language model attr) + """ + return None, None + @overload def embed_input_ids(self, input_ids: Tensor) -> Tensor: ... @@ -302,6 +310,10 @@ def supports_multimodal_encoder_tp_data(model: type[object] | object) -> bool: return getattr(model, "supports_encoder_tp_data", False) +def supports_mm_encoder_only(model: type[object] | object) -> bool: + return getattr(model, "is_mm_encoder_only_model", False) + + @overload def supports_multimodal_pruning( model: type[object], diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index b730ac0315893..0b44ff622f05b 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -34,7 +34,7 @@ import einops import torch import torch.nn as nn import torch.nn.functional as F -from transformers import BatchFeature +from transformers import BatchFeature, Qwen2ForCausalLM from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, @@ -1567,3 +1567,11 @@ class Qwen2_5_VLForConditionalGeneration( connector="visual.merger.", tower_model="visual.", ) + + @classmethod + def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]: + """ + Return the language model spec: + (language model class, language model attr) + """ + return Qwen2ForCausalLM, "language_model" diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 4838f68e06f70..fea73557f1e82 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -2090,3 +2090,11 @@ class Qwen3VLForConditionalGeneration( connector="visual.merger", tower_model="visual.", ) + + @classmethod + def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]: + """ + Return the language model spec: + (language model class, language model attr) + """ + return Qwen3LLMForCausalLM, "language_model" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 92822d829a881..0a17923e89989 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -66,6 +66,7 @@ from vllm.model_executor.models.interfaces import ( SupportsXDRoPE, is_mixture_of_experts, supports_eagle3, + supports_mm_encoder_only, supports_mrope, supports_multimodal_pruning, supports_transcription, @@ -4067,6 +4068,11 @@ class GPUModelRunner( remove_lora: If False, dummy LoRAs are not destroyed after the run activate_lora: If False, dummy_run is performed without LoRAs. """ + if supports_mm_encoder_only(self.model): + # The current dummy run only covers LM execution, so we can skip it. + # mm encoder dummy run may need to add in the future. + return torch.tensor([]), torch.tensor([]) + assert ( cudagraph_runtime_mode is None or cudagraph_runtime_mode.valid_runtime_modes() @@ -4344,6 +4350,11 @@ class GPUModelRunner( # The dummy hidden states may contain special values, # like `inf` or `nan`. # To avoid breaking the sampler, we use a random tensor here instead. + + if supports_mm_encoder_only(self.model): + # MM Encoder only model no need to run sampler. + return torch.tensor([]) + hidden_states = torch.rand_like(hidden_states) logits = self.model.compute_logits(hidden_states) @@ -4472,6 +4483,10 @@ class GPUModelRunner( self, hidden_states: torch.Tensor, ) -> PoolerOutput: + if supports_mm_encoder_only(self.model): + # MM Encoder only model not need to run pooler. + return torch.tensor([]) + # Find the task that has the largest output for subsequent steps supported_pooling_tasks = self.get_supported_pooling_tasks()