[Bugfix] Fix deepseek-vl2 inference with more than 2 images (#13818)

This commit is contained in:
Isotr0py 2025-02-25 22:03:02 +08:00 committed by GitHub
parent fa82074167
commit 6ff518626c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 10 deletions

View File

@ -25,7 +25,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
BaseProcessingInfo, ProcessingCache,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
@ -138,18 +139,24 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_num_image_tokens(self, *, image_width: int,
image_height: int) -> int:
def get_num_image_tokens(self,
*,
image_width: int,
image_height: int,
cropping: bool = True) -> int:
hf_processor = self.get_hf_processor()
image_size = hf_processor.image_size
patch_size = hf_processor.patch_size
downsample_ratio = hf_processor.downsample_ratio
best_width, best_height = hf_processor.select_best_resolution(
(image_width, image_height))
if cropping:
best_width, best_height = hf_processor.select_best_resolution(
(image_width, image_height))
num_width_tiles, num_height_tiles = (best_width // image_size,
best_height // image_size)
else:
num_width_tiles = num_height_tiles = 1
num_width_tiles, num_height_tiles = (best_width // image_size,
best_height // image_size)
h = w = math.ceil((image_size // patch_size) / downsample_ratio)
global_views_tokens = h * (w + 1)
@ -169,10 +176,12 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
num_images = mm_counts.get("image", 0)
max_image_size = self.get_image_size_with_most_features()
max_image_tokens = self.get_num_image_tokens(
image_height=max_image_size.height,
image_width=max_image_size.width)
image_width=max_image_size.width,
cropping=num_images <= 2)
return {"image": max_image_tokens}
@ -207,6 +216,30 @@ class DeepseekVL2DummyInputsBuilder(
class DeepseekVL2MultiModalProcessor(
BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]):
def __init__(
self,
info: DeepseekVL2ProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]",
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(
info,
dummy_inputs,
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt
if self.cache is not None and mm_limit["image"] > 2:
# The processor output depends on the number of images passed,
# making it incompatible with processing cache which is supposed
# to be invariant of how many images are passed per prompt
self.cache = None
logger.warning_once(
f"{type(self).__name__} does not support processing cache with "
"image limit larger than 2.")
def _call_hf_processor(
self,
prompt: str,
@ -271,6 +304,7 @@ class DeepseekVL2MultiModalProcessor(
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
cropping=len(images) <= 2,
)
return [image_token_id] * num_image_tokens

View File

@ -477,13 +477,15 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
enable_sanity_checks=enable_sanity_checks,
)
if self.cache is not None:
mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt
if self.cache is not None and mm_limit["image"] >= 2:
# The processor output depends on the number of images passed,
# making it incompatible with processing cache which is supposed
# to be invariant of how many images are passed per prompt
self.cache = None
logger.warning_once(
f"{type(self).__name__} does not support processing cache.")
f"{type(self).__name__} does not support processing cache with "
"multi-image support enabled.")
def _get_prompt_replacements(
self,