mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 18:17:28 +08:00
[BugFix] skip language model in Encoder (#30242)
Signed-off-by: dengyunyang <584797741@qq.com>
This commit is contained in:
parent
2cf91c2ea4
commit
8f8f469b1b
@ -38,6 +38,8 @@ Encoder engines should be launched with the following flags:
|
||||
|
||||
- `--max-num-batched-tokens=<large value>` **(default: 2048)** – This flag controls the token scheduling budget per decoding step and is irrelevant to encoder-only instances. **Set it to a very high value (effectively unlimited) to bypass scheduler limitations.** The actual token budget is managed by the encoder cache manager.
|
||||
|
||||
- `--convert "mm_encoder_only"` **(Optional)** - The language model is skipped during initialization to reduce device memory usage. **Models using this option must implement the `get_language_model_spec` interface.**
|
||||
|
||||
## Local media inputs
|
||||
|
||||
To support local image inputs (from your ```MEDIA_PATH``` directory), add the following flag to the encoder instance:
|
||||
|
||||
@ -71,7 +71,7 @@ else:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
RunnerOption = Literal["auto", RunnerType]
|
||||
ConvertType = Literal["none", "embed", "classify", "reward"]
|
||||
ConvertType = Literal["none", "embed", "classify", "reward", "mm_encoder_only"]
|
||||
ConvertOption = Literal["auto", ConvertType]
|
||||
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"]
|
||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||
|
||||
@ -189,7 +189,9 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
|
||||
)
|
||||
|
||||
convert_type = model_config.convert_type
|
||||
if convert_type != "none" and supports_multimodal(model_cls):
|
||||
if convert_type not in ["none", "mm_encoder_only"] and supports_multimodal(
|
||||
model_cls
|
||||
):
|
||||
logger.debug_once("Detected conversion of Multi Modal model.")
|
||||
converted = try_create_mm_pooling_model_cls(model_cls)
|
||||
if converted is not None:
|
||||
@ -200,6 +202,11 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
|
||||
|
||||
if convert_type == "none":
|
||||
pass
|
||||
elif convert_type == "mm_encoder_only":
|
||||
logger.debug_once("Converting to mm encoder only model.")
|
||||
from vllm.model_executor.models.adapters import as_mm_encoder_only_model
|
||||
|
||||
model_cls = as_mm_encoder_only_model(model_cls)
|
||||
elif convert_type == "embed":
|
||||
logger.debug_once("Converting to embedding model.")
|
||||
model_cls = as_embedding_model(model_cls)
|
||||
|
||||
@ -520,3 +520,64 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
method = getattr(text_config, "method", None)
|
||||
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
|
||||
return SEQ_CLS_LOAD_METHODS[method](model, weights)
|
||||
|
||||
|
||||
def as_mm_encoder_only_model(cls: _T) -> _T:
|
||||
"""
|
||||
Subclass an existing vLLM vl model to support mm encoder only for
|
||||
EPD encoder instances.
|
||||
"""
|
||||
if not hasattr(cls, "embed_multimodal"):
|
||||
# Submodel case: return the original class.
|
||||
return cls
|
||||
|
||||
if not hasattr(cls, "get_language_model_spec"):
|
||||
raise TypeError(f"{cls} need to implement `get_language_model_spec` method.")
|
||||
|
||||
lm_model_cls, lm_attr = cls.get_language_model_spec()
|
||||
|
||||
if lm_model_cls is None or lm_attr is None:
|
||||
raise TypeError(
|
||||
f"{cls}.get_language_model_spec() must return (lm_model_cls, lm_attr)"
|
||||
)
|
||||
|
||||
class DummyLM(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.make_empty_intermediate_tensors = None
|
||||
|
||||
class ModelForMMEncoderOnly(cls):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.is_mm_encoder_only_model = True
|
||||
origin_init = lm_model_cls.__init__
|
||||
try:
|
||||
lm_model_cls.__init__ = DummyLM.__init__
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
if hasattr(self, lm_attr):
|
||||
delattr(self, lm_attr)
|
||||
finally:
|
||||
lm_model_cls.__init__ = origin_init
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
from .utils import AutoWeightsLoader
|
||||
|
||||
origin_init_ = AutoWeightsLoader.__init__
|
||||
|
||||
def _new_init_(self, *args, **kwargs):
|
||||
origin_init_(self, *args, **kwargs)
|
||||
self.skip_prefixes = (self.skip_prefixes or []) + [f"{lm_attr}."]
|
||||
|
||||
try:
|
||||
AutoWeightsLoader.__init__ = _new_init_
|
||||
result = super().load_weights(weights)
|
||||
finally:
|
||||
AutoWeightsLoader.__init__ = origin_init_
|
||||
return result
|
||||
|
||||
return ModelForMMEncoderOnly # type: ignore
|
||||
|
||||
@ -141,6 +141,14 @@ class SupportsMultiModal(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
|
||||
"""
|
||||
Return the language model spec:
|
||||
(language model class, language model attr)
|
||||
"""
|
||||
return None, None
|
||||
|
||||
@overload
|
||||
def embed_input_ids(self, input_ids: Tensor) -> Tensor: ...
|
||||
|
||||
@ -302,6 +310,10 @@ def supports_multimodal_encoder_tp_data(model: type[object] | object) -> bool:
|
||||
return getattr(model, "supports_encoder_tp_data", False)
|
||||
|
||||
|
||||
def supports_mm_encoder_only(model: type[object] | object) -> bool:
|
||||
return getattr(model, "is_mm_encoder_only_model", False)
|
||||
|
||||
|
||||
@overload
|
||||
def supports_multimodal_pruning(
|
||||
model: type[object],
|
||||
|
||||
@ -34,7 +34,7 @@ import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import BatchFeature
|
||||
from transformers import BatchFeature, Qwen2ForCausalLM
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig,
|
||||
@ -1567,3 +1567,11 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
connector="visual.merger.",
|
||||
tower_model="visual.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
|
||||
"""
|
||||
Return the language model spec:
|
||||
(language model class, language model attr)
|
||||
"""
|
||||
return Qwen2ForCausalLM, "language_model"
|
||||
|
||||
@ -2090,3 +2090,11 @@ class Qwen3VLForConditionalGeneration(
|
||||
connector="visual.merger",
|
||||
tower_model="visual.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
|
||||
"""
|
||||
Return the language model spec:
|
||||
(language model class, language model attr)
|
||||
"""
|
||||
return Qwen3LLMForCausalLM, "language_model"
|
||||
|
||||
@ -66,6 +66,7 @@ from vllm.model_executor.models.interfaces import (
|
||||
SupportsXDRoPE,
|
||||
is_mixture_of_experts,
|
||||
supports_eagle3,
|
||||
supports_mm_encoder_only,
|
||||
supports_mrope,
|
||||
supports_multimodal_pruning,
|
||||
supports_transcription,
|
||||
@ -4067,6 +4068,11 @@ class GPUModelRunner(
|
||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||
activate_lora: If False, dummy_run is performed without LoRAs.
|
||||
"""
|
||||
if supports_mm_encoder_only(self.model):
|
||||
# The current dummy run only covers LM execution, so we can skip it.
|
||||
# mm encoder dummy run may need to add in the future.
|
||||
return torch.tensor([]), torch.tensor([])
|
||||
|
||||
assert (
|
||||
cudagraph_runtime_mode is None
|
||||
or cudagraph_runtime_mode.valid_runtime_modes()
|
||||
@ -4344,6 +4350,11 @@ class GPUModelRunner(
|
||||
# The dummy hidden states may contain special values,
|
||||
# like `inf` or `nan`.
|
||||
# To avoid breaking the sampler, we use a random tensor here instead.
|
||||
|
||||
if supports_mm_encoder_only(self.model):
|
||||
# MM Encoder only model no need to run sampler.
|
||||
return torch.tensor([])
|
||||
|
||||
hidden_states = torch.rand_like(hidden_states)
|
||||
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
@ -4472,6 +4483,10 @@ class GPUModelRunner(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> PoolerOutput:
|
||||
if supports_mm_encoder_only(self.model):
|
||||
# MM Encoder only model not need to run pooler.
|
||||
return torch.tensor([])
|
||||
|
||||
# Find the task that has the largest output for subsequent steps
|
||||
supported_pooling_tasks = self.get_supported_pooling_tasks()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user