[Bugfix] Fix deepseek-ocr multi-image inference and add merge_by_field_config=True with tensor schema support (#27361)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-10-23 08:15:38 +08:00 committed by GitHub
parent b4fda58a2d
commit 2566dca2a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 112 additions and 66 deletions

View File

@ -44,6 +44,7 @@ class ModelRequestData(NamedTuple):
stop_token_ids: list[int] | None = None
chat_template: str | None = None
lora_requests: list[LoRARequest] | None = None
sampling_params: SamplingParams | None = None
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
@ -201,6 +202,46 @@ def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_deepseek_ocr(question: str, image_urls: list[str]) -> ModelRequestData:
from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor
model_name = "deepseek-ai/DeepSeek-OCR"
engine_args = EngineArgs(
model=model_name,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
logits_processors=[NGramPerReqLogitsProcessor],
)
placeholder = "<image>\n" * len(image_urls)
prompt = placeholder + question
# The following sampling params config is taken from
# the official Deepseek-OCR inference example.
# (IMPORTANT) Use the custom logits processor and avoid skipping
# special tokens for this model for the optimal OCR performance.
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
# ngram logit processor args
extra_args=dict(
ngram_size=30,
window_size=90,
# whitelist: <td>, </td>
whitelist_token_ids={128821, 128822},
),
skip_special_tokens=False,
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
sampling_params=sampling_params,
)
def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "google/gemma-3-4b-it"
@ -1253,6 +1294,7 @@ model_example_map = {
"bee": load_bee,
"command_a_vision": load_command_a_vision,
"deepseek_vl_v2": load_deepseek_vl2,
"deepseek_ocr": load_deepseek_ocr,
"gemma3": load_gemma3,
"h2ovl_chat": load_h2ovl,
"hyperclovax_seed_vision": load_hyperclovax_seed_vision,
@ -1325,8 +1367,12 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: int | None)
engine_args = asdict(req_data.engine_args) | {"seed": seed}
llm = LLM(**engine_args)
sampling_params = SamplingParams(
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
sampling_params = (
SamplingParams(
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
)
if req_data.sampling_params is None
else req_data.sampling_params
)
outputs = llm.chat(
[

View File

@ -332,6 +332,7 @@ def _test_processing_correctness_one(
"facebook/chameleon-7b",
"CohereLabs/command-a-vision-07-2025",
"deepseek-ai/deepseek-vl2-tiny",
"deepseek-ai/DeepSeek-OCR",
"baidu/ERNIE-4.5-VL-28B-A3B-PT",
"adept/fuyu-8b",
"google/gemma-3-4b-it",

View File

@ -4,6 +4,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal
import torch
import torch.nn as nn
@ -53,6 +54,7 @@ from vllm.transformers_utils.processors.deepseek_ocr import (
count_tiles,
)
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.v1.sample.logits_processor import (
AdapterLogitsProcessor,
RequestLogitsProcessor,
@ -65,6 +67,28 @@ from .deepseek_vl2 import MlpProjector
_IMAGE_TOKEN = "<image>"
class DeepseekOCRImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- n: Number of images
- p: Number of patches
- base_size: Base size of the processor
- image_size: Image size of the processor
"""
type: Literal["pixel_values"]
data: Annotated[
torch.Tensor,
TensorShape("bn", 3, "base_size", "base_size", dynamic_dims={"bnp"}),
]
images_crop: Annotated[
torch.Tensor,
TensorShape("bnp", 3, "image_size", "image_size", dynamic_dims={"bnp"}),
]
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
class NoRepeatNGramLogitsProcessor:
def __init__(
self,
@ -260,10 +284,15 @@ class DeepseekOCRMultiModalProcessor(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
images_spatial_crop = hf_inputs.get("images_spatial_crop", torch.empty((0, 2)))
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
images_spatial_crop=MultiModalFieldConfig.batched("image"),
images_crop=MultiModalFieldConfig.batched("image"),
images_crop=MultiModalFieldConfig.flat_from_sizes(
"image", patches_per_image
),
)
def _get_prompt_updates(
@ -302,35 +331,6 @@ class DeepseekOCRMultiModalProcessor(
)
]
# TODO(Isotr0py): Check if we still need this workaround for
# deepseek-ocr processor.
# def _cached_apply_hf_processor(
# self,
# prompt: str | list[int],
# mm_data_items: MultiModalDataItems,
# hf_processor_mm_kwargs: Mapping[str, object],
# tokenization_kwargs: Mapping[str, object],
# mm_uuids: MultiModalUUIDDict | None = None,
# ) -> tuple[list[int], MultiModalKwargs, bool]:
# # The processor logic is different for len(images) <= 2 vs > 2
# # Since the processing cache assumes that the processor output is
# # invariant of how many images are passed per prompt, we only
# # perform caching for the most common case
# if mm_data_items.get_count("image", strict=False) > 2:
# # This code path corresponds to the cache being disabled
# return self._apply_hf_processor_main(
# prompt=prompt,
# mm_items=mm_data_items,
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
# enable_hf_prompt_update=True,
# )
# return super()._cached_apply_hf_processor(
# prompt=prompt,
# mm_data_items=mm_data_items,
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
# )
@MULTIMODAL_REGISTRY.register_processor(
DeepseekOCRMultiModalProcessor,
@ -338,6 +338,8 @@ class DeepseekOCRMultiModalProcessor(
dummy_inputs=DeepseekOCRDummyInputsBuilder,
)
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# map prefix for language backbone
@ -389,6 +391,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.vision_model = DeepCLIPVisionTransformer(
config=clip_vision_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.projector = MlpProjector(self.projector_config)
@ -426,7 +429,9 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(self, **kwargs: object):
def _parse_and_validate_image_input(
self, **kwargs: object
) -> DeepseekOCRImagePixelInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
images_crop = kwargs.pop("images_crop", None)
@ -435,23 +440,16 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
if not isinstance(images_spatial_crop, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image sizes. "
f"Got type: {type(images_spatial_crop)}"
)
if not isinstance(images_crop, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of image crop. Got type: {type(images_crop)}"
)
return [pixel_values, images_crop, images_spatial_crop]
base_size = self.vision_config.image_size
return DeepseekOCRImagePixelInputs(
type="pixel_values",
data=pixel_values,
images_crop=images_crop,
images_spatial_crop=images_spatial_crop,
resolve_bindings={
"base_size": base_size,
},
)
raise AssertionError("This line should be unreachable.")
@ -518,10 +516,13 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) -> NestedTensors:
images_in_this_batch = []
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
images_crop = images_crop.split(patches_per_image.tolist())
for jdx in range(images_spatial_crop.size(0)):
patches = images_crop[jdx][0].to(torch.bfloat16)
image_ori = pixel_values[jdx]
crop_shape = images_spatial_crop[jdx][0]
patches = images_crop[jdx]
image_ori = pixel_values[[jdx]]
crop_shape = images_spatial_crop[jdx]
global_features = self._encode_global_features(image_ori)
local_features = self._encode_local_features(patches, crop_shape)
@ -540,10 +541,12 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return images_in_this_batch
def _process_image_input(self, image_input) -> torch.Tensor:
pixel_values = image_input[0].to(torch.bfloat16)
images_crop = image_input[1]
images_spatial_crop = image_input[2].to(dtype=torch.long)
def _process_image_input(
self, image_input: DeepseekOCRImagePixelInputs
) -> torch.Tensor:
pixel_values = image_input.data
images_crop = image_input.images_crop
images_spatial_crop = image_input.images_spatial_crop.to(dtype=torch.long)
vision_features = self._pixel_values_to_embedding(
pixel_values=pixel_values,

View File

@ -411,20 +411,16 @@ class DeepseekOCRProcessor(ProcessorMixin):
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
pixel_values = torch.zeros((0, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((0, 2), dtype=torch.long)
images_crop = torch.zeros((0, 3, self.image_size, self.image_size))
else:
pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
images_crop = torch.stack(images_crop_list, dim=0)
else:
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
images_crop = torch.zeros((0, 3, self.image_size, self.image_size))
input_ids = input_ids.unsqueeze(0)