mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 19:07:53 +08:00
[Bugfix] Fix deepseek-ocr multi-image inference and add merge_by_field_config=True with tensor schema support (#27361)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
b4fda58a2d
commit
2566dca2a9
@ -44,6 +44,7 @@ class ModelRequestData(NamedTuple):
|
|||||||
stop_token_ids: list[int] | None = None
|
stop_token_ids: list[int] | None = None
|
||||||
chat_template: str | None = None
|
chat_template: str | None = None
|
||||||
lora_requests: list[LoRARequest] | None = None
|
lora_requests: list[LoRARequest] | None = None
|
||||||
|
sampling_params: SamplingParams | None = None
|
||||||
|
|
||||||
|
|
||||||
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||||
@ -201,6 +202,46 @@ def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_deepseek_ocr(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||||
|
from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor
|
||||||
|
|
||||||
|
model_name = "deepseek-ai/DeepSeek-OCR"
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model_name,
|
||||||
|
max_num_seqs=2,
|
||||||
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
|
logits_processors=[NGramPerReqLogitsProcessor],
|
||||||
|
)
|
||||||
|
|
||||||
|
placeholder = "<image>\n" * len(image_urls)
|
||||||
|
prompt = placeholder + question
|
||||||
|
|
||||||
|
# The following sampling params config is taken from
|
||||||
|
# the official Deepseek-OCR inference example.
|
||||||
|
# (IMPORTANT) Use the custom logits processor and avoid skipping
|
||||||
|
# special tokens for this model for the optimal OCR performance.
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=8192,
|
||||||
|
# ngram logit processor args
|
||||||
|
extra_args=dict(
|
||||||
|
ngram_size=30,
|
||||||
|
window_size=90,
|
||||||
|
# whitelist: <td>, </td>
|
||||||
|
whitelist_token_ids={128821, 128822},
|
||||||
|
),
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ModelRequestData(
|
||||||
|
engine_args=engine_args,
|
||||||
|
prompt=prompt,
|
||||||
|
image_data=[fetch_image(url) for url in image_urls],
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
|
def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||||
model_name = "google/gemma-3-4b-it"
|
model_name = "google/gemma-3-4b-it"
|
||||||
|
|
||||||
@ -1253,6 +1294,7 @@ model_example_map = {
|
|||||||
"bee": load_bee,
|
"bee": load_bee,
|
||||||
"command_a_vision": load_command_a_vision,
|
"command_a_vision": load_command_a_vision,
|
||||||
"deepseek_vl_v2": load_deepseek_vl2,
|
"deepseek_vl_v2": load_deepseek_vl2,
|
||||||
|
"deepseek_ocr": load_deepseek_ocr,
|
||||||
"gemma3": load_gemma3,
|
"gemma3": load_gemma3,
|
||||||
"h2ovl_chat": load_h2ovl,
|
"h2ovl_chat": load_h2ovl,
|
||||||
"hyperclovax_seed_vision": load_hyperclovax_seed_vision,
|
"hyperclovax_seed_vision": load_hyperclovax_seed_vision,
|
||||||
@ -1325,8 +1367,12 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: int | None)
|
|||||||
engine_args = asdict(req_data.engine_args) | {"seed": seed}
|
engine_args = asdict(req_data.engine_args) | {"seed": seed}
|
||||||
llm = LLM(**engine_args)
|
llm = LLM(**engine_args)
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = (
|
||||||
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
|
SamplingParams(
|
||||||
|
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
|
||||||
|
)
|
||||||
|
if req_data.sampling_params is None
|
||||||
|
else req_data.sampling_params
|
||||||
)
|
)
|
||||||
outputs = llm.chat(
|
outputs = llm.chat(
|
||||||
[
|
[
|
||||||
|
|||||||
@ -332,6 +332,7 @@ def _test_processing_correctness_one(
|
|||||||
"facebook/chameleon-7b",
|
"facebook/chameleon-7b",
|
||||||
"CohereLabs/command-a-vision-07-2025",
|
"CohereLabs/command-a-vision-07-2025",
|
||||||
"deepseek-ai/deepseek-vl2-tiny",
|
"deepseek-ai/deepseek-vl2-tiny",
|
||||||
|
"deepseek-ai/DeepSeek-OCR",
|
||||||
"baidu/ERNIE-4.5-VL-28B-A3B-PT",
|
"baidu/ERNIE-4.5-VL-28B-A3B-PT",
|
||||||
"adept/fuyu-8b",
|
"adept/fuyu-8b",
|
||||||
"google/gemma-3-4b-it",
|
"google/gemma-3-4b-it",
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -53,6 +54,7 @@ from vllm.transformers_utils.processors.deepseek_ocr import (
|
|||||||
count_tiles,
|
count_tiles,
|
||||||
)
|
)
|
||||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
from vllm.v1.sample.logits_processor import (
|
from vllm.v1.sample.logits_processor import (
|
||||||
AdapterLogitsProcessor,
|
AdapterLogitsProcessor,
|
||||||
RequestLogitsProcessor,
|
RequestLogitsProcessor,
|
||||||
@ -65,6 +67,28 @@ from .deepseek_vl2 import MlpProjector
|
|||||||
_IMAGE_TOKEN = "<image>"
|
_IMAGE_TOKEN = "<image>"
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekOCRImagePixelInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- b: Batch size
|
||||||
|
- n: Number of images
|
||||||
|
- p: Number of patches
|
||||||
|
- base_size: Base size of the processor
|
||||||
|
- image_size: Image size of the processor
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["pixel_values"]
|
||||||
|
data: Annotated[
|
||||||
|
torch.Tensor,
|
||||||
|
TensorShape("bn", 3, "base_size", "base_size", dynamic_dims={"bnp"}),
|
||||||
|
]
|
||||||
|
images_crop: Annotated[
|
||||||
|
torch.Tensor,
|
||||||
|
TensorShape("bnp", 3, "image_size", "image_size", dynamic_dims={"bnp"}),
|
||||||
|
]
|
||||||
|
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
|
||||||
|
|
||||||
|
|
||||||
class NoRepeatNGramLogitsProcessor:
|
class NoRepeatNGramLogitsProcessor:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -260,10 +284,15 @@ class DeepseekOCRMultiModalProcessor(
|
|||||||
hf_inputs: BatchFeature,
|
hf_inputs: BatchFeature,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
images_spatial_crop = hf_inputs.get("images_spatial_crop", torch.empty((0, 2)))
|
||||||
|
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
|
||||||
|
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
|
||||||
return dict(
|
return dict(
|
||||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
images_spatial_crop=MultiModalFieldConfig.batched("image"),
|
images_spatial_crop=MultiModalFieldConfig.batched("image"),
|
||||||
images_crop=MultiModalFieldConfig.batched("image"),
|
images_crop=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"image", patches_per_image
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
@ -302,35 +331,6 @@ class DeepseekOCRMultiModalProcessor(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO(Isotr0py): Check if we still need this workaround for
|
|
||||||
# deepseek-ocr processor.
|
|
||||||
# def _cached_apply_hf_processor(
|
|
||||||
# self,
|
|
||||||
# prompt: str | list[int],
|
|
||||||
# mm_data_items: MultiModalDataItems,
|
|
||||||
# hf_processor_mm_kwargs: Mapping[str, object],
|
|
||||||
# tokenization_kwargs: Mapping[str, object],
|
|
||||||
# mm_uuids: MultiModalUUIDDict | None = None,
|
|
||||||
# ) -> tuple[list[int], MultiModalKwargs, bool]:
|
|
||||||
# # The processor logic is different for len(images) <= 2 vs > 2
|
|
||||||
# # Since the processing cache assumes that the processor output is
|
|
||||||
# # invariant of how many images are passed per prompt, we only
|
|
||||||
# # perform caching for the most common case
|
|
||||||
# if mm_data_items.get_count("image", strict=False) > 2:
|
|
||||||
# # This code path corresponds to the cache being disabled
|
|
||||||
# return self._apply_hf_processor_main(
|
|
||||||
# prompt=prompt,
|
|
||||||
# mm_items=mm_data_items,
|
|
||||||
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
|
||||||
# enable_hf_prompt_update=True,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# return super()._cached_apply_hf_processor(
|
|
||||||
# prompt=prompt,
|
|
||||||
# mm_data_items=mm_data_items,
|
|
||||||
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
DeepseekOCRMultiModalProcessor,
|
DeepseekOCRMultiModalProcessor,
|
||||||
@ -338,6 +338,8 @@ class DeepseekOCRMultiModalProcessor(
|
|||||||
dummy_inputs=DeepseekOCRDummyInputsBuilder,
|
dummy_inputs=DeepseekOCRDummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={
|
orig_to_new_prefix={
|
||||||
# map prefix for language backbone
|
# map prefix for language backbone
|
||||||
@ -389,6 +391,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self.vision_model = DeepCLIPVisionTransformer(
|
self.vision_model = DeepCLIPVisionTransformer(
|
||||||
config=clip_vision_config,
|
config=clip_vision_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "vision_model"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.projector = MlpProjector(self.projector_config)
|
self.projector = MlpProjector(self.projector_config)
|
||||||
@ -426,7 +429,9 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self.language_model.make_empty_intermediate_tensors
|
self.language_model.make_empty_intermediate_tensors
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_and_validate_image_input(self, **kwargs: object):
|
def _parse_and_validate_image_input(
|
||||||
|
self, **kwargs: object
|
||||||
|
) -> DeepseekOCRImagePixelInputs | None:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
|
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
|
||||||
images_crop = kwargs.pop("images_crop", None)
|
images_crop = kwargs.pop("images_crop", None)
|
||||||
@ -435,23 +440,16 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
base_size = self.vision_config.image_size
|
||||||
raise ValueError(
|
return DeepseekOCRImagePixelInputs(
|
||||||
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
|
type="pixel_values",
|
||||||
)
|
data=pixel_values,
|
||||||
|
images_crop=images_crop,
|
||||||
if not isinstance(images_spatial_crop, (torch.Tensor, list)):
|
images_spatial_crop=images_spatial_crop,
|
||||||
raise ValueError(
|
resolve_bindings={
|
||||||
"Incorrect type of image sizes. "
|
"base_size": base_size,
|
||||||
f"Got type: {type(images_spatial_crop)}"
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(images_crop, (torch.Tensor, list)):
|
|
||||||
raise ValueError(
|
|
||||||
f"Incorrect type of image crop. Got type: {type(images_crop)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return [pixel_values, images_crop, images_spatial_crop]
|
|
||||||
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
@ -518,10 +516,13 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
) -> NestedTensors:
|
) -> NestedTensors:
|
||||||
images_in_this_batch = []
|
images_in_this_batch = []
|
||||||
|
|
||||||
|
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
|
||||||
|
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
|
||||||
|
images_crop = images_crop.split(patches_per_image.tolist())
|
||||||
for jdx in range(images_spatial_crop.size(0)):
|
for jdx in range(images_spatial_crop.size(0)):
|
||||||
patches = images_crop[jdx][0].to(torch.bfloat16)
|
patches = images_crop[jdx]
|
||||||
image_ori = pixel_values[jdx]
|
image_ori = pixel_values[[jdx]]
|
||||||
crop_shape = images_spatial_crop[jdx][0]
|
crop_shape = images_spatial_crop[jdx]
|
||||||
|
|
||||||
global_features = self._encode_global_features(image_ori)
|
global_features = self._encode_global_features(image_ori)
|
||||||
local_features = self._encode_local_features(patches, crop_shape)
|
local_features = self._encode_local_features(patches, crop_shape)
|
||||||
@ -540,10 +541,12 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
return images_in_this_batch
|
return images_in_this_batch
|
||||||
|
|
||||||
def _process_image_input(self, image_input) -> torch.Tensor:
|
def _process_image_input(
|
||||||
pixel_values = image_input[0].to(torch.bfloat16)
|
self, image_input: DeepseekOCRImagePixelInputs
|
||||||
images_crop = image_input[1]
|
) -> torch.Tensor:
|
||||||
images_spatial_crop = image_input[2].to(dtype=torch.long)
|
pixel_values = image_input.data
|
||||||
|
images_crop = image_input.images_crop
|
||||||
|
images_spatial_crop = image_input.images_spatial_crop.to(dtype=torch.long)
|
||||||
|
|
||||||
vision_features = self._pixel_values_to_embedding(
|
vision_features = self._pixel_values_to_embedding(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
|
|||||||
@ -411,20 +411,16 @@ class DeepseekOCRProcessor(ProcessorMixin):
|
|||||||
images_seq_mask = images_seq_mask[:-1]
|
images_seq_mask = images_seq_mask[:-1]
|
||||||
|
|
||||||
if len(images_list) == 0:
|
if len(images_list) == 0:
|
||||||
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
|
pixel_values = torch.zeros((0, 3, self.base_size, self.base_size))
|
||||||
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
|
images_spatial_crop = torch.zeros((0, 2), dtype=torch.long)
|
||||||
images_crop = torch.zeros(
|
images_crop = torch.zeros((0, 3, self.image_size, self.image_size))
|
||||||
(1, 3, self.image_size, self.image_size)
|
|
||||||
).unsqueeze(0)
|
|
||||||
else:
|
else:
|
||||||
pixel_values = torch.stack(images_list, dim=0)
|
pixel_values = torch.stack(images_list, dim=0)
|
||||||
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
|
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
|
||||||
if images_crop_list:
|
if images_crop_list:
|
||||||
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
|
images_crop = torch.stack(images_crop_list, dim=0)
|
||||||
else:
|
else:
|
||||||
images_crop = torch.zeros(
|
images_crop = torch.zeros((0, 3, self.image_size, self.image_size))
|
||||||
(1, 3, self.image_size, self.image_size)
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
input_ids = input_ids.unsqueeze(0)
|
input_ids = input_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user