From 1405f0c7bab17b10682d2de337bd544b3e9f2f92 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 1 Oct 2025 16:31:03 +0800 Subject: [PATCH] [Misc] Factor out common `_apply_feature_select_strategy` (#26003) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/llava.py | 19 +++----------- vllm/model_executor/models/llava_next.py | 5 ++-- vllm/model_executor/models/tarsier.py | 23 +++++------------ vllm/model_executor/models/vision.py | 32 +++++++++++++++++++++--- 4 files changed, 40 insertions(+), 39 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index d823e5cb58d26..78c413b770516 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -41,7 +41,7 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix) -from .vision import get_vision_encoder_info +from .vision import get_num_selected_vision_tokens, get_vision_encoder_info class LlavaImagePixelInputs(TensorSchema): @@ -147,19 +147,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def _apply_feature_select_strategy( - self, - strategy: str, - encoder_num_image_tokens: int, - ) -> int: - if strategy == "default": - return encoder_num_image_tokens - 1 - if strategy == "full": - return encoder_num_image_tokens - - msg = f"Unexpected feature select strategy: {strategy!r}" - raise NotImplementedError(msg) - def get_num_image_tokens( self, *, @@ -169,12 +156,12 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): hf_config = self.get_hf_config() vision_encoder_info = self.get_vision_encoder_info() - return self._apply_feature_select_strategy( - hf_config.vision_feature_select_strategy, + return get_num_selected_vision_tokens( vision_encoder_info.get_num_image_tokens( image_width=image_width, image_height=image_height, ), + hf_config.vision_feature_select_strategy, ) def get_image_size_with_most_features(self) -> ImageSize: diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 3f7e39c020617..70fd0b2e5efbb 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -27,6 +27,7 @@ from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo, from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix) +from .vision import get_num_selected_vision_tokens class LlavaNextImagePixelInputs(TensorSchema): @@ -95,12 +96,12 @@ class LlavaNextProcessingInfo(BaseLlavaProcessingInfo): hf_config = self.get_hf_config() vision_encoder_info = self.get_vision_encoder_info() - base_feature_size = self._apply_feature_select_strategy( - hf_config.vision_feature_select_strategy, + base_feature_size = get_num_selected_vision_tokens( vision_encoder_info.get_num_image_tokens( image_width=image_width, image_height=image_height, ), + hf_config.vision_feature_select_strategy, ) num_patch_height, num_patch_width = get_anyres_image_grid_shape( diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index ed02fe2c389f4..8759c4ea4a64c 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -40,7 +40,8 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix) -from .vision import VisionEncoderInfo, get_vision_encoder_info +from .vision import (VisionEncoderInfo, get_num_selected_vision_tokens, + get_vision_encoder_info) class TarsierImagePixelInputs(TensorSchema): @@ -201,18 +202,6 @@ class TarsierProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def _apply_feature_select_strategy( - self, - strategy: str, - encoder_num_image_tokens: int, - ) -> int: - if strategy == "default": - return encoder_num_image_tokens - 1 - if strategy == "full": - return encoder_num_image_tokens - msg = f"Unexpected feature select strategy: {strategy!r}" - raise NotImplementedError(msg) - def get_num_image_tokens( self, *, @@ -221,21 +210,21 @@ class TarsierProcessingInfo(BaseProcessingInfo): ) -> int: hf_config = self.get_hf_config() vision_encoder_info = self.get_vision_encoder_info() - num_projected_patches = self._apply_feature_select_strategy( - hf_config.vision_feature_select_strategy, + num_projected_patches = get_num_selected_vision_tokens( vision_encoder_info.get_num_image_tokens( image_width=image_width, image_height=image_height, ), + hf_config.vision_feature_select_strategy, ) if num_projected_patches <= 0: default_size = self.get_image_size_with_most_features() - num_projected_patches_default = self._apply_feature_select_strategy( - hf_config.vision_feature_select_strategy, + num_projected_patches_default = get_num_selected_vision_tokens( vision_encoder_info.get_num_image_tokens( image_width=default_size.width, image_height=default_size.height, ), + hf_config.vision_feature_select_strategy, ) if num_projected_patches_default <= 0: raise ValueError( diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index e077691fcec21..3d16d71e1764a 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -9,7 +9,6 @@ from typing import (Callable, Final, Generic, Literal, Optional, Protocol, import torch from transformers import PretrainedConfig -from typing_extensions import assert_never from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -22,9 +21,13 @@ logger = init_logger(__name__) _C = TypeVar("_C", bound=PretrainedConfig) +class _RootConfig(Protocol[_C]): + vision_config: _C + + class VisionEncoderInfo(ABC, Generic[_C]): - def __init__(self, hf_config: _C) -> None: + def __init__(self, hf_config: _RootConfig[_C]) -> None: super().__init__() self.hf_config = hf_config @@ -95,7 +98,7 @@ VisionFeatureSelectStrategy = Union[ def _get_vision_feature_selector( - strategy: VisionFeatureSelectStrategy, + strategy: Union[VisionFeatureSelectStrategy, str], ) -> Callable[[torch.Tensor], torch.Tensor]: if callable(strategy): return strategy @@ -111,7 +114,28 @@ def _get_vision_feature_selector( if strategy == "full": return lambda feats: feats - assert_never(strategy) + raise ValueError(f"Unexpected feature select strategy: {strategy!r}") + + +def get_num_selected_vision_tokens( + num_vision_tokens: int, + strategy: Union[VisionFeatureSelectStrategy, str], +) -> int: + if callable(strategy): + dummy_features = torch.empty(1, num_vision_tokens, 64) # [B, L, D] + dummy_selected_features = strategy(dummy_features) + return dummy_selected_features.shape[1] + + if strategy == "class": + return 1 + + if strategy == "default": + return num_vision_tokens - 1 + + if strategy == "full": + return num_vision_tokens + + raise ValueError(f"Unexpected feature select strategy: {strategy!r}") def resolve_visual_encoder_outputs(