From 2566dca2a9e4e24c941845905e0ebad62441a1fa Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 23 Oct 2025 08:15:38 +0800 Subject: [PATCH] [Bugfix] Fix deepseek-ocr multi-image inference and add `merge_by_field_config=True` with tensor schema support (#27361) Signed-off-by: Isotr0py --- .../vision_language_multi_image.py | 50 +++++++- .../multimodal/processing/test_common.py | 1 + vllm/model_executor/models/deepseek_ocr.py | 113 +++++++++--------- .../processors/deepseek_ocr.py | 14 +-- 4 files changed, 112 insertions(+), 66 deletions(-) diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index bd7e1d6b0466b..b9115121a9463 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -44,6 +44,7 @@ class ModelRequestData(NamedTuple): stop_token_ids: list[int] | None = None chat_template: str | None = None lora_requests: list[LoRARequest] | None = None + sampling_params: SamplingParams | None = None # NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on @@ -201,6 +202,46 @@ def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_deepseek_ocr(question: str, image_urls: list[str]) -> ModelRequestData: + from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor + + model_name = "deepseek-ai/DeepSeek-OCR" + + engine_args = EngineArgs( + model=model_name, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + logits_processors=[NGramPerReqLogitsProcessor], + ) + + placeholder = "\n" * len(image_urls) + prompt = placeholder + question + + # The following sampling params config is taken from + # the official Deepseek-OCR inference example. + # (IMPORTANT) Use the custom logits processor and avoid skipping + # special tokens for this model for the optimal OCR performance. + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=8192, + # ngram logit processor args + extra_args=dict( + ngram_size=30, + window_size=90, + # whitelist: , + whitelist_token_ids={128821, 128822}, + ), + skip_special_tokens=False, + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + sampling_params=sampling_params, + ) + + def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "google/gemma-3-4b-it" @@ -1253,6 +1294,7 @@ model_example_map = { "bee": load_bee, "command_a_vision": load_command_a_vision, "deepseek_vl_v2": load_deepseek_vl2, + "deepseek_ocr": load_deepseek_ocr, "gemma3": load_gemma3, "h2ovl_chat": load_h2ovl, "hyperclovax_seed_vision": load_hyperclovax_seed_vision, @@ -1325,8 +1367,12 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: int | None) engine_args = asdict(req_data.engine_args) | {"seed": seed} llm = LLM(**engine_args) - sampling_params = SamplingParams( - temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids + sampling_params = ( + SamplingParams( + temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids + ) + if req_data.sampling_params is None + else req_data.sampling_params ) outputs = llm.chat( [ diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index a7308244523e0..d0f730630d7f2 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -332,6 +332,7 @@ def _test_processing_correctness_one( "facebook/chameleon-7b", "CohereLabs/command-a-vision-07-2025", "deepseek-ai/deepseek-vl2-tiny", + "deepseek-ai/DeepSeek-OCR", "baidu/ERNIE-4.5-VL-28B-A3B-PT", "adept/fuyu-8b", "google/gemma-3-4b-it", diff --git a/vllm/model_executor/models/deepseek_ocr.py b/vllm/model_executor/models/deepseek_ocr.py index c9064dabc0ab3..fa24db456af4d 100644 --- a/vllm/model_executor/models/deepseek_ocr.py +++ b/vllm/model_executor/models/deepseek_ocr.py @@ -4,6 +4,7 @@ import math from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Literal import torch import torch.nn as nn @@ -53,6 +54,7 @@ from vllm.transformers_utils.processors.deepseek_ocr import ( count_tiles, ) from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.v1.sample.logits_processor import ( AdapterLogitsProcessor, RequestLogitsProcessor, @@ -65,6 +67,28 @@ from .deepseek_vl2 import MlpProjector _IMAGE_TOKEN = "" +class DeepseekOCRImagePixelInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - n: Number of images + - p: Number of patches + - base_size: Base size of the processor + - image_size: Image size of the processor + """ + + type: Literal["pixel_values"] + data: Annotated[ + torch.Tensor, + TensorShape("bn", 3, "base_size", "base_size", dynamic_dims={"bnp"}), + ] + images_crop: Annotated[ + torch.Tensor, + TensorShape("bnp", 3, "image_size", "image_size", dynamic_dims={"bnp"}), + ] + images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)] + + class NoRepeatNGramLogitsProcessor: def __init__( self, @@ -260,10 +284,15 @@ class DeepseekOCRMultiModalProcessor( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: + images_spatial_crop = hf_inputs.get("images_spatial_crop", torch.empty((0, 2))) + is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1) + patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0) return dict( pixel_values=MultiModalFieldConfig.batched("image"), images_spatial_crop=MultiModalFieldConfig.batched("image"), - images_crop=MultiModalFieldConfig.batched("image"), + images_crop=MultiModalFieldConfig.flat_from_sizes( + "image", patches_per_image + ), ) def _get_prompt_updates( @@ -302,35 +331,6 @@ class DeepseekOCRMultiModalProcessor( ) ] - # TODO(Isotr0py): Check if we still need this workaround for - # deepseek-ocr processor. - # def _cached_apply_hf_processor( - # self, - # prompt: str | list[int], - # mm_data_items: MultiModalDataItems, - # hf_processor_mm_kwargs: Mapping[str, object], - # tokenization_kwargs: Mapping[str, object], - # mm_uuids: MultiModalUUIDDict | None = None, - # ) -> tuple[list[int], MultiModalKwargs, bool]: - # # The processor logic is different for len(images) <= 2 vs > 2 - # # Since the processing cache assumes that the processor output is - # # invariant of how many images are passed per prompt, we only - # # perform caching for the most common case - # if mm_data_items.get_count("image", strict=False) > 2: - # # This code path corresponds to the cache being disabled - # return self._apply_hf_processor_main( - # prompt=prompt, - # mm_items=mm_data_items, - # hf_processor_mm_kwargs=hf_processor_mm_kwargs, - # enable_hf_prompt_update=True, - # ) - - # return super()._cached_apply_hf_processor( - # prompt=prompt, - # mm_data_items=mm_data_items, - # hf_processor_mm_kwargs=hf_processor_mm_kwargs, - # ) - @MULTIMODAL_REGISTRY.register_processor( DeepseekOCRMultiModalProcessor, @@ -338,6 +338,8 @@ class DeepseekOCRMultiModalProcessor( dummy_inputs=DeepseekOCRDummyInputsBuilder, ) class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True + hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # map prefix for language backbone @@ -389,6 +391,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): self.vision_model = DeepCLIPVisionTransformer( config=clip_vision_config, quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model"), ) self.projector = MlpProjector(self.projector_config) @@ -426,7 +429,9 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): self.language_model.make_empty_intermediate_tensors ) - def _parse_and_validate_image_input(self, **kwargs: object): + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> DeepseekOCRImagePixelInputs | None: pixel_values = kwargs.pop("pixel_values", None) images_spatial_crop = kwargs.pop("images_spatial_crop", None) images_crop = kwargs.pop("images_crop", None) @@ -435,23 +440,16 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of pixel values. Got type: {type(pixel_values)}" - ) - - if not isinstance(images_spatial_crop, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of image sizes. " - f"Got type: {type(images_spatial_crop)}" - ) - - if not isinstance(images_crop, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of image crop. Got type: {type(images_crop)}" - ) - - return [pixel_values, images_crop, images_spatial_crop] + base_size = self.vision_config.image_size + return DeepseekOCRImagePixelInputs( + type="pixel_values", + data=pixel_values, + images_crop=images_crop, + images_spatial_crop=images_spatial_crop, + resolve_bindings={ + "base_size": base_size, + }, + ) raise AssertionError("This line should be unreachable.") @@ -518,10 +516,13 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ) -> NestedTensors: images_in_this_batch = [] + is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1) + patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0) + images_crop = images_crop.split(patches_per_image.tolist()) for jdx in range(images_spatial_crop.size(0)): - patches = images_crop[jdx][0].to(torch.bfloat16) - image_ori = pixel_values[jdx] - crop_shape = images_spatial_crop[jdx][0] + patches = images_crop[jdx] + image_ori = pixel_values[[jdx]] + crop_shape = images_spatial_crop[jdx] global_features = self._encode_global_features(image_ori) local_features = self._encode_local_features(patches, crop_shape) @@ -540,10 +541,12 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): return images_in_this_batch - def _process_image_input(self, image_input) -> torch.Tensor: - pixel_values = image_input[0].to(torch.bfloat16) - images_crop = image_input[1] - images_spatial_crop = image_input[2].to(dtype=torch.long) + def _process_image_input( + self, image_input: DeepseekOCRImagePixelInputs + ) -> torch.Tensor: + pixel_values = image_input.data + images_crop = image_input.images_crop + images_spatial_crop = image_input.images_spatial_crop.to(dtype=torch.long) vision_features = self._pixel_values_to_embedding( pixel_values=pixel_values, diff --git a/vllm/transformers_utils/processors/deepseek_ocr.py b/vllm/transformers_utils/processors/deepseek_ocr.py index 99f2df3342e02..bb7aa0c174867 100644 --- a/vllm/transformers_utils/processors/deepseek_ocr.py +++ b/vllm/transformers_utils/processors/deepseek_ocr.py @@ -411,20 +411,16 @@ class DeepseekOCRProcessor(ProcessorMixin): images_seq_mask = images_seq_mask[:-1] if len(images_list) == 0: - pixel_values = torch.zeros((1, 3, self.base_size, self.base_size)) - images_spatial_crop = torch.zeros((1, 1), dtype=torch.long) - images_crop = torch.zeros( - (1, 3, self.image_size, self.image_size) - ).unsqueeze(0) + pixel_values = torch.zeros((0, 3, self.base_size, self.base_size)) + images_spatial_crop = torch.zeros((0, 2), dtype=torch.long) + images_crop = torch.zeros((0, 3, self.image_size, self.image_size)) else: pixel_values = torch.stack(images_list, dim=0) images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) if images_crop_list: - images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0) + images_crop = torch.stack(images_crop_list, dim=0) else: - images_crop = torch.zeros( - (1, 3, self.image_size, self.image_size) - ).unsqueeze(0) + images_crop = torch.zeros((0, 3, self.image_size, self.image_size)) input_ids = input_ids.unsqueeze(0)