[Model] Add Ovis2.5 PP support (#23405)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-08-23 01:46:34 +08:00 committed by GitHub
parent 22cf679aad
commit 32d2b4064f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 185 additions and 105 deletions

View File

@ -233,6 +233,7 @@ MULTIMODAL_MODELS = {
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(), "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(), "allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
"AIDC-AI/Ovis2-1B": PPTestSettings.fast(), "AIDC-AI/Ovis2-1B": PPTestSettings.fast(),
"AIDC-AI/Ovis2.5-2B": PPTestSettings.fast(),
"microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(), "microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"), "mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(), "Qwen/Qwen-VL-Chat": PPTestSettings.fast(),

View File

@ -11,7 +11,6 @@ from pathlib import PosixPath
import pytest import pytest
from transformers import (AutoModel, AutoModelForImageTextToText, from transformers import (AutoModel, AutoModelForImageTextToText,
AutoModelForTextToWaveform, AutoModelForVision2Seq) AutoModelForTextToWaveform, AutoModelForVision2Seq)
from transformers.utils import is_flash_attn_2_available
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import identity from vllm.utils import identity
@ -637,10 +636,7 @@ VLM_TEST_SETTINGS = {
dtype="half", dtype="half",
num_logprobs=10, num_logprobs=10,
patch_hf_runner=model_utils.ovis2_5_patch_hf_runner, patch_hf_runner=model_utils.ovis2_5_patch_hf_runner,
marks=[pytest.mark.skipif( hf_model_kwargs={"revision": "refs/pr/5"},
not is_flash_attn_2_available(),
reason="HF model needs `flash_attn` installed"
)],
), ),
"phi3v": VLMTestInfo( "phi3v": VLMTestInfo(
models=["microsoft/Phi-3.5-vision-instruct"], models=["microsoft/Phi-3.5-vision-instruct"],

View File

@ -468,9 +468,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B", extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501 "1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
"Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B", "Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B",
trust_remote_code=True, trust_remote_code=True),
max_transformers_version="4.53",
transformers_version_reason="HF model is not compatible"), # noqa: E501
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501 "PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501 extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct", "Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",

View File

@ -30,7 +30,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
from .interfaces import MultiModalEmbeddings, SupportsMultiModal from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
IMAGE_TOKEN = "<image>" IMAGE_TOKEN = "<image>"
VIDEO_TOKEN = "<video>" VIDEO_TOKEN = "<video>"
@ -70,6 +70,7 @@ class VisualTokenizer(torch.nn.Module):
visual_vocab_size: int, visual_vocab_size: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -77,6 +78,7 @@ class VisualTokenizer(torch.nn.Module):
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.vit", prefix=f"{prefix}.vit",
use_data_parallel=use_data_parallel,
) )
# reserved tokens for INDICATOR_IDS # reserved tokens for INDICATOR_IDS
head_dim = visual_vocab_size - len(INDICATOR_IDS) head_dim = visual_vocab_size - len(INDICATOR_IDS)
@ -93,31 +95,33 @@ class VisualTokenizer(torch.nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
model_type = config.model_type model_type = config.model_type
if model_type == "siglip2_navit": if model_type == "siglip2_navit":
return Siglip2NavitModel(config=config, ) return Siglip2NavitModel(config=config,
quant_config=quant_config,
prefix=prefix,
use_data_parallel=use_data_parallel)
raise ValueError( raise ValueError(
f"Unsupported visual tokenizer model_type: {model_type}") f"Unsupported visual tokenizer model_type: {model_type}")
@property @property
def dtype(self): def dtype(self) -> torch.dtype:
return next(self.head.parameters()).dtype return next(self.head.parameters()).dtype
@property @property
def device(self): def device(self) -> torch.device:
return next(self.head.parameters()).device return next(self.head.parameters()).device
def tokenize(self, logits): def tokenize(self, logits: torch.Tensor) -> torch.Tensor:
tokens = torch.softmax(logits, dim=-1, tokens = torch.softmax(logits, dim=-1,
dtype=torch.float32).to(logits.dtype) dtype=torch.float32).to(logits.dtype)
return tokens return tokens
def encode(self, pixel_values, grid_thws): def encode(self, pixel_values: torch.Tensor,
features = self.vit(pixel_values, grid_thws: torch.Tensor) -> torch.Tensor:
grid_thws, features = self.vit(pixel_values, grid_thws)
output_hidden_states=True,
return_dict=True)
# refer to qwen2.5-vl patchmerger # refer to qwen2.5-vl patchmerger
seq_len, _ = features.shape seq_len, _ = features.shape
features = features.reshape(seq_len // (self.config.hidden_stride**2), features = features.reshape(seq_len // (self.config.hidden_stride**2),
@ -125,7 +129,8 @@ class VisualTokenizer(torch.nn.Module):
return features return features
def forward(self, pixel_values, grid_thws) -> torch.Tensor: def forward(self, pixel_values: torch.Tensor,
grid_thws: torch.Tensor) -> torch.Tensor:
features = self.encode(pixel_values, grid_thws) features = self.encode(pixel_values, grid_thws)
logits = self.head(features) logits = self.head(features)
tokens = self.tokenize(logits) tokens = self.tokenize(logits)
@ -395,7 +400,7 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo]
@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor, @MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor,
info=Ovis2_5ProcessingInfo, info=Ovis2_5ProcessingInfo,
dummy_inputs=Ovis2_5DummyInputsBuilder) dummy_inputs=Ovis2_5DummyInputsBuilder)
class Ovis2_5(nn.Module, SupportsMultiModal): class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
@ -421,9 +426,8 @@ class Ovis2_5(nn.Module, SupportsMultiModal):
text_model_type = self.config.get_text_config().model_type text_model_type = self.config.get_text_config().model_type
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
# TODO(Isotr0py): PP support self.make_empty_intermediate_tensors = (
# self.make_empty_intermediate_tensors = ( self.get_language_model().make_empty_intermediate_tensors)
# self.language_model.make_empty_intermediate_tensors)
def _parse_and_validate_visual_input( def _parse_and_validate_visual_input(
self, is_video, self, is_video,
@ -567,4 +571,4 @@ class Ovis2_5(nn.Module, SupportsMultiModal):
return loader.load_weights(weights) return loader.load_weights(weights)
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.llm return self.llm

View File

@ -3,16 +3,24 @@
"""Implementation of SiglipVisionModel intended to be only used """Implementation of SiglipVisionModel intended to be only used
within a vision language model.""" within a vision language model."""
from typing import Optional, Union from collections.abc import Iterable
from typing import Optional
import torch import torch
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from transformers.activations import ACT2FN from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutputWithNoAttention
from vllm.config import QuantizationConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import _Backend from vllm.platforms import _Backend
from .vision import get_vit_attn_backend from .vision import get_vit_attn_backend
@ -48,10 +56,11 @@ class Siglip2VisionEmbeddings(nn.Module):
# siglip2 naflex # siglip2 naflex
if self.num_patches > 0: if self.num_patches > 0:
self.patch_embedding = nn.Linear( self.patch_embedding = ReplicatedLinear(
in_features=config.num_channels * self.patch_size * input_size=config.num_channels * self.patch_size *
self.patch_size, self.patch_size,
out_features=self.embed_dim, output_size=self.embed_dim,
return_bias=False,
) )
if self.preserve_original_pe: if self.preserve_original_pe:
self.position_embedding_size = int(self.num_patches**0.5) self.position_embedding_size = int(self.num_patches**0.5)
@ -89,7 +98,7 @@ class Siglip2VisionEmbeddings(nn.Module):
# Apply patch embeddings to already patchified pixel values # Apply patch embeddings to already patchified pixel values
target_dtype = self.patch_embedding.weight.dtype target_dtype = self.patch_embedding.weight.dtype
if isinstance(self.patch_embedding, nn.Linear): if isinstance(self.patch_embedding, LinearBase):
patch_embeds = self.patch_embedding( patch_embeds = self.patch_embedding(
pixel_values.to(dtype=target_dtype)) pixel_values.to(dtype=target_dtype))
elif isinstance(self.patch_embedding, nn.Conv2d): elif isinstance(self.patch_embedding, nn.Conv2d):
@ -184,7 +193,13 @@ def apply_rotary_pos_emb(
class Siglip2Attention(nn.Module): class Siglip2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config): def __init__(
self,
config: Siglip2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -199,11 +214,25 @@ class Siglip2Attention(nn.Module):
self.dropout = config.attention_dropout self.dropout = config.attention_dropout
self.is_causal = False self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) # TODO(Isotr0py): Enable data parallel after we support
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) # disabling TP on parallel linear layer
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.qkv_proj = QKVParallelLinear(
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.use_rope = config.use_rope self.use_rope = config.use_rope
# Detect attention implementation. # Detect attention implementation.
@ -228,13 +257,15 @@ class Siglip2Attention(nn.Module):
seq_length, embed_dim = hidden_states.shape seq_length, embed_dim = hidden_states.shape
queries = self.q_proj(hidden_states) qkv_states, _ = self.qkv_proj(hidden_states)
keys = self.k_proj(hidden_states) queries, keys, values = qkv_states.chunk(3, dim=-1)
values = self.v_proj(hidden_states)
queries = queries.view(seq_length, self.num_heads, self.head_dim) queries = queries.view(seq_length, self.num_heads_per_partition,
keys = keys.view(seq_length, self.num_heads, self.head_dim) self.head_dim)
values = values.view(seq_length, self.num_heads, self.head_dim) keys = keys.view(seq_length, self.num_heads_per_partition,
self.head_dim)
values = values.view(seq_length, self.num_heads_per_partition,
self.head_dim)
if self.use_rope: if self.use_rope:
cos, sin = position_embeddings cos, sin = position_embeddings
@ -276,41 +307,72 @@ class Siglip2Attention(nn.Module):
v_i, v_i,
dropout_p=0.0) dropout_p=0.0)
# (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim) # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
output_i = output_i.transpose(1, 2).reshape(-1, self.embed_dim) output_i = output_i.transpose(1, 2).reshape(
end_idx - start_idx, -1)
outputs.append(output_i) outputs.append(output_i)
attn_output = torch.cat(outputs, dim=0) attn_output = torch.cat(outputs, dim=0)
attn_output = self.out_proj(attn_output) attn_output, _ = self.out_proj(attn_output)
return attn_output return attn_output
class Siglip2MLP(nn.Module): class Siglip2MLP(nn.Module):
def __init__(self, config): def __init__(
self,
config: Siglip2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.activation_fn = ACT2FN[config.hidden_act] self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) # TODO(Isotr0py): Enable data parallel after we support
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) # disabling TP on parallel linear layer
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states) hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states) hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states) hidden_states, _ = self.fc2(hidden_states)
return hidden_states return hidden_states
class Siglip2EncoderLayer(nn.Module): class Siglip2EncoderLayer(nn.Module):
def __init__(self, config: PretrainedConfig): def __init__(
self,
config: Siglip2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.layer_norm1 = nn.LayerNorm(self.embed_dim, self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.self_attn = Siglip2Attention(config) self.self_attn = Siglip2Attention(config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(config) self.mlp = Siglip2MLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]: position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]:
@ -347,14 +409,22 @@ class Siglip2Encoder(nn.Module):
config: PretrainedConfig config: PretrainedConfig
""" """
def __init__(self, config: PretrainedConfig): def __init__(
self,
config: Siglip2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Siglip2EncoderLayer(config) Siglip2EncoderLayer(config,
for _ in range(config.num_hidden_layers) quant_config=quant_config,
prefix=f"{prefix}.layers.{idx}",
use_data_parallel=use_data_parallel)
for idx in range(config.num_hidden_layers)
]) ])
self.gradient_checkpointing = False
self.rotary_pos_emb = VisionRotaryEmbedding( self.rotary_pos_emb = VisionRotaryEmbedding(
config.hidden_size // config.num_attention_heads // 2) config.hidden_size // config.num_attention_heads // 2)
@ -445,13 +515,11 @@ class Siglip2Encoder(nn.Module):
return window_index, cu_window_seqlens return window_index, cu_window_seqlens
# Ignore copy
def forward( def forward(
self, self,
inputs_embeds, inputs_embeds: torch.Tensor,
grid_thws: torch.Tensor, grid_thws: torch.Tensor,
output_hidden_states: bool = False, ) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, ...]]]:
r""" r"""
Args: Args:
inputs_embeds (`torch.FloatTensor` of shape inputs_embeds (`torch.FloatTensor` of shape
@ -506,7 +574,6 @@ class Siglip2Encoder(nn.Module):
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
reverse_indices = torch.argsort(window_index) reverse_indices = torch.argsort(window_index)
encoder_states = () if output_hidden_states else None
hidden_states = inputs_embeds hidden_states = inputs_embeds
for index, block in enumerate(self.layers): for index, block in enumerate(self.layers):
@ -517,45 +584,40 @@ class Siglip2Encoder(nn.Module):
cu_seqlens_tmp = cu_window_seqlens cu_seqlens_tmp = cu_window_seqlens
hidden_states = block(hidden_states, cu_seqlens_tmp, hidden_states = block(hidden_states, cu_seqlens_tmp,
position_embeddings) position_embeddings)
if output_hidden_states:
hidden_states_ = hidden_states.reshape(
seq_len // self.spatial_merge_unit,
self.spatial_merge_unit, -1)
encoder_states += (hidden_states_[reverse_indices, :].reshape(
seq_len, -1), )
# tokens = self.post_trunk_norm(tokens)
hidden_states = hidden_states.reshape( hidden_states = hidden_states.reshape(
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1) hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)
return hidden_states, encoder_states return hidden_states
class Siglip2VisionTransformer(nn.Module): class Siglip2VisionTransformer(nn.Module):
def __init__(self, config: PretrainedConfig): def __init__(
self,
config: Siglip2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__() super().__init__()
self.config = config self.config = config
embed_dim = config.hidden_size embed_dim = config.hidden_size
self.embeddings = Siglip2VisionEmbeddings(config) self.embeddings = Siglip2VisionEmbeddings(config)
self.encoder = Siglip2Encoder(config) self.encoder = Siglip2Encoder(config,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel)
self.post_layernorm = nn.LayerNorm(embed_dim, self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self._use_flash_attention_2 = \
(config._attn_implementation == "flash_attention_2")
def forward( def forward(
self, self,
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
grid_thws: torch.LongTensor, grid_thws: torch.LongTensor,
output_hidden_states: Optional[bool] = True, ) -> torch.Tensor:
return_dict: Optional[bool] = True,
) -> Union[
tuple[torch.Tensor],
tuple[torch.Tensor, tuple[torch.Tensor, ...]],
BaseModelOutputWithNoAttention,
]:
r""" r"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) Tensor containing the spatial dimensions (height, width)
@ -563,45 +625,64 @@ class Siglip2VisionTransformer(nn.Module):
""" """
hidden_states = self.embeddings(pixel_values, grid_thws) hidden_states = self.embeddings(pixel_values, grid_thws)
last_hidden_state, hidden_states = self.encoder( last_hidden_state = self.encoder(hidden_states, grid_thws)
hidden_states, grid_thws, output_hidden_states)
last_hidden_state = self.post_layernorm(last_hidden_state) last_hidden_state = self.post_layernorm(last_hidden_state)
if not return_dict:
output = (last_hidden_state, )
output += (hidden_states, ) if output_hidden_states else ()
return output
return last_hidden_state return last_hidden_state
class Siglip2NavitModel(torch.nn.Module): class Siglip2NavitModel(torch.nn.Module):
def __init__(self, config: PretrainedConfig): def __init__(
self,
config: Siglip2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__() super().__init__()
self.vision_model = Siglip2VisionTransformer(config) self.vision_model = Siglip2VisionTransformer(
config,
quant_config=quant_config,
prefix=f"{prefix}.vision_model",
use_data_parallel=use_data_parallel)
def forward( def forward(
self, self,
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
grid_thws: torch.LongTensor, grid_thws: torch.LongTensor,
output_hidden_states: Optional[bool] = None, ) -> torch.Tensor:
return_dict: Optional[bool] = None,
) -> Union[
tuple[torch.Tensor],
tuple[torch.Tensor, tuple[torch.Tensor, ...]],
BaseModelOutputWithNoAttention,
]:
if output_hidden_states is None:
output_hidden_states = self.config.output_hidden_states
if return_dict is None:
return_dict = self.config.use_return_dict
return self.vision_model( return self.vision_model(
pixel_values=pixel_values, pixel_values=pixel_values,
grid_thws=grid_thws, grid_thws=grid_thws,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
) )
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params