mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 07:35:01 +08:00
[Model] Improve DotsOCRForCausalLM (#25466)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
b95429c920
commit
1d6f767dc4
@ -7,11 +7,13 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import LayerNorm
|
from torch.nn import LayerNorm
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
||||||
|
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import check_upstream_fa_availability
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import utils as dist_utils
|
||||||
|
from vllm.distributed.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -19,10 +21,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import (MultiModalEmbeddings,
|
from vllm.model_executor.models.interfaces import (MultiModalEmbeddings,
|
||||||
|
SupportsLoRA,
|
||||||
SupportsMultiModal,
|
SupportsMultiModal,
|
||||||
SupportsPP)
|
SupportsPP)
|
||||||
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
|
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
|
||||||
|
from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention
|
||||||
from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder,
|
from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder,
|
||||||
Qwen2VLMultiModalProcessor,
|
Qwen2VLMultiModalProcessor,
|
||||||
Qwen2VLProcessingInfo)
|
Qwen2VLProcessingInfo)
|
||||||
@ -38,6 +44,8 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
|
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
|
||||||
DotsVisionConfig)
|
DotsVisionConfig)
|
||||||
|
|
||||||
|
from .vision import run_dp_sharded_mrope_vision_model
|
||||||
|
|
||||||
IMAGE_TOKEN = "<|imgpad|>"
|
IMAGE_TOKEN = "<|imgpad|>"
|
||||||
|
|
||||||
|
|
||||||
@ -181,6 +189,8 @@ class PatchMerger(nn.Module):
|
|||||||
context_dim: int,
|
context_dim: int,
|
||||||
spatial_merge_size: int = 2,
|
spatial_merge_size: int = 2,
|
||||||
pre_norm="layernorm",
|
pre_norm="layernorm",
|
||||||
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||||
@ -189,21 +199,21 @@ class PatchMerger(nn.Module):
|
|||||||
self.ln_q = LayerNorm(context_dim, eps=1e-6)
|
self.ln_q = LayerNorm(context_dim, eps=1e-6)
|
||||||
elif self.pre_norm == "rmsnorm":
|
elif self.pre_norm == "rmsnorm":
|
||||||
self.ln_q = RMSNorm(context_dim, eps=1e-6)
|
self.ln_q = RMSNorm(context_dim, eps=1e-6)
|
||||||
else:
|
|
||||||
print("no norm in patch merger")
|
|
||||||
|
|
||||||
self.mlp = nn.Sequential(
|
self.mlp = nn.Sequential(
|
||||||
ColumnParallelLinear(self.hidden_size,
|
ColumnParallelLinear(self.hidden_size,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
return_bias=False,
|
return_bias=False,
|
||||||
disable_tp=True),
|
prefix=f"{prefix}.0",
|
||||||
|
disable_tp=use_data_parallel),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
RowParallelLinear(self.hidden_size,
|
RowParallelLinear(self.hidden_size,
|
||||||
dim,
|
dim,
|
||||||
bias=True,
|
bias=True,
|
||||||
return_bias=False,
|
return_bias=False,
|
||||||
disable_tp=True),
|
prefix=f"{prefix}.2",
|
||||||
|
disable_tp=use_data_parallel),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -223,38 +233,36 @@ class DotsVisionAttention(nn.Module):
|
|||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
*,
|
*,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "") -> None:
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
from vllm.distributed import (parallel_state,
|
|
||||||
tensor_model_parallel_all_gather)
|
|
||||||
from vllm.distributed import utils as dist_utils
|
|
||||||
|
|
||||||
self.embed_dim = dim
|
self.embed_dim = dim
|
||||||
self.num_heads = num_heads
|
self.tp_size = (1 if use_data_parallel else
|
||||||
self.head_dim = dim // num_heads
|
get_tensor_model_parallel_world_size())
|
||||||
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
|
self.tp_rank = (0 if use_data_parallel else
|
||||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
get_tensor_model_parallel_rank())
|
||||||
self.num_heads_per_partition = dist_utils.divide(
|
self.hidden_size_per_attention_head = dist_utils.divide(dim, num_heads)
|
||||||
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||||
num_heads, self.tp_size)
|
num_heads, self.tp_size)
|
||||||
|
|
||||||
# qkv/proj follow Qwen2-VL style; bias controlled by arg
|
# qkv/proj follow Qwen2-VL style; bias controlled by arg
|
||||||
self.qkv = QKVParallelLinear(hidden_size=dim,
|
self.qkv = QKVParallelLinear(
|
||||||
head_size=dim // num_heads,
|
hidden_size=dim,
|
||||||
total_num_heads=num_heads,
|
head_size=self.hidden_size_per_attention_head,
|
||||||
bias=bias,
|
total_num_heads=num_heads,
|
||||||
quant_config=quant_config,
|
bias=bias,
|
||||||
prefix=f"{prefix}.qkv")
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv",
|
||||||
|
disable_tp=use_data_parallel)
|
||||||
self.proj = RowParallelLinear(input_size=dim,
|
self.proj = RowParallelLinear(input_size=dim,
|
||||||
output_size=dim,
|
output_size=dim,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.proj")
|
prefix=f"{prefix}.proj",
|
||||||
self._all_gather = tensor_model_parallel_all_gather
|
disable_tp=use_data_parallel)
|
||||||
self._split_last = dist_utils.split_tensor_along_last_dim
|
|
||||||
|
|
||||||
# Select attention backend
|
# Select attention backend
|
||||||
self.attn_backend = get_vit_attn_backend(self.head_dim,
|
self.attn_backend = get_vit_attn_backend(
|
||||||
torch.get_default_dtype())
|
self.hidden_size_per_attention_head, torch.get_default_dtype())
|
||||||
self.use_upstream_fa = False
|
self.use_upstream_fa = False
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||||
check_upstream_fa_availability(torch.get_default_dtype()):
|
check_upstream_fa_availability(torch.get_default_dtype()):
|
||||||
@ -270,19 +278,6 @@ class DotsVisionAttention(nn.Module):
|
|||||||
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||||
}
|
}
|
||||||
|
|
||||||
def _split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
|
||||||
# qkv: [S, B, 3*dim]
|
|
||||||
seq_len, bs, _ = qkv.shape
|
|
||||||
if self.tp_size > 1:
|
|
||||||
qkv = self._all_gather(qkv)
|
|
||||||
q, k, v = qkv.chunk(3, dim=2)
|
|
||||||
if self.tp_size > 1:
|
|
||||||
q = self._split_last(q, num_partitions=self.tp_size)[self.tp_rank]
|
|
||||||
k = self._split_last(k, num_partitions=self.tp_size)[self.tp_rank]
|
|
||||||
v = self._split_last(v, num_partitions=self.tp_size)[self.tp_rank]
|
|
||||||
new_shape = (seq_len, bs, self.num_heads_per_partition, self.head_dim)
|
|
||||||
return (q.view(*new_shape), k.view(*new_shape), v.view(*new_shape))
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -295,7 +290,7 @@ class DotsVisionAttention(nn.Module):
|
|||||||
# [S, C] -> [S, B=1, C]
|
# [S, C] -> [S, B=1, C]
|
||||||
x = hidden_states.unsqueeze(1)
|
x = hidden_states.unsqueeze(1)
|
||||||
x, _ = self.qkv(x)
|
x, _ = self.qkv(x)
|
||||||
q, k, v = self._split_qkv(x)
|
q, k, v = Qwen2_5_VisionAttention.split_qkv(self, x)
|
||||||
bs = q.shape[1]
|
bs = q.shape[1]
|
||||||
# [S,B,H,D] -> [B,S,H,D]
|
# [S,B,H,D] -> [B,S,H,D]
|
||||||
q = q.permute(1, 0, 2, 3).contiguous()
|
q = q.permute(1, 0, 2, 3).contiguous()
|
||||||
@ -327,8 +322,9 @@ class DotsVisionAttention(nn.Module):
|
|||||||
max_seqlen_k=max_seqlen,
|
max_seqlen_k=max_seqlen,
|
||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
causal=False)
|
causal=False)
|
||||||
context_layer = output.view(bs, -1, self.num_heads_per_partition,
|
context_layer = output.view(bs, -1,
|
||||||
self.head_dim)
|
self.num_attention_heads_per_partition,
|
||||||
|
self.hidden_size_per_attention_head)
|
||||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
outputs = []
|
outputs = []
|
||||||
for i in range(1, len(cu_seqlens)):
|
for i in range(1, len(cu_seqlens)):
|
||||||
@ -368,7 +364,8 @@ class DotsSwiGLUFFN(nn.Module):
|
|||||||
config,
|
config,
|
||||||
*,
|
*,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = ""):
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_features = config.intermediate_size
|
hidden_features = config.intermediate_size
|
||||||
in_features = config.embed_dim
|
in_features = config.embed_dim
|
||||||
@ -380,13 +377,13 @@ class DotsSwiGLUFFN(nn.Module):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.fc13",
|
prefix=f"{prefix}.fc13",
|
||||||
disable_tp=True)
|
disable_tp=use_data_parallel)
|
||||||
self.fc2 = RowParallelLinear(hidden_features,
|
self.fc2 = RowParallelLinear(hidden_features,
|
||||||
in_features,
|
in_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.fc2",
|
prefix=f"{prefix}.fc2",
|
||||||
disable_tp=True)
|
disable_tp=use_data_parallel)
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -397,28 +394,36 @@ class DotsSwiGLUFFN(nn.Module):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
params = dict(self.named_parameters())
|
stacked_params_mapping = [
|
||||||
loaded: set[str] = set()
|
("fc13", "fc1", 0),
|
||||||
for name, w in weights:
|
("fc13", "fc3", 1),
|
||||||
# Map fc1 -> fc13 (shard 0)
|
]
|
||||||
if name.startswith("fc1."):
|
params_dict = dict(self.named_parameters())
|
||||||
tgt = name.replace("fc1.", "fc13.")
|
loaded_params: set[str] = set()
|
||||||
if tgt in params:
|
for name, loaded_weight in weights:
|
||||||
params[tgt].weight_loader(params[tgt], w, 0)
|
|
||||||
loaded.add(tgt)
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
continue
|
if weight_name not in name:
|
||||||
# Map fc3 -> fc13 (shard 1)
|
continue
|
||||||
if name.startswith("fc3."):
|
name = name.replace(weight_name, param_name)
|
||||||
tgt = name.replace("fc3.", "fc13.")
|
# Skip loading extra bias for GPTQ models.
|
||||||
if tgt in params:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
params[tgt].weight_loader(params[tgt], w, 1)
|
continue
|
||||||
loaded.add(tgt)
|
param = params_dict[name]
|
||||||
continue
|
weight_loader = param.weight_loader
|
||||||
# Pass-through for fc2 and others
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
if name in params:
|
break
|
||||||
params[name].weight_loader(params[name], w)
|
else:
|
||||||
loaded.add(name)
|
# Skip loading extra bias for GPTQ models.
|
||||||
return loaded
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class DotsPatchEmbed(nn.Module):
|
class DotsPatchEmbed(nn.Module):
|
||||||
@ -463,25 +468,28 @@ class DotsViTPreprocessor(nn.Module):
|
|||||||
|
|
||||||
class DotsVisionBlock(nn.Module):
|
class DotsVisionBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
config,
|
self,
|
||||||
*,
|
config,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
*,
|
||||||
prefix: str = ""):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attn = DotsVisionAttention(
|
self.attn = DotsVisionAttention(config,
|
||||||
config,
|
config.embed_dim,
|
||||||
config.embed_dim,
|
num_heads=config.num_attention_heads,
|
||||||
num_heads=config.num_attention_heads,
|
bias=config.use_bias,
|
||||||
bias=config.use_bias,
|
quant_config=quant_config,
|
||||||
quant_config=quant_config,
|
prefix=f"{prefix}.attn",
|
||||||
prefix=f"{prefix}.attn",
|
use_data_parallel=use_data_parallel)
|
||||||
)
|
|
||||||
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
||||||
self.mlp = DotsSwiGLUFFN(config,
|
self.mlp = DotsSwiGLUFFN(config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp")
|
prefix=f"{prefix}.mlp",
|
||||||
|
use_data_parallel=use_data_parallel)
|
||||||
self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
@ -502,7 +510,7 @@ class DotsVisionBlock(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class DotsVisionTransformer(PreTrainedModel):
|
class DotsVisionTransformer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -512,8 +520,9 @@ class DotsVisionTransformer(PreTrainedModel):
|
|||||||
num_hidden_layers_override: Optional[int] = None,
|
num_hidden_layers_override: Optional[int] = None,
|
||||||
require_post_norm: Optional[bool] = None,
|
require_post_norm: Optional[bool] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(config)
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.spatial_merge_size = config.spatial_merge_size
|
self.spatial_merge_size = config.spatial_merge_size
|
||||||
|
|
||||||
@ -526,14 +535,15 @@ class DotsVisionTransformer(PreTrainedModel):
|
|||||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||||
check_upstream_fa_availability(torch.get_default_dtype()):
|
check_upstream_fa_availability(torch.get_default_dtype()):
|
||||||
self.attn_backend = _Backend.FLASH_ATTN
|
self.attn_backend = _Backend.FLASH_ATTN
|
||||||
|
self.out_hidden_size = config.hidden_size
|
||||||
# Keep blocks for compatibility with other vision towers
|
# Keep blocks for compatibility with other vision towers
|
||||||
num_layers = (config.num_hidden_layers if num_hidden_layers_override
|
num_layers = (config.num_hidden_layers if num_hidden_layers_override
|
||||||
is None else num_hidden_layers_override)
|
is None else num_hidden_layers_override)
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
DotsVisionBlock(config,
|
DotsVisionBlock(config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.blocks.{i}")
|
prefix=f"{prefix}.blocks.{i}",
|
||||||
|
use_data_parallel=use_data_parallel)
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
])
|
])
|
||||||
if require_post_norm is None:
|
if require_post_norm is None:
|
||||||
@ -548,6 +558,7 @@ class DotsVisionTransformer(PreTrainedModel):
|
|||||||
dim=config.hidden_size,
|
dim=config.hidden_size,
|
||||||
context_dim=config.embed_dim,
|
context_dim=config.embed_dim,
|
||||||
spatial_merge_size=config.spatial_merge_size,
|
spatial_merge_size=config.spatial_merge_size,
|
||||||
|
use_data_parallel=use_data_parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -604,7 +615,11 @@ class DotsVisionTransformer(PreTrainedModel):
|
|||||||
return max_seqlen, seqlens
|
return max_seqlen, seqlens
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
grid_thw: torch.Tensor) -> torch.Tensor:
|
grid_thw: list[list[int]]) -> torch.Tensor:
|
||||||
|
# Convert grid_thw to tensor (always expecting list format now)
|
||||||
|
grid_thw = torch.tensor(grid_thw,
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=torch.long)
|
||||||
hidden_states = hidden_states.to(self.dtype)
|
hidden_states = hidden_states.to(self.dtype)
|
||||||
hidden_states = self.patch_embed(hidden_states, grid_thw)
|
hidden_states = self.patch_embed(hidden_states, grid_thw)
|
||||||
|
|
||||||
@ -638,7 +653,8 @@ class DotsVisionTransformer(PreTrainedModel):
|
|||||||
info=DotsOCRProcessingInfo,
|
info=DotsOCRProcessingInfo,
|
||||||
dummy_inputs=DotsOCRDummyInputsBuilder,
|
dummy_inputs=DotsOCRDummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||||
|
SupportsLoRA):
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_substr={
|
orig_to_new_substr={
|
||||||
".attn.qkv_proj.": ".attn.qkv.",
|
".attn.qkv_proj.": ".attn.qkv.",
|
||||||
@ -650,6 +666,21 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
".attn.qkv": [".attn.qkv"],
|
||||||
|
"fc13": ["fc1", "fc3"],
|
||||||
|
}
|
||||||
|
supports_encoder_tp_data = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||||
if modality.startswith("image"):
|
if modality.startswith("image"):
|
||||||
@ -660,19 +691,18 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
self.config: DotsOCRConfig = vllm_config.model_config.hf_config
|
self.config: DotsOCRConfig = vllm_config.model_config.hf_config
|
||||||
self.quant_config = vllm_config.quant_config
|
self.quant_config = vllm_config.quant_config
|
||||||
self.multimodal_config = vllm_config.model_config.multimodal_config
|
multimodal_config = vllm_config.model_config.multimodal_config
|
||||||
|
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||||
if isinstance(self.config.vision_config, dict):
|
if isinstance(self.config.vision_config, dict):
|
||||||
vision_config = DotsVisionConfig(**self.config.vision_config)
|
vision_config = DotsVisionConfig(**self.config.vision_config)
|
||||||
self.config.vision_config = vision_config
|
self.config.vision_config = vision_config
|
||||||
else:
|
else:
|
||||||
vision_config = self.config.vision_config
|
vision_config = self.config.vision_config
|
||||||
|
|
||||||
self.vision_tower = DotsVisionTransformer(
|
self.vision_tower = DotsVisionTransformer(
|
||||||
vision_config,
|
vision_config,
|
||||||
quant_config=self.quant_config,
|
quant_config=self.quant_config,
|
||||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||||
)
|
use_data_parallel=self.use_data_parallel)
|
||||||
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
|
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
hf_config=self.config,
|
hf_config=self.config,
|
||||||
@ -744,8 +774,17 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
else:
|
else:
|
||||||
pixel_values = image_input["pixel_values"].type(
|
pixel_values = image_input["pixel_values"].type(
|
||||||
self.vision_tower.dtype)
|
self.vision_tower.dtype)
|
||||||
image_embeds = self.vision_tower(
|
|
||||||
pixel_values, grid_thw)[:, :self.config.hidden_size]
|
if self.use_data_parallel:
|
||||||
|
return run_dp_sharded_mrope_vision_model(
|
||||||
|
self.vision_tower,
|
||||||
|
pixel_values,
|
||||||
|
grid_thw_list,
|
||||||
|
rope_type="rope_3d",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_embeds = self.vision_tower(
|
||||||
|
pixel_values, grid_thw)[:, :self.config.hidden_size]
|
||||||
|
|
||||||
# Split concatenated embeddings for each image item.
|
# Split concatenated embeddings for each image item.
|
||||||
merge_size = self.vision_tower.spatial_merge_size
|
merge_size = self.vision_tower.spatial_merge_size
|
||||||
@ -822,3 +861,13 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
|
def get_mm_mapping(self) -> MultiModelKeys:
|
||||||
|
"""
|
||||||
|
Get the module prefix in multimodal models
|
||||||
|
"""
|
||||||
|
return MultiModelKeys.from_string_field(
|
||||||
|
language_model="language_model",
|
||||||
|
connector="vision_tower.merger",
|
||||||
|
tower_model="vision_tower.",
|
||||||
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user