[Model] Improve DotsOCRForCausalLM (#25466)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Jee Jee Li 2025-09-25 07:58:08 +08:00 committed by yewentao256
parent b95429c920
commit 1d6f767dc4

View File

@ -7,11 +7,13 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import LayerNorm from torch.nn import LayerNorm
from transformers.modeling_utils import PreTrainedModel
from transformers.models.qwen2_vl import Qwen2VLProcessor from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -19,10 +21,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, from vllm.model_executor.models.interfaces import (MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal, SupportsMultiModal,
SupportsPP) SupportsPP)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention
from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder, from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder,
Qwen2VLMultiModalProcessor, Qwen2VLMultiModalProcessor,
Qwen2VLProcessingInfo) Qwen2VLProcessingInfo)
@ -38,6 +44,8 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig, from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
DotsVisionConfig) DotsVisionConfig)
from .vision import run_dp_sharded_mrope_vision_model
IMAGE_TOKEN = "<|imgpad|>" IMAGE_TOKEN = "<|imgpad|>"
@ -181,6 +189,8 @@ class PatchMerger(nn.Module):
context_dim: int, context_dim: int,
spatial_merge_size: int = 2, spatial_merge_size: int = 2,
pre_norm="layernorm", pre_norm="layernorm",
prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2) self.hidden_size = context_dim * (spatial_merge_size**2)
@ -189,21 +199,21 @@ class PatchMerger(nn.Module):
self.ln_q = LayerNorm(context_dim, eps=1e-6) self.ln_q = LayerNorm(context_dim, eps=1e-6)
elif self.pre_norm == "rmsnorm": elif self.pre_norm == "rmsnorm":
self.ln_q = RMSNorm(context_dim, eps=1e-6) self.ln_q = RMSNorm(context_dim, eps=1e-6)
else:
print("no norm in patch merger")
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
ColumnParallelLinear(self.hidden_size, ColumnParallelLinear(self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
return_bias=False, return_bias=False,
disable_tp=True), prefix=f"{prefix}.0",
disable_tp=use_data_parallel),
nn.GELU(), nn.GELU(),
RowParallelLinear(self.hidden_size, RowParallelLinear(self.hidden_size,
dim, dim,
bias=True, bias=True,
return_bias=False, return_bias=False,
disable_tp=True), prefix=f"{prefix}.2",
disable_tp=use_data_parallel),
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -223,38 +233,36 @@ class DotsVisionAttention(nn.Module):
bias: bool = True, bias: bool = True,
*, *,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None: prefix: str = "",
use_data_parallel: bool = False) -> None:
super().__init__() super().__init__()
from vllm.distributed import (parallel_state,
tensor_model_parallel_all_gather)
from vllm.distributed import utils as dist_utils
self.embed_dim = dim self.embed_dim = dim
self.num_heads = num_heads self.tp_size = (1 if use_data_parallel else
self.head_dim = dim // num_heads get_tensor_model_parallel_world_size())
self.tp_size = parallel_state.get_tensor_model_parallel_world_size() self.tp_rank = (0 if use_data_parallel else
self.tp_rank = parallel_state.get_tensor_model_parallel_rank() get_tensor_model_parallel_rank())
self.num_heads_per_partition = dist_utils.divide( self.hidden_size_per_attention_head = dist_utils.divide(dim, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size) num_heads, self.tp_size)
# qkv/proj follow Qwen2-VL style; bias controlled by arg # qkv/proj follow Qwen2-VL style; bias controlled by arg
self.qkv = QKVParallelLinear(hidden_size=dim, self.qkv = QKVParallelLinear(
head_size=dim // num_heads, hidden_size=dim,
total_num_heads=num_heads, head_size=self.hidden_size_per_attention_head,
bias=bias, total_num_heads=num_heads,
quant_config=quant_config, bias=bias,
prefix=f"{prefix}.qkv") quant_config=quant_config,
prefix=f"{prefix}.qkv",
disable_tp=use_data_parallel)
self.proj = RowParallelLinear(input_size=dim, self.proj = RowParallelLinear(input_size=dim,
output_size=dim, output_size=dim,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.proj") prefix=f"{prefix}.proj",
self._all_gather = tensor_model_parallel_all_gather disable_tp=use_data_parallel)
self._split_last = dist_utils.split_tensor_along_last_dim
# Select attention backend # Select attention backend
self.attn_backend = get_vit_attn_backend(self.head_dim, self.attn_backend = get_vit_attn_backend(
torch.get_default_dtype()) self.hidden_size_per_attention_head, torch.get_default_dtype())
self.use_upstream_fa = False self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \ if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()): check_upstream_fa_availability(torch.get_default_dtype()):
@ -270,19 +278,6 @@ class DotsVisionAttention(nn.Module):
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
} }
def _split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# qkv: [S, B, 3*dim]
seq_len, bs, _ = qkv.shape
if self.tp_size > 1:
qkv = self._all_gather(qkv)
q, k, v = qkv.chunk(3, dim=2)
if self.tp_size > 1:
q = self._split_last(q, num_partitions=self.tp_size)[self.tp_rank]
k = self._split_last(k, num_partitions=self.tp_size)[self.tp_rank]
v = self._split_last(v, num_partitions=self.tp_size)[self.tp_rank]
new_shape = (seq_len, bs, self.num_heads_per_partition, self.head_dim)
return (q.view(*new_shape), k.view(*new_shape), v.view(*new_shape))
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -295,7 +290,7 @@ class DotsVisionAttention(nn.Module):
# [S, C] -> [S, B=1, C] # [S, C] -> [S, B=1, C]
x = hidden_states.unsqueeze(1) x = hidden_states.unsqueeze(1)
x, _ = self.qkv(x) x, _ = self.qkv(x)
q, k, v = self._split_qkv(x) q, k, v = Qwen2_5_VisionAttention.split_qkv(self, x)
bs = q.shape[1] bs = q.shape[1]
# [S,B,H,D] -> [B,S,H,D] # [S,B,H,D] -> [B,S,H,D]
q = q.permute(1, 0, 2, 3).contiguous() q = q.permute(1, 0, 2, 3).contiguous()
@ -327,8 +322,9 @@ class DotsVisionAttention(nn.Module):
max_seqlen_k=max_seqlen, max_seqlen_k=max_seqlen,
dropout_p=0.0, dropout_p=0.0,
causal=False) causal=False)
context_layer = output.view(bs, -1, self.num_heads_per_partition, context_layer = output.view(bs, -1,
self.head_dim) self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == _Backend.TORCH_SDPA:
outputs = [] outputs = []
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
@ -368,7 +364,8 @@ class DotsSwiGLUFFN(nn.Module):
config, config,
*, *,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = "",
use_data_parallel: bool = False):
super().__init__() super().__init__()
hidden_features = config.intermediate_size hidden_features = config.intermediate_size
in_features = config.embed_dim in_features = config.embed_dim
@ -380,13 +377,13 @@ class DotsSwiGLUFFN(nn.Module):
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc13", prefix=f"{prefix}.fc13",
disable_tp=True) disable_tp=use_data_parallel)
self.fc2 = RowParallelLinear(hidden_features, self.fc2 = RowParallelLinear(hidden_features,
in_features, in_features,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc2", prefix=f"{prefix}.fc2",
disable_tp=True) disable_tp=use_data_parallel)
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -397,28 +394,36 @@ class DotsSwiGLUFFN(nn.Module):
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
params = dict(self.named_parameters()) stacked_params_mapping = [
loaded: set[str] = set() ("fc13", "fc1", 0),
for name, w in weights: ("fc13", "fc3", 1),
# Map fc1 -> fc13 (shard 0) ]
if name.startswith("fc1."): params_dict = dict(self.named_parameters())
tgt = name.replace("fc1.", "fc13.") loaded_params: set[str] = set()
if tgt in params: for name, loaded_weight in weights:
params[tgt].weight_loader(params[tgt], w, 0)
loaded.add(tgt) for param_name, weight_name, shard_id in stacked_params_mapping:
continue if weight_name not in name:
# Map fc3 -> fc13 (shard 1) continue
if name.startswith("fc3."): name = name.replace(weight_name, param_name)
tgt = name.replace("fc3.", "fc13.") # Skip loading extra bias for GPTQ models.
if tgt in params: if name.endswith(".bias") and name not in params_dict:
params[tgt].weight_loader(params[tgt], w, 1) continue
loaded.add(tgt) param = params_dict[name]
continue weight_loader = param.weight_loader
# Pass-through for fc2 and others weight_loader(param, loaded_weight, shard_id)
if name in params: break
params[name].weight_loader(params[name], w) else:
loaded.add(name) # Skip loading extra bias for GPTQ models.
return loaded if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class DotsPatchEmbed(nn.Module): class DotsPatchEmbed(nn.Module):
@ -463,25 +468,28 @@ class DotsViTPreprocessor(nn.Module):
class DotsVisionBlock(nn.Module): class DotsVisionBlock(nn.Module):
def __init__(self, def __init__(
config, self,
*, config,
quant_config: Optional[QuantizationConfig] = None, *,
prefix: str = ""): quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__() super().__init__()
self.attn = DotsVisionAttention( self.attn = DotsVisionAttention(config,
config, config.embed_dim,
config.embed_dim, num_heads=config.num_attention_heads,
num_heads=config.num_attention_heads, bias=config.use_bias,
bias=config.use_bias, quant_config=quant_config,
quant_config=quant_config, prefix=f"{prefix}.attn",
prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel)
)
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
self.mlp = DotsSwiGLUFFN(config, self.mlp = DotsSwiGLUFFN(config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp") prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
def forward(self, def forward(self,
@ -502,7 +510,7 @@ class DotsVisionBlock(nn.Module):
return hidden_states return hidden_states
class DotsVisionTransformer(PreTrainedModel): class DotsVisionTransformer(nn.Module):
def __init__( def __init__(
self, self,
@ -512,8 +520,9 @@ class DotsVisionTransformer(PreTrainedModel):
num_hidden_layers_override: Optional[int] = None, num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None, require_post_norm: Optional[bool] = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__(config) super().__init__()
self.config = config self.config = config
self.spatial_merge_size = config.spatial_merge_size self.spatial_merge_size = config.spatial_merge_size
@ -526,14 +535,15 @@ class DotsVisionTransformer(PreTrainedModel):
if self.attn_backend != _Backend.FLASH_ATTN and \ if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()): check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = _Backend.FLASH_ATTN
self.out_hidden_size = config.hidden_size
# Keep blocks for compatibility with other vision towers # Keep blocks for compatibility with other vision towers
num_layers = (config.num_hidden_layers if num_hidden_layers_override num_layers = (config.num_hidden_layers if num_hidden_layers_override
is None else num_hidden_layers_override) is None else num_hidden_layers_override)
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
DotsVisionBlock(config, DotsVisionBlock(config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.blocks.{i}") prefix=f"{prefix}.blocks.{i}",
use_data_parallel=use_data_parallel)
for i in range(num_layers) for i in range(num_layers)
]) ])
if require_post_norm is None: if require_post_norm is None:
@ -548,6 +558,7 @@ class DotsVisionTransformer(PreTrainedModel):
dim=config.hidden_size, dim=config.hidden_size,
context_dim=config.embed_dim, context_dim=config.embed_dim,
spatial_merge_size=config.spatial_merge_size, spatial_merge_size=config.spatial_merge_size,
use_data_parallel=use_data_parallel,
) )
@property @property
@ -604,7 +615,11 @@ class DotsVisionTransformer(PreTrainedModel):
return max_seqlen, seqlens return max_seqlen, seqlens
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
grid_thw: torch.Tensor) -> torch.Tensor: grid_thw: list[list[int]]) -> torch.Tensor:
# Convert grid_thw to tensor (always expecting list format now)
grid_thw = torch.tensor(grid_thw,
device=hidden_states.device,
dtype=torch.long)
hidden_states = hidden_states.to(self.dtype) hidden_states = hidden_states.to(self.dtype)
hidden_states = self.patch_embed(hidden_states, grid_thw) hidden_states = self.patch_embed(hidden_states, grid_thw)
@ -638,7 +653,8 @@ class DotsVisionTransformer(PreTrainedModel):
info=DotsOCRProcessingInfo, info=DotsOCRProcessingInfo,
dummy_inputs=DotsOCRDummyInputsBuilder, dummy_inputs=DotsOCRDummyInputsBuilder,
) )
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={ orig_to_new_substr={
".attn.qkv_proj.": ".attn.qkv.", ".attn.qkv_proj.": ".attn.qkv.",
@ -650,6 +666,21 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
}, },
) )
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
".attn.qkv": [".attn.qkv"],
"fc13": ["fc1", "fc3"],
}
supports_encoder_tp_data = True
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"): if modality.startswith("image"):
@ -660,19 +691,18 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.config: DotsOCRConfig = vllm_config.model_config.hf_config self.config: DotsOCRConfig = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config self.quant_config = vllm_config.quant_config
self.multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if isinstance(self.config.vision_config, dict): if isinstance(self.config.vision_config, dict):
vision_config = DotsVisionConfig(**self.config.vision_config) vision_config = DotsVisionConfig(**self.config.vision_config)
self.config.vision_config = vision_config self.config.vision_config = vision_config
else: else:
vision_config = self.config.vision_config vision_config = self.config.vision_config
self.vision_tower = DotsVisionTransformer( self.vision_tower = DotsVisionTransformer(
vision_config, vision_config,
quant_config=self.quant_config, quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) use_data_parallel=self.use_data_parallel)
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model( self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
hf_config=self.config, hf_config=self.config,
@ -744,8 +774,17 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
else: else:
pixel_values = image_input["pixel_values"].type( pixel_values = image_input["pixel_values"].type(
self.vision_tower.dtype) self.vision_tower.dtype)
image_embeds = self.vision_tower(
pixel_values, grid_thw)[:, :self.config.hidden_size] if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.vision_tower,
pixel_values,
grid_thw_list,
rope_type="rope_3d",
)
else:
image_embeds = self.vision_tower(
pixel_values, grid_thw)[:, :self.config.hidden_size]
# Split concatenated embeddings for each image item. # Split concatenated embeddings for each image item.
merge_size = self.vision_tower.spatial_merge_size merge_size = self.vision_tower.spatial_merge_size
@ -822,3 +861,13 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="vision_tower.merger",
tower_model="vision_tower.",
)