[Bugfix] Fix deepseek-ocr multi-image inference and add merge_by_field_config=True with tensor schema support (#27361)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-10-23 08:15:38 +08:00 committed by GitHub
parent b4fda58a2d
commit 2566dca2a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 112 additions and 66 deletions

View File

@ -44,6 +44,7 @@ class ModelRequestData(NamedTuple):
stop_token_ids: list[int] | None = None stop_token_ids: list[int] | None = None
chat_template: str | None = None chat_template: str | None = None
lora_requests: list[LoRARequest] | None = None lora_requests: list[LoRARequest] | None = None
sampling_params: SamplingParams | None = None
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on # NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
@ -201,6 +202,46 @@ def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData:
) )
def load_deepseek_ocr(question: str, image_urls: list[str]) -> ModelRequestData:
from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor
model_name = "deepseek-ai/DeepSeek-OCR"
engine_args = EngineArgs(
model=model_name,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
logits_processors=[NGramPerReqLogitsProcessor],
)
placeholder = "<image>\n" * len(image_urls)
prompt = placeholder + question
# The following sampling params config is taken from
# the official Deepseek-OCR inference example.
# (IMPORTANT) Use the custom logits processor and avoid skipping
# special tokens for this model for the optimal OCR performance.
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
# ngram logit processor args
extra_args=dict(
ngram_size=30,
window_size=90,
# whitelist: <td>, </td>
whitelist_token_ids={128821, 128822},
),
skip_special_tokens=False,
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
sampling_params=sampling_params,
)
def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData: def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "google/gemma-3-4b-it" model_name = "google/gemma-3-4b-it"
@ -1253,6 +1294,7 @@ model_example_map = {
"bee": load_bee, "bee": load_bee,
"command_a_vision": load_command_a_vision, "command_a_vision": load_command_a_vision,
"deepseek_vl_v2": load_deepseek_vl2, "deepseek_vl_v2": load_deepseek_vl2,
"deepseek_ocr": load_deepseek_ocr,
"gemma3": load_gemma3, "gemma3": load_gemma3,
"h2ovl_chat": load_h2ovl, "h2ovl_chat": load_h2ovl,
"hyperclovax_seed_vision": load_hyperclovax_seed_vision, "hyperclovax_seed_vision": load_hyperclovax_seed_vision,
@ -1325,8 +1367,12 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: int | None)
engine_args = asdict(req_data.engine_args) | {"seed": seed} engine_args = asdict(req_data.engine_args) | {"seed": seed}
llm = LLM(**engine_args) llm = LLM(**engine_args)
sampling_params = SamplingParams( sampling_params = (
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids SamplingParams(
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
)
if req_data.sampling_params is None
else req_data.sampling_params
) )
outputs = llm.chat( outputs = llm.chat(
[ [

View File

@ -332,6 +332,7 @@ def _test_processing_correctness_one(
"facebook/chameleon-7b", "facebook/chameleon-7b",
"CohereLabs/command-a-vision-07-2025", "CohereLabs/command-a-vision-07-2025",
"deepseek-ai/deepseek-vl2-tiny", "deepseek-ai/deepseek-vl2-tiny",
"deepseek-ai/DeepSeek-OCR",
"baidu/ERNIE-4.5-VL-28B-A3B-PT", "baidu/ERNIE-4.5-VL-28B-A3B-PT",
"adept/fuyu-8b", "adept/fuyu-8b",
"google/gemma-3-4b-it", "google/gemma-3-4b-it",

View File

@ -4,6 +4,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -53,6 +54,7 @@ from vllm.transformers_utils.processors.deepseek_ocr import (
count_tiles, count_tiles,
) )
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.v1.sample.logits_processor import ( from vllm.v1.sample.logits_processor import (
AdapterLogitsProcessor, AdapterLogitsProcessor,
RequestLogitsProcessor, RequestLogitsProcessor,
@ -65,6 +67,28 @@ from .deepseek_vl2 import MlpProjector
_IMAGE_TOKEN = "<image>" _IMAGE_TOKEN = "<image>"
class DeepseekOCRImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- n: Number of images
- p: Number of patches
- base_size: Base size of the processor
- image_size: Image size of the processor
"""
type: Literal["pixel_values"]
data: Annotated[
torch.Tensor,
TensorShape("bn", 3, "base_size", "base_size", dynamic_dims={"bnp"}),
]
images_crop: Annotated[
torch.Tensor,
TensorShape("bnp", 3, "image_size", "image_size", dynamic_dims={"bnp"}),
]
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
class NoRepeatNGramLogitsProcessor: class NoRepeatNGramLogitsProcessor:
def __init__( def __init__(
self, self,
@ -260,10 +284,15 @@ class DeepseekOCRMultiModalProcessor(
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
images_spatial_crop = hf_inputs.get("images_spatial_crop", torch.empty((0, 2)))
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
return dict( return dict(
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"),
images_spatial_crop=MultiModalFieldConfig.batched("image"), images_spatial_crop=MultiModalFieldConfig.batched("image"),
images_crop=MultiModalFieldConfig.batched("image"), images_crop=MultiModalFieldConfig.flat_from_sizes(
"image", patches_per_image
),
) )
def _get_prompt_updates( def _get_prompt_updates(
@ -302,35 +331,6 @@ class DeepseekOCRMultiModalProcessor(
) )
] ]
# TODO(Isotr0py): Check if we still need this workaround for
# deepseek-ocr processor.
# def _cached_apply_hf_processor(
# self,
# prompt: str | list[int],
# mm_data_items: MultiModalDataItems,
# hf_processor_mm_kwargs: Mapping[str, object],
# tokenization_kwargs: Mapping[str, object],
# mm_uuids: MultiModalUUIDDict | None = None,
# ) -> tuple[list[int], MultiModalKwargs, bool]:
# # The processor logic is different for len(images) <= 2 vs > 2
# # Since the processing cache assumes that the processor output is
# # invariant of how many images are passed per prompt, we only
# # perform caching for the most common case
# if mm_data_items.get_count("image", strict=False) > 2:
# # This code path corresponds to the cache being disabled
# return self._apply_hf_processor_main(
# prompt=prompt,
# mm_items=mm_data_items,
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
# enable_hf_prompt_update=True,
# )
# return super()._cached_apply_hf_processor(
# prompt=prompt,
# mm_data_items=mm_data_items,
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
# )
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
DeepseekOCRMultiModalProcessor, DeepseekOCRMultiModalProcessor,
@ -338,6 +338,8 @@ class DeepseekOCRMultiModalProcessor(
dummy_inputs=DeepseekOCRDummyInputsBuilder, dummy_inputs=DeepseekOCRDummyInputsBuilder,
) )
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
# map prefix for language backbone # map prefix for language backbone
@ -389,6 +391,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.vision_model = DeepCLIPVisionTransformer( self.vision_model = DeepCLIPVisionTransformer(
config=clip_vision_config, config=clip_vision_config,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"),
) )
self.projector = MlpProjector(self.projector_config) self.projector = MlpProjector(self.projector_config)
@ -426,7 +429,9 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
) )
def _parse_and_validate_image_input(self, **kwargs: object): def _parse_and_validate_image_input(
self, **kwargs: object
) -> DeepseekOCRImagePixelInputs | None:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
images_spatial_crop = kwargs.pop("images_spatial_crop", None) images_spatial_crop = kwargs.pop("images_spatial_crop", None)
images_crop = kwargs.pop("images_crop", None) images_crop = kwargs.pop("images_crop", None)
@ -435,23 +440,16 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return None return None
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)): base_size = self.vision_config.image_size
raise ValueError( return DeepseekOCRImagePixelInputs(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}" type="pixel_values",
) data=pixel_values,
images_crop=images_crop,
if not isinstance(images_spatial_crop, (torch.Tensor, list)): images_spatial_crop=images_spatial_crop,
raise ValueError( resolve_bindings={
"Incorrect type of image sizes. " "base_size": base_size,
f"Got type: {type(images_spatial_crop)}" },
) )
if not isinstance(images_crop, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of image crop. Got type: {type(images_crop)}"
)
return [pixel_values, images_crop, images_spatial_crop]
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
@ -518,10 +516,13 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) -> NestedTensors: ) -> NestedTensors:
images_in_this_batch = [] images_in_this_batch = []
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
images_crop = images_crop.split(patches_per_image.tolist())
for jdx in range(images_spatial_crop.size(0)): for jdx in range(images_spatial_crop.size(0)):
patches = images_crop[jdx][0].to(torch.bfloat16) patches = images_crop[jdx]
image_ori = pixel_values[jdx] image_ori = pixel_values[[jdx]]
crop_shape = images_spatial_crop[jdx][0] crop_shape = images_spatial_crop[jdx]
global_features = self._encode_global_features(image_ori) global_features = self._encode_global_features(image_ori)
local_features = self._encode_local_features(patches, crop_shape) local_features = self._encode_local_features(patches, crop_shape)
@ -540,10 +541,12 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return images_in_this_batch return images_in_this_batch
def _process_image_input(self, image_input) -> torch.Tensor: def _process_image_input(
pixel_values = image_input[0].to(torch.bfloat16) self, image_input: DeepseekOCRImagePixelInputs
images_crop = image_input[1] ) -> torch.Tensor:
images_spatial_crop = image_input[2].to(dtype=torch.long) pixel_values = image_input.data
images_crop = image_input.images_crop
images_spatial_crop = image_input.images_spatial_crop.to(dtype=torch.long)
vision_features = self._pixel_values_to_embedding( vision_features = self._pixel_values_to_embedding(
pixel_values=pixel_values, pixel_values=pixel_values,

View File

@ -411,20 +411,16 @@ class DeepseekOCRProcessor(ProcessorMixin):
images_seq_mask = images_seq_mask[:-1] images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0: if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size)) pixel_values = torch.zeros((0, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long) images_spatial_crop = torch.zeros((0, 2), dtype=torch.long)
images_crop = torch.zeros( images_crop = torch.zeros((0, 3, self.image_size, self.image_size))
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
else: else:
pixel_values = torch.stack(images_list, dim=0) pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list: if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0) images_crop = torch.stack(images_crop_list, dim=0)
else: else:
images_crop = torch.zeros( images_crop = torch.zeros((0, 3, self.image_size, self.image_size))
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
input_ids = input_ids.unsqueeze(0) input_ids = input_ids.unsqueeze(0)