[Model] Improve DotsOCRForCausalLM (#25466)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Jee Jee Li 2025-09-25 07:58:08 +08:00 committed by yewentao256
parent b95429c920
commit 1d6f767dc4

View File

@ -7,11 +7,13 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm
from transformers.modeling_utils import PreTrainedModel
from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -19,10 +21,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention
from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder,
Qwen2VLMultiModalProcessor,
Qwen2VLProcessingInfo)
@ -38,6 +44,8 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
DotsVisionConfig)
from .vision import run_dp_sharded_mrope_vision_model
IMAGE_TOKEN = "<|imgpad|>"
@ -181,6 +189,8 @@ class PatchMerger(nn.Module):
context_dim: int,
spatial_merge_size: int = 2,
pre_norm="layernorm",
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
@ -189,21 +199,21 @@ class PatchMerger(nn.Module):
self.ln_q = LayerNorm(context_dim, eps=1e-6)
elif self.pre_norm == "rmsnorm":
self.ln_q = RMSNorm(context_dim, eps=1e-6)
else:
print("no norm in patch merger")
self.mlp = nn.Sequential(
ColumnParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
return_bias=False,
disable_tp=True),
prefix=f"{prefix}.0",
disable_tp=use_data_parallel),
nn.GELU(),
RowParallelLinear(self.hidden_size,
dim,
bias=True,
return_bias=False,
disable_tp=True),
prefix=f"{prefix}.2",
disable_tp=use_data_parallel),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -223,38 +233,36 @@ class DotsVisionAttention(nn.Module):
bias: bool = True,
*,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
prefix: str = "",
use_data_parallel: bool = False) -> None:
super().__init__()
from vllm.distributed import (parallel_state,
tensor_model_parallel_all_gather)
from vllm.distributed import utils as dist_utils
self.embed_dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.num_heads_per_partition = dist_utils.divide(
self.tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
self.tp_rank = (0 if use_data_parallel else
get_tensor_model_parallel_rank())
self.hidden_size_per_attention_head = dist_utils.divide(dim, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size)
# qkv/proj follow Qwen2-VL style; bias controlled by arg
self.qkv = QKVParallelLinear(hidden_size=dim,
head_size=dim // num_heads,
self.qkv = QKVParallelLinear(
hidden_size=dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv")
prefix=f"{prefix}.qkv",
disable_tp=use_data_parallel)
self.proj = RowParallelLinear(input_size=dim,
output_size=dim,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.proj")
self._all_gather = tensor_model_parallel_all_gather
self._split_last = dist_utils.split_tensor_along_last_dim
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel)
# Select attention backend
self.attn_backend = get_vit_attn_backend(self.head_dim,
torch.get_default_dtype())
self.attn_backend = get_vit_attn_backend(
self.hidden_size_per_attention_head, torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()):
@ -270,19 +278,6 @@ class DotsVisionAttention(nn.Module):
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}
def _split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# qkv: [S, B, 3*dim]
seq_len, bs, _ = qkv.shape
if self.tp_size > 1:
qkv = self._all_gather(qkv)
q, k, v = qkv.chunk(3, dim=2)
if self.tp_size > 1:
q = self._split_last(q, num_partitions=self.tp_size)[self.tp_rank]
k = self._split_last(k, num_partitions=self.tp_size)[self.tp_rank]
v = self._split_last(v, num_partitions=self.tp_size)[self.tp_rank]
new_shape = (seq_len, bs, self.num_heads_per_partition, self.head_dim)
return (q.view(*new_shape), k.view(*new_shape), v.view(*new_shape))
def forward(
self,
hidden_states: torch.Tensor,
@ -295,7 +290,7 @@ class DotsVisionAttention(nn.Module):
# [S, C] -> [S, B=1, C]
x = hidden_states.unsqueeze(1)
x, _ = self.qkv(x)
q, k, v = self._split_qkv(x)
q, k, v = Qwen2_5_VisionAttention.split_qkv(self, x)
bs = q.shape[1]
# [S,B,H,D] -> [B,S,H,D]
q = q.permute(1, 0, 2, 3).contiguous()
@ -327,8 +322,9 @@ class DotsVisionAttention(nn.Module):
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False)
context_layer = output.view(bs, -1, self.num_heads_per_partition,
self.head_dim)
context_layer = output.view(bs, -1,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
elif self.attn_backend == _Backend.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
@ -368,7 +364,8 @@ class DotsSwiGLUFFN(nn.Module):
config,
*,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
hidden_features = config.intermediate_size
in_features = config.embed_dim
@ -380,13 +377,13 @@ class DotsSwiGLUFFN(nn.Module):
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc13",
disable_tp=True)
disable_tp=use_data_parallel)
self.fc2 = RowParallelLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
disable_tp=True)
disable_tp=use_data_parallel)
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -397,28 +394,36 @@ class DotsSwiGLUFFN(nn.Module):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params = dict(self.named_parameters())
loaded: set[str] = set()
for name, w in weights:
# Map fc1 -> fc13 (shard 0)
if name.startswith("fc1."):
tgt = name.replace("fc1.", "fc13.")
if tgt in params:
params[tgt].weight_loader(params[tgt], w, 0)
loaded.add(tgt)
stacked_params_mapping = [
("fc13", "fc1", 0),
("fc13", "fc3", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# Map fc3 -> fc13 (shard 1)
if name.startswith("fc3."):
tgt = name.replace("fc3.", "fc13.")
if tgt in params:
params[tgt].weight_loader(params[tgt], w, 1)
loaded.add(tgt)
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Pass-through for fc2 and others
if name in params:
params[name].weight_loader(params[name], w)
loaded.add(name)
return loaded
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class DotsPatchEmbed(nn.Module):
@ -463,25 +468,28 @@ class DotsViTPreprocessor(nn.Module):
class DotsVisionBlock(nn.Module):
def __init__(self,
def __init__(
self,
config,
*,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.attn = DotsVisionAttention(
config,
self.attn = DotsVisionAttention(config,
config.embed_dim,
num_heads=config.num_attention_heads,
bias=config.use_bias,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
use_data_parallel=use_data_parallel)
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
self.mlp = DotsSwiGLUFFN(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
def forward(self,
@ -502,7 +510,7 @@ class DotsVisionBlock(nn.Module):
return hidden_states
class DotsVisionTransformer(PreTrainedModel):
class DotsVisionTransformer(nn.Module):
def __init__(
self,
@ -512,8 +520,9 @@ class DotsVisionTransformer(PreTrainedModel):
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__(config)
super().__init__()
self.config = config
self.spatial_merge_size = config.spatial_merge_size
@ -526,14 +535,15 @@ class DotsVisionTransformer(PreTrainedModel):
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.out_hidden_size = config.hidden_size
# Keep blocks for compatibility with other vision towers
num_layers = (config.num_hidden_layers if num_hidden_layers_override
is None else num_hidden_layers_override)
self.blocks = nn.ModuleList([
DotsVisionBlock(config,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{i}")
prefix=f"{prefix}.blocks.{i}",
use_data_parallel=use_data_parallel)
for i in range(num_layers)
])
if require_post_norm is None:
@ -548,6 +558,7 @@ class DotsVisionTransformer(PreTrainedModel):
dim=config.hidden_size,
context_dim=config.embed_dim,
spatial_merge_size=config.spatial_merge_size,
use_data_parallel=use_data_parallel,
)
@property
@ -604,7 +615,11 @@ class DotsVisionTransformer(PreTrainedModel):
return max_seqlen, seqlens
def forward(self, hidden_states: torch.Tensor,
grid_thw: torch.Tensor) -> torch.Tensor:
grid_thw: list[list[int]]) -> torch.Tensor:
# Convert grid_thw to tensor (always expecting list format now)
grid_thw = torch.tensor(grid_thw,
device=hidden_states.device,
dtype=torch.long)
hidden_states = hidden_states.to(self.dtype)
hidden_states = self.patch_embed(hidden_states, grid_thw)
@ -638,7 +653,8 @@ class DotsVisionTransformer(PreTrainedModel):
info=DotsOCRProcessingInfo,
dummy_inputs=DotsOCRDummyInputsBuilder,
)
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
".attn.qkv_proj.": ".attn.qkv.",
@ -650,6 +666,21 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
},
)
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
".attn.qkv": [".attn.qkv"],
"fc13": ["fc1", "fc3"],
}
supports_encoder_tp_data = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
@ -660,19 +691,18 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.config: DotsOCRConfig = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config
self.multimodal_config = vllm_config.model_config.multimodal_config
multimodal_config = vllm_config.model_config.multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if isinstance(self.config.vision_config, dict):
vision_config = DotsVisionConfig(**self.config.vision_config)
self.config.vision_config = vision_config
else:
vision_config = self.config.vision_config
self.vision_tower = DotsVisionTransformer(
vision_config,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
use_data_parallel=self.use_data_parallel)
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=self.config,
@ -744,6 +774,15 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
else:
pixel_values = image_input["pixel_values"].type(
self.vision_tower.dtype)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.vision_tower,
pixel_values,
grid_thw_list,
rope_type="rope_3d",
)
else:
image_embeds = self.vision_tower(
pixel_values, grid_thw)[:, :self.config.hidden_size]
@ -822,3 +861,13 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="vision_tower.merger",
tower_model="vision_tower.",
)