mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 15:14:29 +08:00
[Bugfix] Fix deepseek-ocr multi-image inference and add merge_by_field_config=True with tensor schema support (#27361)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
b4fda58a2d
commit
2566dca2a9
@ -44,6 +44,7 @@ class ModelRequestData(NamedTuple):
|
||||
stop_token_ids: list[int] | None = None
|
||||
chat_template: str | None = None
|
||||
lora_requests: list[LoRARequest] | None = None
|
||||
sampling_params: SamplingParams | None = None
|
||||
|
||||
|
||||
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||
@ -201,6 +202,46 @@ def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_deepseek_ocr(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor
|
||||
|
||||
model_name = "deepseek-ai/DeepSeek-OCR"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
logits_processors=[NGramPerReqLogitsProcessor],
|
||||
)
|
||||
|
||||
placeholder = "<image>\n" * len(image_urls)
|
||||
prompt = placeholder + question
|
||||
|
||||
# The following sampling params config is taken from
|
||||
# the official Deepseek-OCR inference example.
|
||||
# (IMPORTANT) Use the custom logits processor and avoid skipping
|
||||
# special tokens for this model for the optimal OCR performance.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=8192,
|
||||
# ngram logit processor args
|
||||
extra_args=dict(
|
||||
ngram_size=30,
|
||||
window_size=90,
|
||||
# whitelist: <td>, </td>
|
||||
whitelist_token_ids={128821, 128822},
|
||||
),
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
|
||||
def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "google/gemma-3-4b-it"
|
||||
|
||||
@ -1253,6 +1294,7 @@ model_example_map = {
|
||||
"bee": load_bee,
|
||||
"command_a_vision": load_command_a_vision,
|
||||
"deepseek_vl_v2": load_deepseek_vl2,
|
||||
"deepseek_ocr": load_deepseek_ocr,
|
||||
"gemma3": load_gemma3,
|
||||
"h2ovl_chat": load_h2ovl,
|
||||
"hyperclovax_seed_vision": load_hyperclovax_seed_vision,
|
||||
@ -1325,8 +1367,12 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: int | None)
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
|
||||
sampling_params = (
|
||||
SamplingParams(
|
||||
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
|
||||
)
|
||||
if req_data.sampling_params is None
|
||||
else req_data.sampling_params
|
||||
)
|
||||
outputs = llm.chat(
|
||||
[
|
||||
|
||||
@ -332,6 +332,7 @@ def _test_processing_correctness_one(
|
||||
"facebook/chameleon-7b",
|
||||
"CohereLabs/command-a-vision-07-2025",
|
||||
"deepseek-ai/deepseek-vl2-tiny",
|
||||
"deepseek-ai/DeepSeek-OCR",
|
||||
"baidu/ERNIE-4.5-VL-28B-A3B-PT",
|
||||
"adept/fuyu-8b",
|
||||
"google/gemma-3-4b-it",
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -53,6 +54,7 @@ from vllm.transformers_utils.processors.deepseek_ocr import (
|
||||
count_tiles,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
AdapterLogitsProcessor,
|
||||
RequestLogitsProcessor,
|
||||
@ -65,6 +67,28 @@ from .deepseek_vl2 import MlpProjector
|
||||
_IMAGE_TOKEN = "<image>"
|
||||
|
||||
|
||||
class DeepseekOCRImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- n: Number of images
|
||||
- p: Number of patches
|
||||
- base_size: Base size of the processor
|
||||
- image_size: Image size of the processor
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"]
|
||||
data: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("bn", 3, "base_size", "base_size", dynamic_dims={"bnp"}),
|
||||
]
|
||||
images_crop: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("bnp", 3, "image_size", "image_size", dynamic_dims={"bnp"}),
|
||||
]
|
||||
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
|
||||
|
||||
|
||||
class NoRepeatNGramLogitsProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
@ -260,10 +284,15 @@ class DeepseekOCRMultiModalProcessor(
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
images_spatial_crop = hf_inputs.get("images_spatial_crop", torch.empty((0, 2)))
|
||||
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
|
||||
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
images_spatial_crop=MultiModalFieldConfig.batched("image"),
|
||||
images_crop=MultiModalFieldConfig.batched("image"),
|
||||
images_crop=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", patches_per_image
|
||||
),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@ -302,35 +331,6 @@ class DeepseekOCRMultiModalProcessor(
|
||||
)
|
||||
]
|
||||
|
||||
# TODO(Isotr0py): Check if we still need this workaround for
|
||||
# deepseek-ocr processor.
|
||||
# def _cached_apply_hf_processor(
|
||||
# self,
|
||||
# prompt: str | list[int],
|
||||
# mm_data_items: MultiModalDataItems,
|
||||
# hf_processor_mm_kwargs: Mapping[str, object],
|
||||
# tokenization_kwargs: Mapping[str, object],
|
||||
# mm_uuids: MultiModalUUIDDict | None = None,
|
||||
# ) -> tuple[list[int], MultiModalKwargs, bool]:
|
||||
# # The processor logic is different for len(images) <= 2 vs > 2
|
||||
# # Since the processing cache assumes that the processor output is
|
||||
# # invariant of how many images are passed per prompt, we only
|
||||
# # perform caching for the most common case
|
||||
# if mm_data_items.get_count("image", strict=False) > 2:
|
||||
# # This code path corresponds to the cache being disabled
|
||||
# return self._apply_hf_processor_main(
|
||||
# prompt=prompt,
|
||||
# mm_items=mm_data_items,
|
||||
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
# enable_hf_prompt_update=True,
|
||||
# )
|
||||
|
||||
# return super()._cached_apply_hf_processor(
|
||||
# prompt=prompt,
|
||||
# mm_data_items=mm_data_items,
|
||||
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
# )
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
DeepseekOCRMultiModalProcessor,
|
||||
@ -338,6 +338,8 @@ class DeepseekOCRMultiModalProcessor(
|
||||
dummy_inputs=DeepseekOCRDummyInputsBuilder,
|
||||
)
|
||||
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# map prefix for language backbone
|
||||
@ -389,6 +391,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.vision_model = DeepCLIPVisionTransformer(
|
||||
config=clip_vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
|
||||
self.projector = MlpProjector(self.projector_config)
|
||||
@ -426,7 +429,9 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_image_input(self, **kwargs: object):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> DeepseekOCRImagePixelInputs | None:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
|
||||
images_crop = kwargs.pop("images_crop", None)
|
||||
@ -435,23 +440,16 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
|
||||
)
|
||||
|
||||
if not isinstance(images_spatial_crop, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of image sizes. "
|
||||
f"Got type: {type(images_spatial_crop)}"
|
||||
)
|
||||
|
||||
if not isinstance(images_crop, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of image crop. Got type: {type(images_crop)}"
|
||||
)
|
||||
|
||||
return [pixel_values, images_crop, images_spatial_crop]
|
||||
base_size = self.vision_config.image_size
|
||||
return DeepseekOCRImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=pixel_values,
|
||||
images_crop=images_crop,
|
||||
images_spatial_crop=images_spatial_crop,
|
||||
resolve_bindings={
|
||||
"base_size": base_size,
|
||||
},
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
@ -518,10 +516,13 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> NestedTensors:
|
||||
images_in_this_batch = []
|
||||
|
||||
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
|
||||
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
|
||||
images_crop = images_crop.split(patches_per_image.tolist())
|
||||
for jdx in range(images_spatial_crop.size(0)):
|
||||
patches = images_crop[jdx][0].to(torch.bfloat16)
|
||||
image_ori = pixel_values[jdx]
|
||||
crop_shape = images_spatial_crop[jdx][0]
|
||||
patches = images_crop[jdx]
|
||||
image_ori = pixel_values[[jdx]]
|
||||
crop_shape = images_spatial_crop[jdx]
|
||||
|
||||
global_features = self._encode_global_features(image_ori)
|
||||
local_features = self._encode_local_features(patches, crop_shape)
|
||||
@ -540,10 +541,12 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return images_in_this_batch
|
||||
|
||||
def _process_image_input(self, image_input) -> torch.Tensor:
|
||||
pixel_values = image_input[0].to(torch.bfloat16)
|
||||
images_crop = image_input[1]
|
||||
images_spatial_crop = image_input[2].to(dtype=torch.long)
|
||||
def _process_image_input(
|
||||
self, image_input: DeepseekOCRImagePixelInputs
|
||||
) -> torch.Tensor:
|
||||
pixel_values = image_input.data
|
||||
images_crop = image_input.images_crop
|
||||
images_spatial_crop = image_input.images_spatial_crop.to(dtype=torch.long)
|
||||
|
||||
vision_features = self._pixel_values_to_embedding(
|
||||
pixel_values=pixel_values,
|
||||
|
||||
@ -411,20 +411,16 @@ class DeepseekOCRProcessor(ProcessorMixin):
|
||||
images_seq_mask = images_seq_mask[:-1]
|
||||
|
||||
if len(images_list) == 0:
|
||||
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
|
||||
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
|
||||
images_crop = torch.zeros(
|
||||
(1, 3, self.image_size, self.image_size)
|
||||
).unsqueeze(0)
|
||||
pixel_values = torch.zeros((0, 3, self.base_size, self.base_size))
|
||||
images_spatial_crop = torch.zeros((0, 2), dtype=torch.long)
|
||||
images_crop = torch.zeros((0, 3, self.image_size, self.image_size))
|
||||
else:
|
||||
pixel_values = torch.stack(images_list, dim=0)
|
||||
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
|
||||
if images_crop_list:
|
||||
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
|
||||
images_crop = torch.stack(images_crop_list, dim=0)
|
||||
else:
|
||||
images_crop = torch.zeros(
|
||||
(1, 3, self.image_size, self.image_size)
|
||||
).unsqueeze(0)
|
||||
images_crop = torch.zeros((0, 3, self.image_size, self.image_size))
|
||||
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user