mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 07:35:01 +08:00
[Model] Improve DotsOCRForCausalLM (#25466)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
b95429c920
commit
1d6f767dc4
@ -7,11 +7,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import LayerNorm
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
||||
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -19,10 +21,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
|
||||
from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention
|
||||
from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder,
|
||||
Qwen2VLMultiModalProcessor,
|
||||
Qwen2VLProcessingInfo)
|
||||
@ -38,6 +44,8 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
|
||||
DotsVisionConfig)
|
||||
|
||||
from .vision import run_dp_sharded_mrope_vision_model
|
||||
|
||||
IMAGE_TOKEN = "<|imgpad|>"
|
||||
|
||||
|
||||
@ -181,6 +189,8 @@ class PatchMerger(nn.Module):
|
||||
context_dim: int,
|
||||
spatial_merge_size: int = 2,
|
||||
pre_norm="layernorm",
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||
@ -189,21 +199,21 @@ class PatchMerger(nn.Module):
|
||||
self.ln_q = LayerNorm(context_dim, eps=1e-6)
|
||||
elif self.pre_norm == "rmsnorm":
|
||||
self.ln_q = RMSNorm(context_dim, eps=1e-6)
|
||||
else:
|
||||
print("no norm in patch merger")
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
ColumnParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
return_bias=False,
|
||||
disable_tp=True),
|
||||
prefix=f"{prefix}.0",
|
||||
disable_tp=use_data_parallel),
|
||||
nn.GELU(),
|
||||
RowParallelLinear(self.hidden_size,
|
||||
dim,
|
||||
bias=True,
|
||||
return_bias=False,
|
||||
disable_tp=True),
|
||||
prefix=f"{prefix}.2",
|
||||
disable_tp=use_data_parallel),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -223,38 +233,36 @@ class DotsVisionAttention(nn.Module):
|
||||
bias: bool = True,
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False) -> None:
|
||||
super().__init__()
|
||||
from vllm.distributed import (parallel_state,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.distributed import utils as dist_utils
|
||||
|
||||
self.embed_dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
self.num_heads_per_partition = dist_utils.divide(
|
||||
self.tp_size = (1 if use_data_parallel else
|
||||
get_tensor_model_parallel_world_size())
|
||||
self.tp_rank = (0 if use_data_parallel else
|
||||
get_tensor_model_parallel_rank())
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(dim, num_heads)
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, self.tp_size)
|
||||
|
||||
# qkv/proj follow Qwen2-VL style; bias controlled by arg
|
||||
self.qkv = QKVParallelLinear(hidden_size=dim,
|
||||
head_size=dim // num_heads,
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=dim,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
total_num_heads=num_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv")
|
||||
prefix=f"{prefix}.qkv",
|
||||
disable_tp=use_data_parallel)
|
||||
self.proj = RowParallelLinear(input_size=dim,
|
||||
output_size=dim,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj")
|
||||
self._all_gather = tensor_model_parallel_all_gather
|
||||
self._split_last = dist_utils.split_tensor_along_last_dim
|
||||
|
||||
prefix=f"{prefix}.proj",
|
||||
disable_tp=use_data_parallel)
|
||||
# Select attention backend
|
||||
self.attn_backend = get_vit_attn_backend(self.head_dim,
|
||||
torch.get_default_dtype())
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
self.hidden_size_per_attention_head, torch.get_default_dtype())
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(torch.get_default_dtype()):
|
||||
@ -270,19 +278,6 @@ class DotsVisionAttention(nn.Module):
|
||||
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||
}
|
||||
|
||||
def _split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# qkv: [S, B, 3*dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
if self.tp_size > 1:
|
||||
qkv = self._all_gather(qkv)
|
||||
q, k, v = qkv.chunk(3, dim=2)
|
||||
if self.tp_size > 1:
|
||||
q = self._split_last(q, num_partitions=self.tp_size)[self.tp_rank]
|
||||
k = self._split_last(k, num_partitions=self.tp_size)[self.tp_rank]
|
||||
v = self._split_last(v, num_partitions=self.tp_size)[self.tp_rank]
|
||||
new_shape = (seq_len, bs, self.num_heads_per_partition, self.head_dim)
|
||||
return (q.view(*new_shape), k.view(*new_shape), v.view(*new_shape))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -295,7 +290,7 @@ class DotsVisionAttention(nn.Module):
|
||||
# [S, C] -> [S, B=1, C]
|
||||
x = hidden_states.unsqueeze(1)
|
||||
x, _ = self.qkv(x)
|
||||
q, k, v = self._split_qkv(x)
|
||||
q, k, v = Qwen2_5_VisionAttention.split_qkv(self, x)
|
||||
bs = q.shape[1]
|
||||
# [S,B,H,D] -> [B,S,H,D]
|
||||
q = q.permute(1, 0, 2, 3).contiguous()
|
||||
@ -327,8 +322,9 @@ class DotsVisionAttention(nn.Module):
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False)
|
||||
context_layer = output.view(bs, -1, self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
context_layer = output.view(bs, -1,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
@ -368,7 +364,8 @@ class DotsSwiGLUFFN(nn.Module):
|
||||
config,
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
hidden_features = config.intermediate_size
|
||||
in_features = config.embed_dim
|
||||
@ -380,13 +377,13 @@ class DotsSwiGLUFFN(nn.Module):
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc13",
|
||||
disable_tp=True)
|
||||
disable_tp=use_data_parallel)
|
||||
self.fc2 = RowParallelLinear(hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
disable_tp=True)
|
||||
disable_tp=use_data_parallel)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -397,28 +394,36 @@ class DotsSwiGLUFFN(nn.Module):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
params = dict(self.named_parameters())
|
||||
loaded: set[str] = set()
|
||||
for name, w in weights:
|
||||
# Map fc1 -> fc13 (shard 0)
|
||||
if name.startswith("fc1."):
|
||||
tgt = name.replace("fc1.", "fc13.")
|
||||
if tgt in params:
|
||||
params[tgt].weight_loader(params[tgt], w, 0)
|
||||
loaded.add(tgt)
|
||||
stacked_params_mapping = [
|
||||
("fc13", "fc1", 0),
|
||||
("fc13", "fc3", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# Map fc3 -> fc13 (shard 1)
|
||||
if name.startswith("fc3."):
|
||||
tgt = name.replace("fc3.", "fc13.")
|
||||
if tgt in params:
|
||||
params[tgt].weight_loader(params[tgt], w, 1)
|
||||
loaded.add(tgt)
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Pass-through for fc2 and others
|
||||
if name in params:
|
||||
params[name].weight_loader(params[name], w)
|
||||
loaded.add(name)
|
||||
return loaded
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class DotsPatchEmbed(nn.Module):
|
||||
@ -463,25 +468,28 @@ class DotsViTPreprocessor(nn.Module):
|
||||
|
||||
class DotsVisionBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn = DotsVisionAttention(
|
||||
config,
|
||||
self.attn = DotsVisionAttention(config,
|
||||
config.embed_dim,
|
||||
num_heads=config.num_attention_heads,
|
||||
bias=config.use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
use_data_parallel=use_data_parallel)
|
||||
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
||||
self.mlp = DotsSwiGLUFFN(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel)
|
||||
self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(self,
|
||||
@ -502,7 +510,7 @@ class DotsVisionBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DotsVisionTransformer(PreTrainedModel):
|
||||
class DotsVisionTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -512,8 +520,9 @@ class DotsVisionTransformer(PreTrainedModel):
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__(config)
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.spatial_merge_size = config.spatial_merge_size
|
||||
|
||||
@ -526,14 +535,15 @@ class DotsVisionTransformer(PreTrainedModel):
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
|
||||
self.out_hidden_size = config.hidden_size
|
||||
# Keep blocks for compatibility with other vision towers
|
||||
num_layers = (config.num_hidden_layers if num_hidden_layers_override
|
||||
is None else num_hidden_layers_override)
|
||||
self.blocks = nn.ModuleList([
|
||||
DotsVisionBlock(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{i}")
|
||||
prefix=f"{prefix}.blocks.{i}",
|
||||
use_data_parallel=use_data_parallel)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
if require_post_norm is None:
|
||||
@ -548,6 +558,7 @@ class DotsVisionTransformer(PreTrainedModel):
|
||||
dim=config.hidden_size,
|
||||
context_dim=config.embed_dim,
|
||||
spatial_merge_size=config.spatial_merge_size,
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
@property
|
||||
@ -604,7 +615,11 @@ class DotsVisionTransformer(PreTrainedModel):
|
||||
return max_seqlen, seqlens
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
grid_thw: list[list[int]]) -> torch.Tensor:
|
||||
# Convert grid_thw to tensor (always expecting list format now)
|
||||
grid_thw = torch.tensor(grid_thw,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.long)
|
||||
hidden_states = hidden_states.to(self.dtype)
|
||||
hidden_states = self.patch_embed(hidden_states, grid_thw)
|
||||
|
||||
@ -638,7 +653,8 @@ class DotsVisionTransformer(PreTrainedModel):
|
||||
info=DotsOCRProcessingInfo,
|
||||
dummy_inputs=DotsOCRDummyInputsBuilder,
|
||||
)
|
||||
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
SupportsLoRA):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
".attn.qkv_proj.": ".attn.qkv.",
|
||||
@ -650,6 +666,21 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
},
|
||||
)
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
".attn.qkv": [".attn.qkv"],
|
||||
"fc13": ["fc1", "fc3"],
|
||||
}
|
||||
supports_encoder_tp_data = True
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
@ -660,19 +691,18 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
self.config: DotsOCRConfig = vllm_config.model_config.hf_config
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if isinstance(self.config.vision_config, dict):
|
||||
vision_config = DotsVisionConfig(**self.config.vision_config)
|
||||
self.config.vision_config = vision_config
|
||||
else:
|
||||
vision_config = self.config.vision_config
|
||||
|
||||
self.vision_tower = DotsVisionTransformer(
|
||||
vision_config,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=self.config,
|
||||
@ -744,6 +774,15 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
else:
|
||||
pixel_values = image_input["pixel_values"].type(
|
||||
self.vision_tower.dtype)
|
||||
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.vision_tower,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d",
|
||||
)
|
||||
else:
|
||||
image_embeds = self.vision_tower(
|
||||
pixel_values, grid_thw)[:, :self.config.hidden_size]
|
||||
|
||||
@ -822,3 +861,13 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="vision_tower.merger",
|
||||
tower_model="vision_tower.",
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user