From 1d6f767dc491a281bf853d9e61b54f8e990499a8 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 25 Sep 2025 07:58:08 +0800 Subject: [PATCH] [Model] Improve DotsOCRForCausalLM (#25466) Signed-off-by: Jee Jee Li Signed-off-by: yewentao256 --- vllm/model_executor/models/dots_ocr.py | 237 +++++++++++++++---------- 1 file changed, 143 insertions(+), 94 deletions(-) diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 04fa5584199a3..2db350c892ae7 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -7,11 +7,13 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import LayerNorm -from transformers.modeling_utils import PreTrainedModel from transformers.models.qwen2_vl import Qwen2VLProcessor from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig +from vllm.distributed import utils as dist_utils +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -19,10 +21,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, + SupportsLoRA, SupportsMultiModal, SupportsPP) +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM +from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder, Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo) @@ -38,6 +44,8 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig, DotsVisionConfig) +from .vision import run_dp_sharded_mrope_vision_model + IMAGE_TOKEN = "<|imgpad|>" @@ -181,6 +189,8 @@ class PatchMerger(nn.Module): context_dim: int, spatial_merge_size: int = 2, pre_norm="layernorm", + prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) @@ -189,21 +199,21 @@ class PatchMerger(nn.Module): self.ln_q = LayerNorm(context_dim, eps=1e-6) elif self.pre_norm == "rmsnorm": self.ln_q = RMSNorm(context_dim, eps=1e-6) - else: - print("no norm in patch merger") self.mlp = nn.Sequential( ColumnParallelLinear(self.hidden_size, self.hidden_size, bias=True, return_bias=False, - disable_tp=True), + prefix=f"{prefix}.0", + disable_tp=use_data_parallel), nn.GELU(), RowParallelLinear(self.hidden_size, dim, bias=True, return_bias=False, - disable_tp=True), + prefix=f"{prefix}.2", + disable_tp=use_data_parallel), ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -223,38 +233,36 @@ class DotsVisionAttention(nn.Module): bias: bool = True, *, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + prefix: str = "", + use_data_parallel: bool = False) -> None: super().__init__() - from vllm.distributed import (parallel_state, - tensor_model_parallel_all_gather) - from vllm.distributed import utils as dist_utils self.embed_dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.tp_size = parallel_state.get_tensor_model_parallel_world_size() - self.tp_rank = parallel_state.get_tensor_model_parallel_rank() - self.num_heads_per_partition = dist_utils.divide( + self.tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) + self.tp_rank = (0 if use_data_parallel else + get_tensor_model_parallel_rank()) + self.hidden_size_per_attention_head = dist_utils.divide(dim, num_heads) + self.num_attention_heads_per_partition = dist_utils.divide( num_heads, self.tp_size) - # qkv/proj follow Qwen2-VL style; bias controlled by arg - self.qkv = QKVParallelLinear(hidden_size=dim, - head_size=dim // num_heads, - total_num_heads=num_heads, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv") + self.qkv = QKVParallelLinear( + hidden_size=dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + disable_tp=use_data_parallel) self.proj = RowParallelLinear(input_size=dim, output_size=dim, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.proj") - self._all_gather = tensor_model_parallel_all_gather - self._split_last = dist_utils.split_tensor_along_last_dim - + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel) # Select attention backend - self.attn_backend = get_vit_attn_backend(self.head_dim, - torch.get_default_dtype()) + self.attn_backend = get_vit_attn_backend( + self.hidden_size_per_attention_head, torch.get_default_dtype()) self.use_upstream_fa = False if self.attn_backend != _Backend.FLASH_ATTN and \ check_upstream_fa_availability(torch.get_default_dtype()): @@ -270,19 +278,6 @@ class DotsVisionAttention(nn.Module): _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA } - def _split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: - # qkv: [S, B, 3*dim] - seq_len, bs, _ = qkv.shape - if self.tp_size > 1: - qkv = self._all_gather(qkv) - q, k, v = qkv.chunk(3, dim=2) - if self.tp_size > 1: - q = self._split_last(q, num_partitions=self.tp_size)[self.tp_rank] - k = self._split_last(k, num_partitions=self.tp_size)[self.tp_rank] - v = self._split_last(v, num_partitions=self.tp_size)[self.tp_rank] - new_shape = (seq_len, bs, self.num_heads_per_partition, self.head_dim) - return (q.view(*new_shape), k.view(*new_shape), v.view(*new_shape)) - def forward( self, hidden_states: torch.Tensor, @@ -295,7 +290,7 @@ class DotsVisionAttention(nn.Module): # [S, C] -> [S, B=1, C] x = hidden_states.unsqueeze(1) x, _ = self.qkv(x) - q, k, v = self._split_qkv(x) + q, k, v = Qwen2_5_VisionAttention.split_qkv(self, x) bs = q.shape[1] # [S,B,H,D] -> [B,S,H,D] q = q.permute(1, 0, 2, 3).contiguous() @@ -327,8 +322,9 @@ class DotsVisionAttention(nn.Module): max_seqlen_k=max_seqlen, dropout_p=0.0, causal=False) - context_layer = output.view(bs, -1, self.num_heads_per_partition, - self.head_dim) + context_layer = output.view(bs, -1, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) elif self.attn_backend == _Backend.TORCH_SDPA: outputs = [] for i in range(1, len(cu_seqlens)): @@ -368,7 +364,8 @@ class DotsSwiGLUFFN(nn.Module): config, *, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() hidden_features = config.intermediate_size in_features = config.embed_dim @@ -380,13 +377,13 @@ class DotsSwiGLUFFN(nn.Module): bias=bias, quant_config=quant_config, prefix=f"{prefix}.fc13", - disable_tp=True) + disable_tp=use_data_parallel) self.fc2 = RowParallelLinear(hidden_features, in_features, bias=bias, quant_config=quant_config, prefix=f"{prefix}.fc2", - disable_tp=True) + disable_tp=use_data_parallel) self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -397,28 +394,36 @@ class DotsSwiGLUFFN(nn.Module): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - params = dict(self.named_parameters()) - loaded: set[str] = set() - for name, w in weights: - # Map fc1 -> fc13 (shard 0) - if name.startswith("fc1."): - tgt = name.replace("fc1.", "fc13.") - if tgt in params: - params[tgt].weight_loader(params[tgt], w, 0) - loaded.add(tgt) - continue - # Map fc3 -> fc13 (shard 1) - if name.startswith("fc3."): - tgt = name.replace("fc3.", "fc13.") - if tgt in params: - params[tgt].weight_loader(params[tgt], w, 1) - loaded.add(tgt) - continue - # Pass-through for fc2 and others - if name in params: - params[name].weight_loader(params[name], w) - loaded.add(name) - return loaded + stacked_params_mapping = [ + ("fc13", "fc1", 0), + ("fc13", "fc3", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params class DotsPatchEmbed(nn.Module): @@ -463,25 +468,28 @@ class DotsViTPreprocessor(nn.Module): class DotsVisionBlock(nn.Module): - def __init__(self, - config, - *, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config, + *, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() - self.attn = DotsVisionAttention( - config, - config.embed_dim, - num_heads=config.num_attention_heads, - bias=config.use_bias, - quant_config=quant_config, - prefix=f"{prefix}.attn", - ) + self.attn = DotsVisionAttention(config, + config.embed_dim, + num_heads=config.num_attention_heads, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel) self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) self.mlp = DotsSwiGLUFFN(config, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) def forward(self, @@ -502,7 +510,7 @@ class DotsVisionBlock(nn.Module): return hidden_states -class DotsVisionTransformer(PreTrainedModel): +class DotsVisionTransformer(nn.Module): def __init__( self, @@ -512,8 +520,9 @@ class DotsVisionTransformer(PreTrainedModel): num_hidden_layers_override: Optional[int] = None, require_post_norm: Optional[bool] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: - super().__init__(config) + super().__init__() self.config = config self.spatial_merge_size = config.spatial_merge_size @@ -526,14 +535,15 @@ class DotsVisionTransformer(PreTrainedModel): if self.attn_backend != _Backend.FLASH_ATTN and \ check_upstream_fa_availability(torch.get_default_dtype()): self.attn_backend = _Backend.FLASH_ATTN - + self.out_hidden_size = config.hidden_size # Keep blocks for compatibility with other vision towers num_layers = (config.num_hidden_layers if num_hidden_layers_override is None else num_hidden_layers_override) self.blocks = nn.ModuleList([ DotsVisionBlock(config, quant_config=quant_config, - prefix=f"{prefix}.blocks.{i}") + prefix=f"{prefix}.blocks.{i}", + use_data_parallel=use_data_parallel) for i in range(num_layers) ]) if require_post_norm is None: @@ -548,6 +558,7 @@ class DotsVisionTransformer(PreTrainedModel): dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size, + use_data_parallel=use_data_parallel, ) @property @@ -604,7 +615,11 @@ class DotsVisionTransformer(PreTrainedModel): return max_seqlen, seqlens def forward(self, hidden_states: torch.Tensor, - grid_thw: torch.Tensor) -> torch.Tensor: + grid_thw: list[list[int]]) -> torch.Tensor: + # Convert grid_thw to tensor (always expecting list format now) + grid_thw = torch.tensor(grid_thw, + device=hidden_states.device, + dtype=torch.long) hidden_states = hidden_states.to(self.dtype) hidden_states = self.patch_embed(hidden_states, grid_thw) @@ -638,7 +653,8 @@ class DotsVisionTransformer(PreTrainedModel): info=DotsOCRProcessingInfo, dummy_inputs=DotsOCRDummyInputsBuilder, ) -class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): +class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, + SupportsLoRA): hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ ".attn.qkv_proj.": ".attn.qkv.", @@ -650,6 +666,21 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): }, ) + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + ".attn.qkv": [".attn.qkv"], + "fc13": ["fc1", "fc3"], + } + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -660,19 +691,18 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): self.config: DotsOCRConfig = vllm_config.model_config.hf_config self.quant_config = vllm_config.quant_config - self.multimodal_config = vllm_config.model_config.multimodal_config - + multimodal_config = vllm_config.model_config.multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" if isinstance(self.config.vision_config, dict): vision_config = DotsVisionConfig(**self.config.vision_config) self.config.vision_config = vision_config else: vision_config = self.config.vision_config - self.vision_tower = DotsVisionTransformer( vision_config, quant_config=self.quant_config, prefix=maybe_prefix(prefix, "vision_tower"), - ) + use_data_parallel=self.use_data_parallel) self.language_model: Qwen2ForCausalLM = init_vllm_registered_model( vllm_config=vllm_config, hf_config=self.config, @@ -744,8 +774,17 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): else: pixel_values = image_input["pixel_values"].type( self.vision_tower.dtype) - image_embeds = self.vision_tower( - pixel_values, grid_thw)[:, :self.config.hidden_size] + + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.vision_tower, + pixel_values, + grid_thw_list, + rope_type="rope_3d", + ) + else: + image_embeds = self.vision_tower( + pixel_values, grid_thw)[:, :self.config.hidden_size] # Split concatenated embeddings for each image item. merge_size = self.vision_tower.spatial_merge_size @@ -822,3 +861,13 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="vision_tower.merger", + tower_model="vision_tower.", + )