mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:04:53 +08:00
[Model] Add Ovis2.5 PP support (#23405)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
22cf679aad
commit
32d2b4064f
@ -233,6 +233,7 @@ MULTIMODAL_MODELS = {
|
|||||||
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
|
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
|
||||||
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
|
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
|
||||||
"AIDC-AI/Ovis2-1B": PPTestSettings.fast(),
|
"AIDC-AI/Ovis2-1B": PPTestSettings.fast(),
|
||||||
|
"AIDC-AI/Ovis2.5-2B": PPTestSettings.fast(),
|
||||||
"microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
|
"microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
|
||||||
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
|
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
|
||||||
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
|
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from pathlib import PosixPath
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers import (AutoModel, AutoModelForImageTextToText,
|
from transformers import (AutoModel, AutoModelForImageTextToText,
|
||||||
AutoModelForTextToWaveform, AutoModelForVision2Seq)
|
AutoModelForTextToWaveform, AutoModelForVision2Seq)
|
||||||
from transformers.utils import is_flash_attn_2_available
|
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import identity
|
from vllm.utils import identity
|
||||||
@ -637,10 +636,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
dtype="half",
|
dtype="half",
|
||||||
num_logprobs=10,
|
num_logprobs=10,
|
||||||
patch_hf_runner=model_utils.ovis2_5_patch_hf_runner,
|
patch_hf_runner=model_utils.ovis2_5_patch_hf_runner,
|
||||||
marks=[pytest.mark.skipif(
|
hf_model_kwargs={"revision": "refs/pr/5"},
|
||||||
not is_flash_attn_2_available(),
|
|
||||||
reason="HF model needs `flash_attn` installed"
|
|
||||||
)],
|
|
||||||
),
|
),
|
||||||
"phi3v": VLMTestInfo(
|
"phi3v": VLMTestInfo(
|
||||||
models=["microsoft/Phi-3.5-vision-instruct"],
|
models=["microsoft/Phi-3.5-vision-instruct"],
|
||||||
|
|||||||
@ -468,9 +468,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
|
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
|
||||||
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
|
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
|
||||||
"Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B",
|
"Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True),
|
||||||
max_transformers_version="4.53",
|
|
||||||
transformers_version_reason="HF model is not compatible"), # noqa: E501
|
|
||||||
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
|
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
|
||||||
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
|
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
|
||||||
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
|
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
|
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
|
|
||||||
IMAGE_TOKEN = "<image>"
|
IMAGE_TOKEN = "<image>"
|
||||||
VIDEO_TOKEN = "<video>"
|
VIDEO_TOKEN = "<video>"
|
||||||
@ -70,6 +70,7 @@ class VisualTokenizer(torch.nn.Module):
|
|||||||
visual_vocab_size: int,
|
visual_vocab_size: int,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -77,6 +78,7 @@ class VisualTokenizer(torch.nn.Module):
|
|||||||
config=config,
|
config=config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.vit",
|
prefix=f"{prefix}.vit",
|
||||||
|
use_data_parallel=use_data_parallel,
|
||||||
)
|
)
|
||||||
# reserved tokens for INDICATOR_IDS
|
# reserved tokens for INDICATOR_IDS
|
||||||
head_dim = visual_vocab_size - len(INDICATOR_IDS)
|
head_dim = visual_vocab_size - len(INDICATOR_IDS)
|
||||||
@ -93,31 +95,33 @@ class VisualTokenizer(torch.nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
):
|
):
|
||||||
model_type = config.model_type
|
model_type = config.model_type
|
||||||
if model_type == "siglip2_navit":
|
if model_type == "siglip2_navit":
|
||||||
return Siglip2NavitModel(config=config, )
|
return Siglip2NavitModel(config=config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
use_data_parallel=use_data_parallel)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported visual tokenizer model_type: {model_type}")
|
f"Unsupported visual tokenizer model_type: {model_type}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self) -> torch.dtype:
|
||||||
return next(self.head.parameters()).dtype
|
return next(self.head.parameters()).dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self) -> torch.device:
|
||||||
return next(self.head.parameters()).device
|
return next(self.head.parameters()).device
|
||||||
|
|
||||||
def tokenize(self, logits):
|
def tokenize(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
tokens = torch.softmax(logits, dim=-1,
|
tokens = torch.softmax(logits, dim=-1,
|
||||||
dtype=torch.float32).to(logits.dtype)
|
dtype=torch.float32).to(logits.dtype)
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def encode(self, pixel_values, grid_thws):
|
def encode(self, pixel_values: torch.Tensor,
|
||||||
features = self.vit(pixel_values,
|
grid_thws: torch.Tensor) -> torch.Tensor:
|
||||||
grid_thws,
|
features = self.vit(pixel_values, grid_thws)
|
||||||
output_hidden_states=True,
|
|
||||||
return_dict=True)
|
|
||||||
# refer to qwen2.5-vl patchmerger
|
# refer to qwen2.5-vl patchmerger
|
||||||
seq_len, _ = features.shape
|
seq_len, _ = features.shape
|
||||||
features = features.reshape(seq_len // (self.config.hidden_stride**2),
|
features = features.reshape(seq_len // (self.config.hidden_stride**2),
|
||||||
@ -125,7 +129,8 @@ class VisualTokenizer(torch.nn.Module):
|
|||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def forward(self, pixel_values, grid_thws) -> torch.Tensor:
|
def forward(self, pixel_values: torch.Tensor,
|
||||||
|
grid_thws: torch.Tensor) -> torch.Tensor:
|
||||||
features = self.encode(pixel_values, grid_thws)
|
features = self.encode(pixel_values, grid_thws)
|
||||||
logits = self.head(features)
|
logits = self.head(features)
|
||||||
tokens = self.tokenize(logits)
|
tokens = self.tokenize(logits)
|
||||||
@ -395,7 +400,7 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo]
|
|||||||
@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor,
|
@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor,
|
||||||
info=Ovis2_5ProcessingInfo,
|
info=Ovis2_5ProcessingInfo,
|
||||||
dummy_inputs=Ovis2_5DummyInputsBuilder)
|
dummy_inputs=Ovis2_5DummyInputsBuilder)
|
||||||
class Ovis2_5(nn.Module, SupportsMultiModal):
|
class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -421,9 +426,8 @@ class Ovis2_5(nn.Module, SupportsMultiModal):
|
|||||||
text_model_type = self.config.get_text_config().model_type
|
text_model_type = self.config.get_text_config().model_type
|
||||||
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
|
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
|
||||||
|
|
||||||
# TODO(Isotr0py): PP support
|
self.make_empty_intermediate_tensors = (
|
||||||
# self.make_empty_intermediate_tensors = (
|
self.get_language_model().make_empty_intermediate_tensors)
|
||||||
# self.language_model.make_empty_intermediate_tensors)
|
|
||||||
|
|
||||||
def _parse_and_validate_visual_input(
|
def _parse_and_validate_visual_input(
|
||||||
self, is_video,
|
self, is_video,
|
||||||
@ -567,4 +571,4 @@ class Ovis2_5(nn.Module, SupportsMultiModal):
|
|||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
def get_language_model(self) -> torch.nn.Module:
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
return self.llm
|
return self.llm
|
||||||
|
|||||||
@ -3,16 +3,24 @@
|
|||||||
"""Implementation of SiglipVisionModel intended to be only used
|
"""Implementation of SiglipVisionModel intended to be only used
|
||||||
within a vision language model."""
|
within a vision language model."""
|
||||||
|
|
||||||
from typing import Optional, Union
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from transformers.activations import ACT2FN
|
from transformers import Siglip2VisionConfig
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithNoAttention
|
|
||||||
|
|
||||||
|
from vllm.config import QuantizationConfig
|
||||||
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
LinearBase, QKVParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.platforms import _Backend
|
from vllm.platforms import _Backend
|
||||||
|
|
||||||
from .vision import get_vit_attn_backend
|
from .vision import get_vit_attn_backend
|
||||||
@ -48,10 +56,11 @@ class Siglip2VisionEmbeddings(nn.Module):
|
|||||||
|
|
||||||
# siglip2 naflex
|
# siglip2 naflex
|
||||||
if self.num_patches > 0:
|
if self.num_patches > 0:
|
||||||
self.patch_embedding = nn.Linear(
|
self.patch_embedding = ReplicatedLinear(
|
||||||
in_features=config.num_channels * self.patch_size *
|
input_size=config.num_channels * self.patch_size *
|
||||||
self.patch_size,
|
self.patch_size,
|
||||||
out_features=self.embed_dim,
|
output_size=self.embed_dim,
|
||||||
|
return_bias=False,
|
||||||
)
|
)
|
||||||
if self.preserve_original_pe:
|
if self.preserve_original_pe:
|
||||||
self.position_embedding_size = int(self.num_patches**0.5)
|
self.position_embedding_size = int(self.num_patches**0.5)
|
||||||
@ -89,7 +98,7 @@ class Siglip2VisionEmbeddings(nn.Module):
|
|||||||
|
|
||||||
# Apply patch embeddings to already patchified pixel values
|
# Apply patch embeddings to already patchified pixel values
|
||||||
target_dtype = self.patch_embedding.weight.dtype
|
target_dtype = self.patch_embedding.weight.dtype
|
||||||
if isinstance(self.patch_embedding, nn.Linear):
|
if isinstance(self.patch_embedding, LinearBase):
|
||||||
patch_embeds = self.patch_embedding(
|
patch_embeds = self.patch_embedding(
|
||||||
pixel_values.to(dtype=target_dtype))
|
pixel_values.to(dtype=target_dtype))
|
||||||
elif isinstance(self.patch_embedding, nn.Conv2d):
|
elif isinstance(self.patch_embedding, nn.Conv2d):
|
||||||
@ -184,7 +193,13 @@ def apply_rotary_pos_emb(
|
|||||||
class Siglip2Attention(nn.Module):
|
class Siglip2Attention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Siglip2VisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@ -199,11 +214,25 @@ class Siglip2Attention(nn.Module):
|
|||||||
self.dropout = config.attention_dropout
|
self.dropout = config.attention_dropout
|
||||||
self.is_causal = False
|
self.is_causal = False
|
||||||
|
|
||||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
# TODO(Isotr0py): Enable data parallel after we support
|
||||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
# disabling TP on parallel linear layer
|
||||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
self.qkv_proj = QKVParallelLinear(
|
||||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
hidden_size=self.embed_dim,
|
||||||
|
head_size=self.head_dim,
|
||||||
|
total_num_heads=self.num_heads,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
)
|
||||||
|
self.out_proj = RowParallelLinear(
|
||||||
|
input_size=self.embed_dim,
|
||||||
|
output_size=self.embed_dim,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tp_size = (1 if use_data_parallel else
|
||||||
|
get_tensor_model_parallel_world_size())
|
||||||
|
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||||
self.use_rope = config.use_rope
|
self.use_rope = config.use_rope
|
||||||
|
|
||||||
# Detect attention implementation.
|
# Detect attention implementation.
|
||||||
@ -228,13 +257,15 @@ class Siglip2Attention(nn.Module):
|
|||||||
|
|
||||||
seq_length, embed_dim = hidden_states.shape
|
seq_length, embed_dim = hidden_states.shape
|
||||||
|
|
||||||
queries = self.q_proj(hidden_states)
|
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||||
keys = self.k_proj(hidden_states)
|
queries, keys, values = qkv_states.chunk(3, dim=-1)
|
||||||
values = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
queries = queries.view(seq_length, self.num_heads, self.head_dim)
|
queries = queries.view(seq_length, self.num_heads_per_partition,
|
||||||
keys = keys.view(seq_length, self.num_heads, self.head_dim)
|
self.head_dim)
|
||||||
values = values.view(seq_length, self.num_heads, self.head_dim)
|
keys = keys.view(seq_length, self.num_heads_per_partition,
|
||||||
|
self.head_dim)
|
||||||
|
values = values.view(seq_length, self.num_heads_per_partition,
|
||||||
|
self.head_dim)
|
||||||
|
|
||||||
if self.use_rope:
|
if self.use_rope:
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
@ -276,41 +307,72 @@ class Siglip2Attention(nn.Module):
|
|||||||
v_i,
|
v_i,
|
||||||
dropout_p=0.0)
|
dropout_p=0.0)
|
||||||
# (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
|
# (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
|
||||||
output_i = output_i.transpose(1, 2).reshape(-1, self.embed_dim)
|
output_i = output_i.transpose(1, 2).reshape(
|
||||||
|
end_idx - start_idx, -1)
|
||||||
outputs.append(output_i)
|
outputs.append(output_i)
|
||||||
|
|
||||||
attn_output = torch.cat(outputs, dim=0)
|
attn_output = torch.cat(outputs, dim=0)
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output, _ = self.out_proj(attn_output)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
class Siglip2MLP(nn.Module):
|
class Siglip2MLP(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Siglip2VisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.activation_fn = ACT2FN[config.hidden_act]
|
self.activation_fn = get_act_fn(config.hidden_act)
|
||||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
# TODO(Isotr0py): Enable data parallel after we support
|
||||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
# disabling TP on parallel linear layer
|
||||||
|
self.fc1 = ColumnParallelLinear(
|
||||||
|
config.hidden_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.fc1",
|
||||||
|
)
|
||||||
|
self.fc2 = RowParallelLinear(
|
||||||
|
config.intermediate_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.fc2",
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.fc1(hidden_states)
|
hidden_states, _ = self.fc1(hidden_states)
|
||||||
hidden_states = self.activation_fn(hidden_states)
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
hidden_states = self.fc2(hidden_states)
|
hidden_states, _ = self.fc2(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Siglip2EncoderLayer(nn.Module):
|
class Siglip2EncoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: PretrainedConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Siglip2VisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self.self_attn = Siglip2Attention(config)
|
self.self_attn = Siglip2Attention(config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
use_data_parallel=use_data_parallel)
|
||||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self.mlp = Siglip2MLP(config)
|
self.mlp = Siglip2MLP(config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
use_data_parallel=use_data_parallel)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
|
||||||
position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]:
|
position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]:
|
||||||
@ -347,14 +409,22 @@ class Siglip2Encoder(nn.Module):
|
|||||||
config: PretrainedConfig
|
config: PretrainedConfig
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: PretrainedConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Siglip2VisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
Siglip2EncoderLayer(config)
|
Siglip2EncoderLayer(config,
|
||||||
for _ in range(config.num_hidden_layers)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.layers.{idx}",
|
||||||
|
use_data_parallel=use_data_parallel)
|
||||||
|
for idx in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.gradient_checkpointing = False
|
|
||||||
|
|
||||||
self.rotary_pos_emb = VisionRotaryEmbedding(
|
self.rotary_pos_emb = VisionRotaryEmbedding(
|
||||||
config.hidden_size // config.num_attention_heads // 2)
|
config.hidden_size // config.num_attention_heads // 2)
|
||||||
@ -445,13 +515,11 @@ class Siglip2Encoder(nn.Module):
|
|||||||
|
|
||||||
return window_index, cu_window_seqlens
|
return window_index, cu_window_seqlens
|
||||||
|
|
||||||
# Ignore copy
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
inputs_embeds,
|
inputs_embeds: torch.Tensor,
|
||||||
grid_thws: torch.Tensor,
|
grid_thws: torch.Tensor,
|
||||||
output_hidden_states: bool = False,
|
) -> torch.Tensor:
|
||||||
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, ...]]]:
|
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
inputs_embeds (`torch.FloatTensor` of shape
|
inputs_embeds (`torch.FloatTensor` of shape
|
||||||
@ -506,7 +574,6 @@ class Siglip2Encoder(nn.Module):
|
|||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||||
|
|
||||||
reverse_indices = torch.argsort(window_index)
|
reverse_indices = torch.argsort(window_index)
|
||||||
encoder_states = () if output_hidden_states else None
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
for index, block in enumerate(self.layers):
|
for index, block in enumerate(self.layers):
|
||||||
@ -517,45 +584,40 @@ class Siglip2Encoder(nn.Module):
|
|||||||
cu_seqlens_tmp = cu_window_seqlens
|
cu_seqlens_tmp = cu_window_seqlens
|
||||||
hidden_states = block(hidden_states, cu_seqlens_tmp,
|
hidden_states = block(hidden_states, cu_seqlens_tmp,
|
||||||
position_embeddings)
|
position_embeddings)
|
||||||
if output_hidden_states:
|
|
||||||
hidden_states_ = hidden_states.reshape(
|
|
||||||
seq_len // self.spatial_merge_unit,
|
|
||||||
self.spatial_merge_unit, -1)
|
|
||||||
encoder_states += (hidden_states_[reverse_indices, :].reshape(
|
|
||||||
seq_len, -1), )
|
|
||||||
# tokens = self.post_trunk_norm(tokens)
|
|
||||||
hidden_states = hidden_states.reshape(
|
hidden_states = hidden_states.reshape(
|
||||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||||
hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)
|
hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)
|
||||||
|
|
||||||
return hidden_states, encoder_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Siglip2VisionTransformer(nn.Module):
|
class Siglip2VisionTransformer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: PretrainedConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Siglip2VisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
embed_dim = config.hidden_size
|
embed_dim = config.hidden_size
|
||||||
|
|
||||||
self.embeddings = Siglip2VisionEmbeddings(config)
|
self.embeddings = Siglip2VisionEmbeddings(config)
|
||||||
self.encoder = Siglip2Encoder(config)
|
self.encoder = Siglip2Encoder(config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.encoder",
|
||||||
|
use_data_parallel=use_data_parallel)
|
||||||
self.post_layernorm = nn.LayerNorm(embed_dim,
|
self.post_layernorm = nn.LayerNorm(embed_dim,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self._use_flash_attention_2 = \
|
|
||||||
(config._attn_implementation == "flash_attention_2")
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
grid_thws: torch.LongTensor,
|
grid_thws: torch.LongTensor,
|
||||||
output_hidden_states: Optional[bool] = True,
|
) -> torch.Tensor:
|
||||||
return_dict: Optional[bool] = True,
|
|
||||||
) -> Union[
|
|
||||||
tuple[torch.Tensor],
|
|
||||||
tuple[torch.Tensor, tuple[torch.Tensor, ...]],
|
|
||||||
BaseModelOutputWithNoAttention,
|
|
||||||
]:
|
|
||||||
r"""
|
r"""
|
||||||
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
|
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
|
||||||
Tensor containing the spatial dimensions (height, width)
|
Tensor containing the spatial dimensions (height, width)
|
||||||
@ -563,45 +625,64 @@ class Siglip2VisionTransformer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
hidden_states = self.embeddings(pixel_values, grid_thws)
|
hidden_states = self.embeddings(pixel_values, grid_thws)
|
||||||
|
|
||||||
last_hidden_state, hidden_states = self.encoder(
|
last_hidden_state = self.encoder(hidden_states, grid_thws)
|
||||||
hidden_states, grid_thws, output_hidden_states)
|
|
||||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (last_hidden_state, )
|
|
||||||
output += (hidden_states, ) if output_hidden_states else ()
|
|
||||||
return output
|
|
||||||
|
|
||||||
return last_hidden_state
|
return last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
class Siglip2NavitModel(torch.nn.Module):
|
class Siglip2NavitModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: PretrainedConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Siglip2VisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.vision_model = Siglip2VisionTransformer(config)
|
self.vision_model = Siglip2VisionTransformer(
|
||||||
|
config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.vision_model",
|
||||||
|
use_data_parallel=use_data_parallel)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
grid_thws: torch.LongTensor,
|
grid_thws: torch.LongTensor,
|
||||||
output_hidden_states: Optional[bool] = None,
|
) -> torch.Tensor:
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[
|
|
||||||
tuple[torch.Tensor],
|
|
||||||
tuple[torch.Tensor, tuple[torch.Tensor, ...]],
|
|
||||||
BaseModelOutputWithNoAttention,
|
|
||||||
]:
|
|
||||||
|
|
||||||
if output_hidden_states is None:
|
|
||||||
output_hidden_states = self.config.output_hidden_states
|
|
||||||
if return_dict is None:
|
|
||||||
return_dict = self.config.use_return_dict
|
|
||||||
|
|
||||||
return self.vision_model(
|
return self.vision_model(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
grid_thws=grid_thws,
|
grid_thws=grid_thws,
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user