From 675aa2ec64b2d8ab45948f45cef80f74ebfadbbb Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 22 Oct 2025 22:59:15 +0800 Subject: [PATCH] [Model] Upstream Deepseek-OCR model (#27247) Signed-off-by: Isotr0py Signed-off-by: Roger Wang Co-authored-by: Roger Wang --- docs/models/supported_models.md | 1 + examples/offline_inference/vision_language.py | 89 ++- tests/models/registry.py | 3 + vllm/model_executor/models/deepencoder.py | 673 ++++++++++++++++++ vllm/model_executor/models/deepseek_ocr.py | 594 ++++++++++++++++ vllm/model_executor/models/deepseek_vl2.py | 43 +- vllm/model_executor/models/registry.py | 1 + .../chat_templates/registry.py | 1 + .../template_deepseek_ocr.jinja | 14 + .../processors/deepseek_ocr.py | 442 ++++++++++++ 10 files changed, 1821 insertions(+), 40 deletions(-) create mode 100644 vllm/model_executor/models/deepencoder.py create mode 100644 vllm/model_executor/models/deepseek_ocr.py create mode 100644 vllm/transformers_utils/chat_templates/template_deepseek_ocr.jinja create mode 100644 vllm/transformers_utils/processors/deepseek_ocr.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 001a5b96174ac..79892ac757b5e 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -639,6 +639,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | | `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I+ | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | | `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | +| `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I+ | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ | | `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I+/ V+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | | `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 35311a0ca7e1a..c5711ca9d0bce 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -30,6 +30,7 @@ class ModelRequestData(NamedTuple): prompts: list[str] stop_token_ids: list[int] | None = None lora_requests: list[LoRARequest] | None = None + sampling_params: list[SamplingParams] | None = None # NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on @@ -153,23 +154,6 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData: ) -# Dots-OCR -def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData: - assert modality == "image" - - prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions] - engine_args = EngineArgs( - model="rednote-hilab/dots.ocr", - limit_mm_per_prompt={modality: 1}, - trust_remote_code=True, - ) - - return ModelRequestData( - engine_args=engine_args, - prompts=prompts, - ) - - def run_command_a_vision(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -217,6 +201,66 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: ) +def run_deepseek_ocr(questions: list[str], modality: str) -> ModelRequestData: + from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor + + assert modality == "image" + + model_name = "deepseek-ai/DeepSeek-OCR" + + engine_args = EngineArgs( + model=model_name, + limit_mm_per_prompt={modality: 1}, + logits_processors=[NGramPerReqLogitsProcessor], + ) + + # deepseek-ocr use plain prompt template + prompts = [f"\n{question}" for question in questions] + + # The following sampling params config is taken from + # the official Deepseek-OCR inference example. + # (IMPORTANT) Use the custom logits processor and avoid skipping + # special tokens for this model for the optimal OCR performance. + sampling_params = [ + SamplingParams( + temperature=0.0, + max_tokens=8192, + # ngram logit processor args + extra_args=dict( + ngram_size=30, + window_size=90, + # whitelist: , + whitelist_token_ids={128821, 128822}, + ), + skip_special_tokens=False, + ) + for _ in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + sampling_params=sampling_params, + ) + + +# Dots-OCR +def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions] + engine_args = EngineArgs( + model="rednote-hilab/dots.ocr", + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Ernie4.5-VL def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData: model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT" @@ -1738,9 +1782,10 @@ model_example_map = { "bee": run_bee, "blip-2": run_blip2, "chameleon": run_chameleon, - "dots_ocr": run_dots_ocr, "command_a_vision": run_command_a_vision, "deepseek_vl_v2": run_deepseek_vl2, + "deepseek_ocr": run_deepseek_ocr, + "dots_ocr": run_dots_ocr, "ernie45_vl": run_ernie45_vl, "fuyu": run_fuyu, "gemma3": run_gemma3, @@ -2003,8 +2048,12 @@ def main(args): # We set temperature to 0.2 so that outputs can be different # even when all prompts are identical when running batch inference. - sampling_params = SamplingParams( - temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids + sampling_params = ( + SamplingParams( + temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids + ) + if req_data.sampling_params is None + else req_data.sampling_params ) assert args.num_prompts > 0 diff --git a/tests/models/registry.py b/tests/models/registry.py index 7345d2e07dc7b..bd5a4650081f4 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -585,6 +585,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { transformers_version_reason="HF model is not compatible.", hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}, ), + "DeepseekOCRForCausalLM": _HfExamplesInfo( + "deepseek-ai/DeepSeek-OCR", + ), "DotsOCRForCausalLM": _HfExamplesInfo( "rednote-hilab/dots.ocr", trust_remote_code=True ), diff --git a/vllm/model_executor/models/deepencoder.py b/vllm/model_executor/models/deepencoder.py new file mode 100644 index 0000000000000..e62a57eccc953 --- /dev/null +++ b/vllm/model_executor/models/deepencoder.py @@ -0,0 +1,673 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# adapted from +# https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek-OCR-master/DeepSeek-OCR-vllm/deepencoder/sam_vary_sdpa.py + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import math +from collections.abc import Iterable +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import CLIPVisionConfig + +from vllm.attention.layer import MultiHeadAttention +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from .clip import CLIPEncoder, CLIPVisionEmbeddings + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: type[nn.Module] = nn.LayerNorm, + act_layer: type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ # noqa: E501 + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: nn.Parameter | None = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros( + 1, img_size // patch_size, img_size // patch_size, embed_dim + ) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) + self.net_3 = nn.Conv2d( + 512, 1024, kernel_size=3, stride=2, padding=1, bias=False + ) + + def get_abs_pos(self, abs_pos: torch.Tensor, tgt_size: int): + dtype = abs_pos.dtype + + src_size = abs_pos.size(1) + + if src_size != tgt_size: + old_pos_embed = abs_pos.permute(0, 3, 1, 2) + old_pos_embed = old_pos_embed.to(torch.float32) + new_pos_embed = F.interpolate( + old_pos_embed, + size=(tgt_size, tgt_size), + mode="bicubic", + antialias=True, + align_corners=False, + ).to(dtype) + new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) + return new_pos_embed + else: + return abs_pos + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.get_abs_pos(self.pos_embed, x.size(1)) + + for blk in self.blocks: + x = blk(x) + + neck_output = self.neck(x.permute(0, 3, 1, 2)) + conv2_output = self.net_2(neck_output) + conv3_output = self.net_3(conv2_output) + + return conv3_output + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation + blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: type[nn.Module] = nn.LayerNorm, + act_layer: type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: tuple[int, int] | None = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ # noqa: E501 + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = RelPosAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock( + embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer + ) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class RelPosAttention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: tuple[int, int] | None = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ # noqa: E501 + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert input_size is not None, ( + "Input size must be provided if using relative positional encoding." + ) + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = ( + self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + rel_h, rel_w = None, None + if self.use_rel_pos: + rel_h, rel_w = add_decomposed_rel_pos( + q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W) + ) + + q = q.view(B, self.num_heads, H * W, -1) + k = k.view(B, self.num_heads, H * W, -1) + v = v.view(B, self.num_heads, H * W, -1) + + if self.use_rel_pos: + rel_h = rel_h.view( + B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3) + ) + rel_w = rel_w.view( + B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3) + ) + attn_bias = (rel_h + rel_w).view( + B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4) + ) + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_bias + ) + else: + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + + x = ( + x.view(B, self.num_heads, H, W, -1) + .permute(0, 2, 3, 1, 4) + .reshape(B, H, W, -1) + ) + + x = self.proj(x) + + return x + + +def window_partition( + x: torch.Tensor, window_size: int +) -> tuple[torch.Tensor, tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ # noqa: E501 + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, + window_size: int, + pad_hw: tuple[int, int], + hw: tuple[int, int], +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ # noqa: E501 + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + dtype = rel_pos.dtype + rel_pos = rel_pos.to(torch.float32) + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ).to(dtype) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max( + k_size / q_size, 1.0 + ) + k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max( + q_size / k_size, 1.0 + ) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: tuple[int, int], + k_size: tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + Args: + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ # noqa: E501 + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + rel_h = rel_h.unsqueeze(-1) + rel_w = rel_w.unsqueeze(-2) + rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1) + rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w) + + return rel_h, rel_w + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: tuple[int, int] = (16, 16), + stride: tuple[int, int] = (16, 16), + padding: tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x + + +# TODO(Isotr0py): use vision_config to build sam model +def build_sam_vit_b(): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + ) + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_encoder = ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ) + return image_encoder + + +class DeepCLIPVisionEmbeddings(CLIPVisionEmbeddings): + def get_abs_pos(self, abs_pos: torch.Tensor, tgt_size: int): + # abs_pos: L, C + # tgt_size: M + # return: M, C + dim = abs_pos.size(-1) + abs_pos_new = abs_pos.squeeze(0) + cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:] + + src_size = int(math.sqrt(abs_pos_new.shape[0] - 1)) + tgt_size = int(math.sqrt(tgt_size)) + dtype = abs_pos.dtype + + if src_size != tgt_size: + old_pos_embed = ( + old_pos_embed.view(1, src_size, src_size, dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + old_pos_embed = old_pos_embed.to(torch.float32) + new_pos_embed = F.interpolate( + old_pos_embed, + size=(tgt_size, tgt_size), + mode="bicubic", + antialias=True, + align_corners=False, + ).to(dtype) + new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) + new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim) + vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0) + vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim) + return vision_pos_embed + else: + return abs_pos + + def forward( + self, pixel_values: torch.Tensor, patch_embeds: torch.Tensor | None = None + ) -> torch.Tensor: + batch_size = pixel_values.shape[0] + if patch_embeds is not None: + patch_embeds = patch_embeds + else: + patch_embeds = self.patch_embedding(pixel_values) + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.get_abs_pos( + self.position_embedding(self.position_ids), embeddings.size(1) + ) + return embeddings + + +class DeepCLIPVisionTransformer(nn.Module): + def __init__( + self, + config: CLIPVisionConfig, + quant_config: QuantizationConfig | None = None, + *, + num_hidden_layers_override: int | None = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + embed_dim = config.hidden_size + + self.embeddings = DeepCLIPVisionEmbeddings(config) + + # NOTE: This typo of "layrnorm" is not fixed on purpose to match + # the original transformers code and name of the model weights. + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.transformer = CLIPEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + prefix=f"{prefix}.encoder", + attn_cls=MultiHeadAttention, + ) + + num_hidden_layers = config.num_hidden_layers + if len(self.transformer.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {num_hidden_layers} " + f"layers, but you requested {len(self.transformer.layers)} layers." + ) + + @property + def dtype(self): + return next(self.parameters()).dtype + + @property + def device(self): + return next(self.parameters()).device + + def forward( + self, + pixel_values: torch.Tensor, + patch_embeds: torch.Tensor | None = None, + *, + select_layers: list[int] | None = None, + ) -> torch.Tensor: + hidden_states = self.embeddings(pixel_values, patch_embeds) + hidden_states = self.pre_layrnorm(hidden_states) + + # Produces either the last layer output or all of the hidden states, + # depending on if we have select_layers or not + encoder_outputs = self.transformer( + inputs_embeds=hidden_states, + return_all_hidden_states=select_layers is not None, + ) + return encoder_outputs + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/deepseek_ocr.py b/vllm/model_executor/models/deepseek_ocr.py new file mode 100644 index 0000000000000..c9064dabc0ab3 --- /dev/null +++ b/vllm/model_executor/models/deepseek_ocr.py @@ -0,0 +1,594 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Deepseek-OCR model compatible with HuggingFace weights.""" + +import math +from collections.abc import Iterable, Mapping, Sequence + +import torch +import torch.nn as nn +from transformers import BatchFeature, CLIPVisionConfig + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, + SupportsMultiModal, + SupportsPP, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargs, + NestedTensors, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sampling_params import SamplingParams +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config +from vllm.transformers_utils.processors.deepseek_ocr import ( + BASE_SIZE, + CROP_MODE, + IMAGE_SIZE, + DeepseekOCRProcessor, + count_tiles, +) +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.v1.sample.logits_processor import ( + AdapterLogitsProcessor, + RequestLogitsProcessor, +) + +from .deepencoder import DeepCLIPVisionTransformer, build_sam_vit_b +from .deepseek_vl2 import MlpProjector + +# The image token id may be various +_IMAGE_TOKEN = "" + + +class NoRepeatNGramLogitsProcessor: + def __init__( + self, + ngram_size: int, + window_size: int, + whitelist_token_ids: set[int] | None = None, + ): + self.ngram_size = ngram_size + self.window_size = window_size + self.whitelist_token_ids = whitelist_token_ids or set() + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + if len(output_ids) < self.ngram_size: + return logits + + current_prefix = tuple(output_ids[-(self.ngram_size - 1) :]) + + search_start = max(0, len(output_ids) - self.window_size) + search_end = len(output_ids) - self.ngram_size + 1 + + banned_tokens = set() + for i in range(search_start, search_end): + ngram = tuple(output_ids[i : i + self.ngram_size]) + if ngram[:-1] == current_prefix: + banned_tokens.add(ngram[-1]) + + banned_tokens = banned_tokens - self.whitelist_token_ids + + if banned_tokens: + logits[list(banned_tokens)] = -float("inf") + + return logits + + +class NGramPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of overriding the wrapper class `__init__()` in order to utilize + info about the device type""" + + def __init__( + self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool + ): + super().__init__(vllm_config, device, is_pin_memory) + + def is_argmax_invariant(self) -> bool: + return True + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> RequestLogitsProcessor | None: + ngram_size = params.extra_args and params.extra_args.get("ngram_size") + window_size = params.extra_args and params.extra_args.get("window_size", 100) + whitelist_token_ids = params.extra_args and params.extra_args.get( + "whitelist_token_ids", None + ) + if ngram_size is None: + return None + if not isinstance(ngram_size, int) or ngram_size <= 0: + raise ValueError( + f"`ngram_size` has to be a strictly positive integer, got {ngram_size}." + ) + if not isinstance(window_size, int) or window_size <= 0: + raise ValueError( + "`window_size` has to be a strictly positive integer, " + f"got {window_size}." + ) + if whitelist_token_ids is not None and not isinstance( + whitelist_token_ids, Iterable + ): + raise ValueError( + "`whitelist_token_ids` has to be a set of integers, " + f"got {whitelist_token_ids}." + ) + else: + whitelist_token_ids = ( + set(whitelist_token_ids) if whitelist_token_ids else None + ) + return NoRepeatNGramLogitsProcessor( + ngram_size=ngram_size, + window_size=window_size, + whitelist_token_ids=whitelist_token_ids, + ) + + +class DeepseekOCRProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(DeepseekVLV2Config) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(DeepseekOCRProcessor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None} + + def get_num_image_tokens( + self, *, image_width: int, image_height: int, cropping: bool = True + ) -> int: + image_size = IMAGE_SIZE + base_size = BASE_SIZE + patch_size = 16 + downsample_ratio = 4 + + if CROP_MODE: + if image_width <= 640 and image_height <= 640: + crop_ratio = [1, 1] + else: + # find the closest aspect ratio to the target + crop_ratio = count_tiles( + image_width, image_height, image_size=IMAGE_SIZE + ) + + num_width_tiles, num_height_tiles = crop_ratio + else: + num_width_tiles = num_height_tiles = 1 + + h = w = math.ceil((base_size // patch_size) / downsample_ratio) + + h2 = w2 = math.ceil((image_size // patch_size) / downsample_ratio) + + global_views_tokens = h * (w + 1) + if num_width_tiles > 1 or num_height_tiles > 1: + local_views_tokens = (num_height_tiles * h2) * (num_width_tiles * w2 + 1) + else: + local_views_tokens = 0 + + return global_views_tokens + local_views_tokens + 1 + + def get_image_size_with_most_features(self) -> ImageSize: + if IMAGE_SIZE == 1024 and BASE_SIZE == 1280: + return ImageSize(width=1024 * 2, height=1024 * 2) + return ImageSize(width=640 * 2, height=640 * 2) + + +class DeepseekOCRDummyInputsBuilder(BaseDummyInputsBuilder[DeepseekOCRProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + max_image_size = self.info.get_image_size_with_most_features() + + return { + "image": self._get_dummy_images( + width=max_image_size.width, + height=max_image_size.height, + num_images=num_images, + ) + } + + +class DeepseekOCRMultiModalProcessor( + BaseMultiModalProcessor[DeepseekOCRProcessingInfo] +): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + processed_outputs = self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(prompt=prompt, **mm_data), + mm_kwargs, + ) + + else: + tokenizer = self.info.get_tokenizer() + processed_outputs = tokenizer( + prompt, add_special_tokens=True, return_tensors="pt" + ) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + images_spatial_crop=MultiModalFieldConfig.batched("image"), + images_crop=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + image_token_id = hf_processor.image_token_id + assert isinstance(image_token_id, int) + + def get_replacement_deepseek_vl2(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + size = images.get_image_size(item_idx) + + num_image_tokens = self.info.get_num_image_tokens( + image_width=size.width, + image_height=size.height, + cropping=CROP_MODE, + ) + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement_deepseek_vl2, + ) + ] + + # TODO(Isotr0py): Check if we still need this workaround for + # deepseek-ocr processor. + # def _cached_apply_hf_processor( + # self, + # prompt: str | list[int], + # mm_data_items: MultiModalDataItems, + # hf_processor_mm_kwargs: Mapping[str, object], + # tokenization_kwargs: Mapping[str, object], + # mm_uuids: MultiModalUUIDDict | None = None, + # ) -> tuple[list[int], MultiModalKwargs, bool]: + # # The processor logic is different for len(images) <= 2 vs > 2 + # # Since the processing cache assumes that the processor output is + # # invariant of how many images are passed per prompt, we only + # # perform caching for the most common case + # if mm_data_items.get_count("image", strict=False) > 2: + # # This code path corresponds to the cache being disabled + # return self._apply_hf_processor_main( + # prompt=prompt, + # mm_items=mm_data_items, + # hf_processor_mm_kwargs=hf_processor_mm_kwargs, + # enable_hf_prompt_update=True, + # ) + + # return super()._cached_apply_hf_processor( + # prompt=prompt, + # mm_data_items=mm_data_items, + # hf_processor_mm_kwargs=hf_processor_mm_kwargs, + # ) + + +@MULTIMODAL_REGISTRY.register_processor( + DeepseekOCRMultiModalProcessor, + info=DeepseekOCRProcessingInfo, + dummy_inputs=DeepseekOCRDummyInputsBuilder, +) +class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # map prefix for language backbone + "model.embed_tokens.": "language_model.model.embed_tokens.", + "model.layers.": "language_model.model.layers.", + "model.norm.": "language_model.model.norm.", + "lm_head.": "language_model.lm_head.", + # remove "model." prefix for other components + "model.": "", + } + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return "" + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: DeepseekVLV2Config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.vision_config = config.vision_config + self.projector_config = config.projector_config + self.text_config = config.text_config + + model_config = vllm_config.model_config + tokenizer = cached_tokenizer_from_config(model_config) + self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN] + + self.sam_model = build_sam_vit_b() + clip_vision_config = CLIPVisionConfig( + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + num_hidden_layers=24, + image_size=224, + patch_size=14, + projection_dim=512, + layer_norm_eps=1e-5, + ) + self.vision_model = DeepCLIPVisionTransformer( + config=clip_vision_config, + quant_config=quant_config, + ) + + self.projector = MlpProjector(self.projector_config) + self.tile_tag = config.tile_tag + self.global_view_pos = config.global_view_pos + + # special token for image token sequence format + n_embed = self.projector_config.n_embed + embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32)) + if self.tile_tag == "2D": + # <|view_separator|>, <|\n|> + self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std) + # This is a typo in original implementation + self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std) + else: + raise ValueError( + f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" + ) + + if self.text_config.topk_method == "noaux_tc": + architectures = ["DeepseekV3ForCausalLM"] + elif not self.text_config.use_mla: + architectures = ["DeepseekForCausalLM"] + else: + architectures = ["DeepseekV2ForCausalLM"] + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=self.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=architectures, + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def _parse_and_validate_image_input(self, **kwargs: object): + pixel_values = kwargs.pop("pixel_values", None) + images_spatial_crop = kwargs.pop("images_spatial_crop", None) + images_crop = kwargs.pop("images_crop", None) + + if pixel_values is None or torch.sum(pixel_values).item() == 0: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError( + f"Incorrect type of pixel values. Got type: {type(pixel_values)}" + ) + + if not isinstance(images_spatial_crop, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of image sizes. " + f"Got type: {type(images_spatial_crop)}" + ) + + if not isinstance(images_crop, (torch.Tensor, list)): + raise ValueError( + f"Incorrect type of image crop. Got type: {type(images_crop)}" + ) + + return [pixel_values, images_crop, images_spatial_crop] + + raise AssertionError("This line should be unreachable.") + + def _encode_global_features(self, image_tensor: torch.Tensor) -> torch.Tensor: + global_features_1 = self.sam_model(image_tensor) + global_features_2 = self.vision_model(image_tensor, global_features_1) + features = torch.cat( + ( + global_features_2[:, 1:], + global_features_1.flatten(2).permute(0, 2, 1), + ), + dim=-1, + ) + features = self.projector(features) + + _, hw, dim = features.shape + side = int(hw**0.5) + + features = features.view(side, side, dim) + newline = self.image_newline[None, None, :].expand(side, 1, dim) + features = torch.cat([features, newline], dim=1) + return features.view(-1, dim) + + def _encode_local_features( + self, patches: torch.Tensor, crop_shape: torch.Tensor + ) -> torch.Tensor | None: + if torch.sum(patches).item() == 0: + return None + + local_features_1 = self.sam_model(patches) + local_features_2 = self.vision_model(patches, local_features_1) + features = torch.cat( + ( + local_features_2[:, 1:], + local_features_1.flatten(2).permute(0, 2, 1), + ), + dim=-1, + ) + features = self.projector(features) + + _, hw, dim = features.shape + patch_side = int(hw**0.5) + + width_tiles = int(crop_shape[0].item()) + height_tiles = int(crop_shape[1].item()) + + features = ( + features.view(height_tiles, width_tiles, patch_side, patch_side, dim) + .permute(0, 2, 1, 3, 4) + .reshape(height_tiles * patch_side, width_tiles * patch_side, dim) + ) + newline = self.image_newline[None, None, :].expand( + height_tiles * patch_side, 1, dim + ) + features = torch.cat([features, newline], dim=1) + + return features.view(-1, dim) + + def _pixel_values_to_embedding( + self, + pixel_values: torch.Tensor, + images_crop: torch.Tensor, + images_spatial_crop: torch.Tensor, + ) -> NestedTensors: + images_in_this_batch = [] + + for jdx in range(images_spatial_crop.size(0)): + patches = images_crop[jdx][0].to(torch.bfloat16) + image_ori = pixel_values[jdx] + crop_shape = images_spatial_crop[jdx][0] + + global_features = self._encode_global_features(image_ori) + local_features = self._encode_local_features(patches, crop_shape) + + if local_features is not None: + combined = torch.cat( + [local_features, global_features, self.view_seperator[None, :]], + dim=0, + ) + else: + combined = torch.cat( + [global_features, self.view_seperator[None, :]], dim=0 + ) + + images_in_this_batch.append(combined) + + return images_in_this_batch + + def _process_image_input(self, image_input) -> torch.Tensor: + pixel_values = image_input[0].to(torch.bfloat16) + images_crop = image_input[1] + images_spatial_crop = image_input[2].to(dtype=torch.long) + + vision_features = self._pixel_values_to_embedding( + pixel_values=pixel_values, + images_crop=images_crop, + images_spatial_crop=images_spatial_crop, + ) + + return vision_features + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object + ) -> MultiModalEmbeddings | None: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ): + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + return autoloaded_weights diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 3fc8187278c83..ea10245a84ee1 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -101,9 +101,10 @@ class MlpProjector(nn.Module): super().__init__() self.cfg = cfg + self.projector_type = cfg.projector_type assert not cfg.token_pooling, "Token pooling is not supported currently." - if cfg.projector_type == "downsample_mlp_gelu": + if self.projector_type == "downsample_mlp_gelu": mlp_depth = cfg.depth mlp_ratio = cfg.mlp_ratio modules = [ @@ -120,7 +121,8 @@ class MlpProjector(nn.Module): modules.append(nn.GELU()) modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) modules = nn.Sequential(*modules) - + elif self.projector_type == "linear": + modules = nn.Linear(cfg.input_dim, cfg.n_embed) else: raise NotImplementedError( f"Unsupported projector type: {cfg.projector_type}" @@ -130,24 +132,25 @@ class MlpProjector(nn.Module): def forward(self, x): bs, hw, input_dim = x.shape - h = w = int((hw) ** 0.5) - """compute padding""" - if h % self.cfg.downsample_ratio: - pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio - else: - pad = 0 - x = x.reshape(bs, h, w, input_dim) - if pad > 0: - x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) - """4 to 1 concat""" - x = x.permute(0, 3, 1, 2) # B, C, H, W - x = F.unfold( - x, - kernel_size=self.cfg.downsample_ratio, - stride=self.cfg.downsample_ratio, - padding=0, - ) # B, C*4, HW // 4 - x = x.permute(0, 2, 1) + if self.projector_type == "downsample_mlp_gelu": + h = w = int((hw) ** 0.5) + """compute padding""" + if h % self.cfg.downsample_ratio: + pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio + else: + pad = 0 + x = x.reshape(bs, h, w, input_dim) + if pad > 0: + x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) + """4 to 1 concat""" + x = x.permute(0, 3, 1, 2) # B, C, H, W + x = F.unfold( + x, + kernel_size=self.cfg.downsample_ratio, + stride=self.cfg.downsample_ratio, + padding=0, + ) # B, C*4, HW // 4 + x = x.permute(0, 2, 1) return self.layers(x) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index da1606a7568dd..617854c8548fc 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -258,6 +258,7 @@ _MULTIMODAL_MODELS = { "Cohere2VisionForConditionalGeneration", ), "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), + "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"), "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"), "Ernie4_5_VLMoeForConditionalGeneration": ( "ernie45_vl", diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index afeac2335dc77..dbb4ffb675b8b 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -34,6 +34,7 @@ _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "clip": CHAT_TEMPLATES_DIR / "template_basic.jinja", "chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja", "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", + "deepseek_ocr": CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja", "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", "minicpmv": _get_minicpmv_chat_template_fallback, "paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja", diff --git a/vllm/transformers_utils/chat_templates/template_deepseek_ocr.jinja b/vllm/transformers_utils/chat_templates/template_deepseek_ocr.jinja new file mode 100644 index 0000000000000..287abe3586425 --- /dev/null +++ b/vllm/transformers_utils/chat_templates/template_deepseek_ocr.jinja @@ -0,0 +1,14 @@ +{%- if messages[0]['role'] == 'system' -%} + {%- set system_message = messages[0]['content'] -%} + {%- set messages = messages[1:] -%} +{%- else -%} + {% set system_message = '' -%} +{%- endif -%} + +{{ bos_token + system_message }} +{%- for message in messages -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {%- endif -%} + {{ message['content'] }} +{%- endfor -%} diff --git a/vllm/transformers_utils/processors/deepseek_ocr.py b/vllm/transformers_utils/processors/deepseek_ocr.py new file mode 100644 index 0000000000000..99f2df3342e02 --- /dev/null +++ b/vllm/transformers_utils/processors/deepseek_ocr.py @@ -0,0 +1,442 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# adapted from https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek-OCR-master/DeepSeek-OCR-vllm/process/image_process.py +import math + +import torch +import torchvision.transforms as T +from PIL import Image, ImageOps +from transformers import AutoProcessor, BatchFeature, LlamaTokenizerFast +from transformers.processing_utils import ProcessorMixin + +# TODO(Isotr0py): change modes for variants +# see: https://github.com/deepseek-ai/DeepSeek-OCR/blob/8cf003d38821fa1b19c73da3bd1b0dc262ea8136/DeepSeek-OCR-master/DeepSeek-OCR-vllm/config.py#L1-L6 +# Tiny: base_size = 512, image_size = 512, crop_mode = False +# Small: base_size = 640, image_size = 640, crop_mode = False +# Base: base_size = 1024, image_size = 1024, crop_mode = False +# Large: base_size = 1280, image_size = 1280, crop_mode = False +# Gundam: base_size = 1024, image_size = 640, crop_mode = True +BASE_SIZE = 1024 +IMAGE_SIZE = 640 +CROP_MODE = True + +# TODO(Isotr0py): Expose as mm_kwargs +MIN_CROPS = 2 +MAX_CROPS = 6 # max:9; If your GPU memory is small, it is recommended to set it to 6. + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def calculate_aspect_ratios( + min_num: int = MIN_CROPS, max_num: int = MAX_CROPS +) -> list[tuple[int, int]]: + target_ratios: set[tuple[int, int]] = set( + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + ) + sorted_target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + return sorted_target_ratios + + +def count_tiles( + orig_width, + orig_height, + min_num=MIN_CROPS, + max_num=MAX_CROPS, + image_size=640, + use_thumbnail=False, +): + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = calculate_aspect_ratios(min_num, max_num) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + + return target_aspect_ratio + + +def dynamic_preprocess( + image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False +): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = calculate_aspect_ratios(min_num, max_num) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images, target_aspect_ratio + + +class ImageTransform: + def __init__( + self, + mean: tuple[float, float, float] = (0.5, 0.5, 0.5), + std: tuple[float, float, float] = (0.5, 0.5, 0.5), + normalize: bool = True, + ): + self.mean = mean + self.std = std + self.normalize = normalize + + transform_pipelines = [T.ToTensor()] + + if normalize: + transform_pipelines.append(T.Normalize(mean, std)) + + self.transform = T.Compose(transform_pipelines) + + def __call__(self, pil_img: Image.Image): + x = self.transform(pil_img) + return x + + +class DeepseekOCRProcessor(ProcessorMixin): + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + attributes = ["tokenizer"] + + def __init__( + self, + tokenizer: LlamaTokenizerFast, + patch_size: int = 16, + downsample_ratio: int = 4, + image_mean: tuple[float, float, float] = (0.5, 0.5, 0.5), + image_std: tuple[float, float, float] = (0.5, 0.5, 0.5), + normalize: bool = True, + image_token: str = "", + pad_token: str = "<|▁pad▁|>", + add_special_token: bool = False, + sft_format: str = "deepseek", + mask_prompt: bool = True, + ignore_id: int = -100, + **kwargs, + ): + self.image_size = IMAGE_SIZE + self.base_size = BASE_SIZE + self.patch_size = 16 + self.image_mean = image_mean + self.image_std = image_std + self.normalize = normalize + self.downsample_ratio = 4 + + self.image_transform = ImageTransform( + mean=image_mean, std=image_std, normalize=normalize + ) + + self.tokenizer = tokenizer + self.tokenizer.padding_side = "left" # must set this,padding side with make a difference in batch inference # noqa: E501 + + # add the pad_token as special token to use 'tokenizer.pad_token' + # and 'tokenizer.pad_token_id' + if self.tokenizer.pad_token is None: + self.tokenizer.add_special_tokens({"pad_token": pad_token}) + + # add image token + self.image_token_id = self.tokenizer.vocab.get(image_token) + self.image_token = image_token + self.pad_token = pad_token + self.add_special_token = add_special_token + self.sft_format = sft_format + self.mask_prompt = mask_prompt + self.ignore_id = ignore_id + + super().__init__( + tokenizer, + **kwargs, + ) + + @property + def bos_id(self): + return self.tokenizer.bos_token_id + + @property + def eos_id(self): + return self.tokenizer.eos_token_id + + @property + def pad_id(self): + return self.tokenizer.pad_token_id + + def encode(self, text: str, bos: bool = True, eos: bool = False): + t = self.tokenizer.encode(text, add_special_tokens=False) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: list[int], **kwargs) -> str: + return self.tokenizer.decode(t, **kwargs) + + def process_one( + self, + prompt: str, + images: list[Image.Image], + crop_mode: bool = CROP_MODE, + ): + """ + + Args: + prompt (str): the formatted prompt; + images (List[ImageType]): the list of images; + crop_mode (bool): if True, then crop the image; + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - target_ids (torch.LongTensor): [N + image tokens] + - pixel_values (torch.FloatTensor): [n_patches, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + assert prompt is not None and images is not None, ( + "prompt and images must be used at the same time." + ) + + sft_format = prompt + + ( + input_ids, + pixel_values, + images_crop, + images_seq_mask, + images_spatial_crop, + num_image_tokens, + _, + ) = self.tokenize_with_images( + conversation=sft_format, + images=images, + bos=True, + eos=True, + cropping=crop_mode, + ) + + prepare = BatchFeature( + data=dict( + input_ids=input_ids, + pixel_values=pixel_values, + images_crop=images_crop, + images_seq_mask=images_seq_mask, + images_spatial_crop=images_spatial_crop, + num_image_tokens=num_image_tokens, + ), + tensor_type="pt", + ) + return prepare + + def __call__( + self, + *, + prompt: str, + images: list[Image.Image], + crop_mode: bool = CROP_MODE, + **kwargs, + ): + prepare = self.process_one( + prompt=prompt, + images=images, + crop_mode=crop_mode, + ) + + return prepare + + def tokenize_with_images( + self, + conversation: str, + images: list[Image.Image], + bos: bool = True, + eos: bool = True, + cropping: bool = True, + ): + """Tokenize text with tags.""" + + assert conversation.count(self.image_token) == len(images) + text_splits = conversation.split(self.image_token) + images_list, images_crop_list, images_seq_mask, images_spatial_crop = ( + [], + [], + [], + [], + ) + image_shapes = [] + num_image_tokens = [] + tokenized_str = [] + for text_sep, image in zip(text_splits, images): + tokenized_sep = self.encode(text_sep, bos=False, eos=False) + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + image_shapes.append(image.size) + + images_crop_raw = [] + if image.size[0] <= 640 and image.size[1] <= 640: + crop_ratio = [1, 1] + elif cropping: + images_crop_raw, crop_ratio = dynamic_preprocess( + image, image_size=IMAGE_SIZE + ) + else: + crop_ratio = [1, 1] + + if self.image_size <= 640 and not cropping: + image = image.resize((self.image_size, self.image_size)) + + global_view = ImageOps.pad( + image, + (self.base_size, self.base_size), + color=tuple(int(x * 255) for x in self.image_transform.mean), + ) + images_list.append(self.image_transform(global_view)) + + num_width_tiles, num_height_tiles = crop_ratio + images_spatial_crop.append([num_width_tiles, num_height_tiles]) + + if num_width_tiles > 1 or num_height_tiles > 1: + for cropped_image in images_crop_raw: + images_crop_list.append(self.image_transform(cropped_image)) + + num_queries = math.ceil( + (self.image_size // self.patch_size) / self.downsample_ratio + ) + num_queries_base = math.ceil( + (self.base_size // self.patch_size) / self.downsample_ratio + ) + + tokenized_image = ( + [self.image_token_id] * num_queries_base + [self.image_token_id] + ) * num_queries_base + tokenized_image += [self.image_token_id] + if num_width_tiles > 1 or num_height_tiles > 1: + local_row = [self.image_token_id] * (num_queries * num_width_tiles + 1) + tokenized_image += local_row * (num_queries * num_height_tiles) + tokenized_str += tokenized_image + images_seq_mask += [True] * len(tokenized_image) + num_image_tokens.append(len(tokenized_image)) + + """process the last text split""" + tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False) + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + """add the bos and eos tokens""" + if bos: + tokenized_str = [self.bos_id] + tokenized_str + images_seq_mask = [False] + images_seq_mask + if eos: + tokenized_str = tokenized_str + [self.eos_id] + images_seq_mask = images_seq_mask + [False] + + assert len(tokenized_str) == len(images_seq_mask), ( + f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} " + f"is not equal to images_seq_mask's length {len(images_seq_mask)}." + ) + + masked_tokenized_str = [] + for token_index in tokenized_str: + if token_index != self.image_token_id: + masked_tokenized_str.append(token_index) + else: + masked_tokenized_str.append(self.ignore_id) + + assert ( + len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str) + ), ( + f"tokenized_str's length {len(tokenized_str)}, " + f"input_ids' length {len(masked_tokenized_str)}, " + f"images_seq_mask's length {len(images_seq_mask)}, are not equal." + ) + + input_ids = torch.LongTensor(tokenized_str) + target_ids = torch.LongTensor(masked_tokenized_str) + images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) + + # set input_ids < 0 | input_ids == self.image_token_id as ignore_id + target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = ( + self.ignore_id + ) + input_ids[input_ids < 0] = self.pad_id + + # Remove the ending eos token + assert input_ids[-1] == self.eos_id + input_ids = input_ids[:-1] + target_ids = target_ids[:-1] + images_seq_mask = images_seq_mask[:-1] + + if len(images_list) == 0: + pixel_values = torch.zeros((1, 3, self.base_size, self.base_size)) + images_spatial_crop = torch.zeros((1, 1), dtype=torch.long) + images_crop = torch.zeros( + (1, 3, self.image_size, self.image_size) + ).unsqueeze(0) + else: + pixel_values = torch.stack(images_list, dim=0) + images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) + if images_crop_list: + images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0) + else: + images_crop = torch.zeros( + (1, 3, self.image_size, self.image_size) + ).unsqueeze(0) + + input_ids = input_ids.unsqueeze(0) + + return ( + input_ids, + pixel_values, + images_crop, + images_seq_mask, + images_spatial_crop, + num_image_tokens, + image_shapes, + ) + + +AutoProcessor.register("DeepseekOCRProcessor", DeepseekOCRProcessor)