mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 08:44:58 +08:00
[Model][LoRA]LoRA support added for MiniCPMV2.6 (#8943)
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
b6d7392579
commit
8e60afa15e
@ -65,11 +65,10 @@ class Idefics2VisionEmbeddings(nn.Module):
|
|||||||
self.position_embedding = nn.Embedding(self.num_positions,
|
self.position_embedding = nn.Embedding(self.num_positions,
|
||||||
self.embed_dim)
|
self.embed_dim)
|
||||||
|
|
||||||
def forward(
|
def forward(self,
|
||||||
self,
|
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
patch_attention_mask: torch.BoolTensor,
|
patch_attention_mask: torch.BoolTensor,
|
||||||
) -> torch.Tensor:
|
tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
|
||||||
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
||||||
patch_embeds = self.patch_embedding(pixel_values)
|
patch_embeds = self.patch_embedding(pixel_values)
|
||||||
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
@ -84,6 +83,11 @@ class Idefics2VisionEmbeddings(nn.Module):
|
|||||||
fill_value=0)
|
fill_value=0)
|
||||||
|
|
||||||
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
||||||
|
|
||||||
|
if tgt_sizes is not None:
|
||||||
|
nb_patches_h = tgt_sizes[batch_idx][0]
|
||||||
|
nb_patches_w = tgt_sizes[batch_idx][1]
|
||||||
|
else:
|
||||||
nb_patches_h = p_attn_mask[:, 0].sum()
|
nb_patches_h = p_attn_mask[:, 0].sum()
|
||||||
nb_patches_w = p_attn_mask[0].sum()
|
nb_patches_w = p_attn_mask[0].sum()
|
||||||
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
||||||
@ -287,10 +291,12 @@ class Idefics2VisionTransformer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
pixel_values,
|
pixel_values,
|
||||||
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
) -> torch.tensor:
|
tgt_sizes: Optional[torch.IntTensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embeddings(
|
hidden_states = self.embeddings(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
patch_attention_mask=patch_attention_mask)
|
patch_attention_mask=patch_attention_mask,
|
||||||
|
tgt_sizes=tgt_sizes)
|
||||||
encoder_outputs = self.encoder(hidden_states)
|
encoder_outputs = self.encoder(hidden_states)
|
||||||
last_hidden_state = self.post_layernorm(encoder_outputs)
|
last_hidden_state = self.post_layernorm(encoder_outputs)
|
||||||
return last_hidden_state
|
return last_hidden_state
|
||||||
|
|||||||
@ -31,17 +31,15 @@ import torch
|
|||||||
import torch.types
|
import torch.types
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.init import trunc_normal_
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.resampler import (Resampler2,
|
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
|
||||||
get_2d_sincos_pos_embed)
|
get_2d_sincos_pos_embed)
|
||||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
@ -106,58 +104,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
|
|||||||
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
|
||||||
|
|
||||||
class BaseResampler(nn.Module):
|
|
||||||
"""
|
|
||||||
A 2D perceiver-resampler network with one cross attention layers by
|
|
||||||
(grid_size**2) learnable queries and 2d sincos pos_emb
|
|
||||||
Outputs:
|
|
||||||
A tensor with the shape of (grid_size**2, embed_dim)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_queries: int,
|
|
||||||
embed_dim: int,
|
|
||||||
num_heads: int,
|
|
||||||
kv_dim: Optional[int] = None,
|
|
||||||
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.num_queries = num_queries
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.num_heads = num_heads
|
|
||||||
|
|
||||||
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
|
||||||
trunc_normal_(self.query, std=0.02)
|
|
||||||
if kv_dim is not None and kv_dim != embed_dim:
|
|
||||||
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
|
|
||||||
else:
|
|
||||||
# Maintain the same return value with ReplicatedLinear.forward
|
|
||||||
self.kv_proj = lambda *args, **kwargs: (
|
|
||||||
nn.Identity()(*args, **kwargs),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
|
||||||
self.ln_q = norm_layer(embed_dim)
|
|
||||||
self.ln_kv = norm_layer(embed_dim)
|
|
||||||
self.ln_post = norm_layer(embed_dim)
|
|
||||||
self.proj = nn.Parameter(
|
|
||||||
(embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
|
|
||||||
|
|
||||||
def _init_weights(self, m: nn.Module) -> None:
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
trunc_normal_(m.weight, std=0.02)
|
|
||||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.LayerNorm):
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
nn.init.constant_(m.weight, 1.0)
|
|
||||||
|
|
||||||
def _repeat(self, query, N: int):
|
|
||||||
return query.unsqueeze(1).repeat(1, N, 1)
|
|
||||||
|
|
||||||
|
|
||||||
class Resampler2_5(BaseResampler):
|
class Resampler2_5(BaseResampler):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -869,7 +815,35 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
|||||||
return "resampler" in name
|
return "resampler" in name
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMV2_6(MiniCPMVBaseModel):
|
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
# vision encoder
|
||||||
|
"fc1",
|
||||||
|
"fc2",
|
||||||
|
"out_proj",
|
||||||
|
# language model
|
||||||
|
"qkv_proj", # same name with vision encoder
|
||||||
|
"o_proj",
|
||||||
|
"gate_up_proj",
|
||||||
|
"down_proj",
|
||||||
|
# resampler
|
||||||
|
"kv_proj",
|
||||||
|
]
|
||||||
|
|
||||||
|
embedding_modules = {}
|
||||||
|
embedding_padding_modules = []
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -894,15 +868,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
|||||||
name="model")
|
name="model")
|
||||||
|
|
||||||
def init_vision_module(self) -> nn.Module:
|
def init_vision_module(self) -> nn.Module:
|
||||||
# A custom version of SiglipVisionTransformer, won't work with TP
|
|
||||||
from vllm.model_executor.models.na_vit import SiglipVisionTransformer
|
|
||||||
|
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
model = Idefics2VisionTransformer(self.config.vision_config)
|
||||||
self.config.vision_config._attn_implementation = "flash_attention_2"
|
|
||||||
else:
|
|
||||||
# not support sdpa
|
|
||||||
self.config.vision_config._attn_implementation = "eager"
|
|
||||||
model = SiglipVisionTransformer(self.config.vision_config)
|
|
||||||
if self.config.drop_vision_last_layer:
|
if self.config.drop_vision_last_layer:
|
||||||
model.encoder.layers = model.encoder.layers[:-1]
|
model.encoder.layers = model.encoder.layers[:-1]
|
||||||
return model
|
return model
|
||||||
@ -928,7 +895,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
|||||||
pixel_values,
|
pixel_values,
|
||||||
patch_attention_mask=patch_attn_mask,
|
patch_attention_mask=patch_attn_mask,
|
||||||
tgt_sizes=tgt_sizes,
|
tgt_sizes=tgt_sizes,
|
||||||
).last_hidden_state
|
)
|
||||||
return vision_embedding
|
return vision_embedding
|
||||||
|
|
||||||
def get_vision_hidden_states(
|
def get_vision_hidden_states(
|
||||||
@ -960,12 +927,12 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
|||||||
all_pixel_values.type(dtype),
|
all_pixel_values.type(dtype),
|
||||||
patch_attention_mask=patch_attn_mask,
|
patch_attention_mask=patch_attn_mask,
|
||||||
tgt_sizes=tgt_sizes,
|
tgt_sizes=tgt_sizes,
|
||||||
).last_hidden_state
|
)
|
||||||
|
|
||||||
return self.resampler(vision_embedding, tgt_sizes)
|
return self.resampler(vision_embedding, tgt_sizes)
|
||||||
|
|
||||||
def is_default_weight_loading(self, name: str) -> bool:
|
def is_default_weight_loading(self, name: str) -> bool:
|
||||||
return "resampler" in name or "vpm" in name
|
return "resampler" in name
|
||||||
|
|
||||||
|
|
||||||
_SUPPORT_VERSION = {
|
_SUPPORT_VERSION = {
|
||||||
|
|||||||
@ -1,804 +0,0 @@
|
|||||||
import logging
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import warnings
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
|
||||||
from transformers.activations import ACT2FN
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
|
||||||
from transformers.modeling_outputs import (BaseModelOutput,
|
|
||||||
BaseModelOutputWithPooling)
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from transformers.utils import (ModelOutput, is_flash_attn_2_available,
|
|
||||||
replace_return_docstrings)
|
|
||||||
|
|
||||||
logger = logging.getLogger("vllm")
|
|
||||||
|
|
||||||
|
|
||||||
# For Siglip: copied from
|
|
||||||
# HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
|
|
||||||
# Remove hints as there's little possibility to change these code.
|
|
||||||
class SiglipVisionConfig(PretrainedConfig):
|
|
||||||
|
|
||||||
model_type = "siglip_vision_model"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size=768,
|
|
||||||
intermediate_size=3072,
|
|
||||||
num_hidden_layers=12,
|
|
||||||
num_attention_heads=12,
|
|
||||||
num_channels=3,
|
|
||||||
image_size=224,
|
|
||||||
patch_size=16,
|
|
||||||
hidden_act="gelu_pytorch_tanh",
|
|
||||||
layer_norm_eps=1e-6,
|
|
||||||
attention_dropout=0.0,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_channels = num_channels
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.image_size = image_size
|
|
||||||
self.attention_dropout = attention_dropout
|
|
||||||
self.layer_norm_eps = layer_norm_eps
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str,
|
|
||||||
os.PathLike],
|
|
||||||
**kwargs) -> "PretrainedConfig":
|
|
||||||
cls._set_token_in_kwargs(kwargs)
|
|
||||||
|
|
||||||
config_dict, kwargs = cls.get_config_dict(
|
|
||||||
pretrained_model_name_or_path, **kwargs)
|
|
||||||
|
|
||||||
# get the vision config dict if we are loading from SiglipConfig
|
|
||||||
if config_dict.get("model_type") == "siglip":
|
|
||||||
config_dict = config_dict["vision_config"]
|
|
||||||
|
|
||||||
if "model_type" in config_dict and hasattr(
|
|
||||||
cls,
|
|
||||||
"model_type") and config_dict["model_type"] != cls.model_type:
|
|
||||||
logger.warning(
|
|
||||||
"You are using a model of type %s to "
|
|
||||||
"instantiate a model of type %s. "
|
|
||||||
"This is not supported for all configurations"
|
|
||||||
"of models and can yield errors.", config_dict['model_type'],
|
|
||||||
cls.model_type)
|
|
||||||
|
|
||||||
return cls.from_dict(config_dict, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
|
|
||||||
|
|
||||||
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
||||||
"google/siglip-base-patch16-224",
|
|
||||||
# See all SigLIP models at https://huggingface.co/models?filter=siglip
|
|
||||||
]
|
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
||||||
from flash_attn.bert_padding import pad_input # noqa
|
|
||||||
from flash_attn.bert_padding import index_first_axis, unpad_input
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
||||||
def _get_unpad_data(attention_mask):
|
|
||||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
||||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
||||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
|
||||||
cu_seqlens = F.pad(
|
|
||||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
|
||||||
return (
|
|
||||||
indices,
|
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen_in_batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _trunc_normal_(tensor, mean, std, a, b):
|
|
||||||
|
|
||||||
def norm_cdf(x):
|
|
||||||
# Computes standard normal cumulative distribution function
|
|
||||||
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
|
||||||
|
|
||||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
|
||||||
warnings.warn(
|
|
||||||
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
|
||||||
"The distribution of values may be incorrect.",
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Values are generated by using a truncated uniform distribution and
|
|
||||||
# then using the inverse CDF for the normal distribution.
|
|
||||||
# Get upper and lower cdf values
|
|
||||||
l_ = norm_cdf((a - mean) / std)
|
|
||||||
u = norm_cdf((b - mean) / std)
|
|
||||||
|
|
||||||
# Uniformly fill tensor with values from [l, u], then translate to
|
|
||||||
# [2l-1, 2u-1].
|
|
||||||
tensor.uniform_(2 * l_ - 1, 2 * u - 1)
|
|
||||||
|
|
||||||
# Use inverse cdf transform for normal distribution to get truncated
|
|
||||||
# standard normal
|
|
||||||
if tensor.dtype in [torch.float16, torch.bfloat16]:
|
|
||||||
# The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
|
|
||||||
og_dtype = tensor.dtype
|
|
||||||
tensor = tensor.to(torch.float32)
|
|
||||||
tensor.erfinv_()
|
|
||||||
tensor = tensor.to(og_dtype)
|
|
||||||
else:
|
|
||||||
tensor.erfinv_()
|
|
||||||
|
|
||||||
# Transform to proper mean, std
|
|
||||||
tensor.mul_(std * math.sqrt(2.0))
|
|
||||||
tensor.add_(mean)
|
|
||||||
|
|
||||||
# Clamp to ensure it's in the proper range
|
|
||||||
if tensor.dtype == torch.float16:
|
|
||||||
# The `clamp_` op is not (yet?) defined in float16+cpu
|
|
||||||
tensor = tensor.to(torch.float32)
|
|
||||||
tensor.clamp_(min=a, max=b)
|
|
||||||
tensor = tensor.to(torch.float16)
|
|
||||||
else:
|
|
||||||
tensor.clamp_(min=a, max=b)
|
|
||||||
|
|
||||||
|
|
||||||
def trunc_normal_tf_(tensor: torch.Tensor,
|
|
||||||
mean: float = 0.0,
|
|
||||||
std: float = 1.0,
|
|
||||||
a: float = -2.0,
|
|
||||||
b: float = 2.0) -> torch.Tensor:
|
|
||||||
with torch.no_grad():
|
|
||||||
_trunc_normal_(tensor, 0, 1.0, a, b)
|
|
||||||
tensor.mul_(std).add_(mean)
|
|
||||||
|
|
||||||
|
|
||||||
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
|
||||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
|
||||||
if mode == "fan_in":
|
|
||||||
denom = fan_in
|
|
||||||
elif mode == "fan_out":
|
|
||||||
denom = fan_out
|
|
||||||
elif mode == "fan_avg":
|
|
||||||
denom = (fan_in + fan_out) / 2
|
|
||||||
|
|
||||||
variance = scale / denom
|
|
||||||
|
|
||||||
if distribution == "truncated_normal":
|
|
||||||
# constant is stddev of standard normal truncated to (-2, 2)
|
|
||||||
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
|
||||||
elif distribution == "normal":
|
|
||||||
with torch.no_grad():
|
|
||||||
tensor.normal_(std=math.sqrt(variance))
|
|
||||||
elif distribution == "uniform":
|
|
||||||
bound = math.sqrt(3 * variance)
|
|
||||||
with torch.no_grad():
|
|
||||||
tensor.uniform_(-bound, bound)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"invalid distribution {distribution}")
|
|
||||||
|
|
||||||
|
|
||||||
def lecun_normal_(tensor):
|
|
||||||
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
|
||||||
|
|
||||||
|
|
||||||
def default_flax_embed_init(tensor):
|
|
||||||
variance_scaling_(tensor, mode="fan_in", distribution="normal")
|
|
||||||
|
|
||||||
|
|
||||||
class SiglipVisionModelOutput(ModelOutput):
|
|
||||||
image_embeds: Optional[torch.FloatTensor] = None
|
|
||||||
last_hidden_state: torch.FloatTensor = None
|
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
|
||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class SiglipVisionEmbeddings(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config: SiglipVisionConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.embed_dim = config.hidden_size
|
|
||||||
self.image_size = config.image_size
|
|
||||||
self.patch_size = config.patch_size
|
|
||||||
|
|
||||||
self.patch_embedding = nn.Conv2d(
|
|
||||||
in_channels=config.num_channels,
|
|
||||||
out_channels=self.embed_dim,
|
|
||||||
kernel_size=self.patch_size,
|
|
||||||
stride=self.patch_size,
|
|
||||||
padding="valid",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.num_patches_per_side = self.image_size // self.patch_size
|
|
||||||
self.num_patches = self.num_patches_per_side**2
|
|
||||||
self.num_positions = self.num_patches
|
|
||||||
self.position_embedding = nn.Embedding(self.num_positions,
|
|
||||||
self.embed_dim)
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
pixel_values: torch.FloatTensor,
|
|
||||||
patch_attention_mask: torch.BoolTensor,
|
|
||||||
tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
|
|
||||||
batch_size = pixel_values.size(0)
|
|
||||||
|
|
||||||
patch_embeds = self.patch_embedding(pixel_values)
|
|
||||||
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
|
||||||
|
|
||||||
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
|
|
||||||
max_nb_patches_h, max_nb_patches_w = (max_im_h // self.patch_size,
|
|
||||||
max_im_w // self.patch_size)
|
|
||||||
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0,
|
|
||||||
1 / self.num_patches_per_side)
|
|
||||||
position_ids = torch.full(
|
|
||||||
size=(
|
|
||||||
batch_size,
|
|
||||||
max_nb_patches_h * max_nb_patches_w,
|
|
||||||
),
|
|
||||||
fill_value=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
|
||||||
if tgt_sizes is not None:
|
|
||||||
nb_patches_h = tgt_sizes[batch_idx][0]
|
|
||||||
nb_patches_w = tgt_sizes[batch_idx][1]
|
|
||||||
else:
|
|
||||||
nb_patches_h = p_attn_mask[:, 0].sum()
|
|
||||||
nb_patches_w = p_attn_mask[0].sum()
|
|
||||||
|
|
||||||
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
|
||||||
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
|
||||||
|
|
||||||
bucket_coords_h = torch.bucketize(fractional_coords_h,
|
|
||||||
boundaries,
|
|
||||||
right=True)
|
|
||||||
bucket_coords_w = torch.bucketize(fractional_coords_w,
|
|
||||||
boundaries,
|
|
||||||
right=True)
|
|
||||||
|
|
||||||
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side +
|
|
||||||
bucket_coords_w).flatten()
|
|
||||||
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
|
||||||
|
|
||||||
position_ids = position_ids.to(self.position_embedding.weight.device)
|
|
||||||
|
|
||||||
embeddings = embeddings + self.position_embedding(position_ids)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
class SiglipAttention(nn.Module):
|
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
||||||
|
|
||||||
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.embed_dim = config.hidden_size
|
|
||||||
self.num_heads = config.num_attention_heads
|
|
||||||
self.head_dim = self.embed_dim // self.num_heads
|
|
||||||
if self.head_dim * self.num_heads != self.embed_dim:
|
|
||||||
raise ValueError(
|
|
||||||
"embed_dim must be divisible by num_heads (got `embed_dim`: "
|
|
||||||
f"{self.embed_dim} and `num_heads`:"
|
|
||||||
f" {self.num_heads}).")
|
|
||||||
self.scale = self.head_dim**-0.5
|
|
||||||
self.dropout = config.attention_dropout
|
|
||||||
|
|
||||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
||||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
||||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
||||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
|
||||||
Optional[Tuple[torch.Tensor]]]:
|
|
||||||
"""Input shape: Batch x Time x Channel"""
|
|
||||||
|
|
||||||
batch_size, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
query_states = query_states.view(batch_size, q_len, self.num_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
key_states = key_states.view(batch_size, q_len, self.num_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
value_states = value_states.view(batch_size, q_len, self.num_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
k_v_seq_len = key_states.shape[-2]
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(
|
|
||||||
2, 3)) * self.scale
|
|
||||||
|
|
||||||
if attn_weights.size() != (batch_size, self.num_heads, q_len,
|
|
||||||
k_v_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
"Attention weights should be of size "
|
|
||||||
f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
|
||||||
f" {attn_weights.size()}")
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
"Attention mask should be of size "
|
|
||||||
f"{(batch_size, 1, q_len, k_v_seq_len)}",
|
|
||||||
f"but is {attention_mask.size()}")
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights,
|
|
||||||
dim=-1,
|
|
||||||
dtype=torch.float32).to(
|
|
||||||
query_states.dtype)
|
|
||||||
attn_weights = nn.functional.dropout(attn_weights,
|
|
||||||
p=self.dropout,
|
|
||||||
training=self.training)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
|
|
||||||
if attn_output.size() != (batch_size, self.num_heads, q_len,
|
|
||||||
self.head_dim):
|
|
||||||
raise ValueError(
|
|
||||||
"`attn_output` should be of size "
|
|
||||||
f"{(batch_size, self.num_heads, q_len, self.head_dim)}, "
|
|
||||||
"but is"
|
|
||||||
f" {attn_output.size()}")
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
||||||
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
|
||||||
|
|
||||||
return attn_output, attn_weights
|
|
||||||
|
|
||||||
|
|
||||||
class SiglipFlashAttention2(SiglipAttention):
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.is_causal = False # Hack to make sure we don't use a causal mask
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.LongTensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
|
||||||
Optional[Tuple[torch.Tensor]]]:
|
|
||||||
output_attentions = False
|
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
key_states = key_states.view(bsz, q_len, self.num_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
value_states = value_states.view(bsz, q_len, self.num_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value.get_usable_length(
|
|
||||||
kv_seq_len, self.layer_idx)
|
|
||||||
|
|
||||||
query_states = query_states.transpose(1, 2)
|
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
value_states = value_states.transpose(1, 2)
|
|
||||||
|
|
||||||
dropout_rate = self.dropout if self.training else 0.0
|
|
||||||
|
|
||||||
input_dtype = query_states.dtype
|
|
||||||
if input_dtype == torch.float32:
|
|
||||||
if torch.is_autocast_enabled():
|
|
||||||
target_dtype = torch.get_autocast_gpu_dtype()
|
|
||||||
# Handle the case where the model is quantized
|
|
||||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
|
||||||
target_dtype = self.config._pre_quantization_dtype
|
|
||||||
else:
|
|
||||||
target_dtype = self.q_proj.weight.dtype
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
"The input hidden states seems to be "
|
|
||||||
"silently casted in float32, "
|
|
||||||
"this might be related to the fact "
|
|
||||||
"you have upcasted embedding or layer norm layers in float32. "
|
|
||||||
"We will cast back the input in"
|
|
||||||
" %s.", target_dtype)
|
|
||||||
|
|
||||||
query_states = query_states.to(target_dtype)
|
|
||||||
key_states = key_states.to(target_dtype)
|
|
||||||
value_states = value_states.to(target_dtype)
|
|
||||||
|
|
||||||
attn_output = self._flash_attention_forward(query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
attention_mask,
|
|
||||||
q_len,
|
|
||||||
dropout=dropout_rate)
|
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len,
|
|
||||||
self.embed_dim).contiguous()
|
|
||||||
attn_output = self.out_proj(attn_output)
|
|
||||||
|
|
||||||
if not output_attentions:
|
|
||||||
attn_weights = None
|
|
||||||
|
|
||||||
return attn_output, attn_weights
|
|
||||||
|
|
||||||
def _flash_attention_forward(self,
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
attention_mask,
|
|
||||||
query_length,
|
|
||||||
dropout=0.0,
|
|
||||||
softmax_scale=None):
|
|
||||||
causal = self.is_causal and query_length != 1
|
|
||||||
|
|
||||||
# Contains at least one padding token in the sequence
|
|
||||||
if attention_mask is not None:
|
|
||||||
batch_size = query_states.shape[0]
|
|
||||||
(query_states, key_states, value_states, indices_q, cu_seq_lens,
|
|
||||||
max_seq_lens) = self._upad_input(query_states, key_states,
|
|
||||||
value_states, attention_mask,
|
|
||||||
query_length)
|
|
||||||
|
|
||||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
|
||||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
|
||||||
|
|
||||||
attn_output_unpad = flash_attn_varlen_func(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
max_seqlen_q=max_seqlen_in_batch_q,
|
|
||||||
max_seqlen_k=max_seqlen_in_batch_k,
|
|
||||||
dropout_p=dropout,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=causal,
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
|
|
||||||
query_length)
|
|
||||||
else:
|
|
||||||
attn_output = flash_attn_func(query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
dropout,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=causal)
|
|
||||||
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
|
|
||||||
query_length):
|
|
||||||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
|
|
||||||
attention_mask)
|
|
||||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
|
||||||
|
|
||||||
key_layer = index_first_axis(
|
|
||||||
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
|
|
||||||
head_dim), indices_k)
|
|
||||||
value_layer = index_first_axis(
|
|
||||||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
|
|
||||||
head_dim), indices_k)
|
|
||||||
if query_length == kv_seq_len:
|
|
||||||
query_layer = index_first_axis(
|
|
||||||
query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
|
|
||||||
head_dim), indices_k)
|
|
||||||
cu_seqlens_q = cu_seqlens_k
|
|
||||||
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
|
||||||
indices_q = indices_k
|
|
||||||
elif query_length == 1:
|
|
||||||
max_seqlen_in_batch_q = 1
|
|
||||||
cu_seqlens_q = torch.arange(
|
|
||||||
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
|
||||||
) # There is a memcpy here, that is very bad.
|
|
||||||
indices_q = cu_seqlens_q[:-1]
|
|
||||||
query_layer = query_layer.squeeze(1)
|
|
||||||
else:
|
|
||||||
# The -q_len: slice assumes left padding.
|
|
||||||
attention_mask = attention_mask[:, -query_length:]
|
|
||||||
(query_layer, indices_q, cu_seqlens_q,
|
|
||||||
max_seqlen_in_batch_q) = unpad_input(query_layer, attention_mask)
|
|
||||||
|
|
||||||
return (
|
|
||||||
query_layer,
|
|
||||||
key_layer,
|
|
||||||
value_layer,
|
|
||||||
indices_q,
|
|
||||||
(cu_seqlens_q, cu_seqlens_k),
|
|
||||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
|
|
||||||
class SiglipMLP(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.activation_fn = ACT2FN[config.hidden_act]
|
|
||||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
||||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
||||||
hidden_states = self.fc1(hidden_states)
|
|
||||||
hidden_states = self.activation_fn(hidden_states)
|
|
||||||
hidden_states = self.fc2(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer
|
|
||||||
# with CLIP->Siglip
|
|
||||||
class SiglipEncoderLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config: SiglipVisionConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.embed_dim = config.hidden_size
|
|
||||||
self._use_flash_attention_2 = (
|
|
||||||
config._attn_implementation == "flash_attention_2")
|
|
||||||
self.self_attn = (SiglipAttention(config)
|
|
||||||
if not self._use_flash_attention_2 else
|
|
||||||
SiglipFlashAttention2(config))
|
|
||||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
|
||||||
eps=config.layer_norm_eps)
|
|
||||||
self.mlp = SiglipMLP(config)
|
|
||||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
|
||||||
eps=config.layer_norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
) -> Tuple[torch.FloatTensor]:
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
hidden_states = self.layer_norm1(hidden_states)
|
|
||||||
hidden_states, attn_weights = self.self_attn(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.layer_norm2(hidden_states)
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
outputs = (hidden_states, )
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
outputs += (attn_weights, )
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class SiglipPreTrainedModel(PreTrainedModel):
|
|
||||||
config_class = SiglipVisionConfig
|
|
||||||
base_model_prefix = "siglip"
|
|
||||||
supports_gradient_checkpointing = True
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
"""Initialize the weights"""
|
|
||||||
|
|
||||||
if isinstance(module, SiglipVisionEmbeddings):
|
|
||||||
width = self.config.hidden_size
|
|
||||||
nn.init.normal_(module.position_embedding.weight,
|
|
||||||
std=1 / np.sqrt(width))
|
|
||||||
elif isinstance(module, nn.Embedding):
|
|
||||||
default_flax_embed_init(module.weight)
|
|
||||||
elif isinstance(module, SiglipAttention):
|
|
||||||
nn.init.normal_(module.q_proj.weight)
|
|
||||||
nn.init.normal_(module.k_proj.weight)
|
|
||||||
nn.init.normal_(module.v_proj.weight)
|
|
||||||
nn.init.normal_(module.out_proj.weight)
|
|
||||||
nn.init.zeros_(module.q_proj.bias)
|
|
||||||
nn.init.zeros_(module.k_proj.bias)
|
|
||||||
nn.init.zeros_(module.v_proj.bias)
|
|
||||||
nn.init.zeros_(module.out_proj.bias)
|
|
||||||
elif isinstance(module, SiglipMLP):
|
|
||||||
nn.init.normal_(module.fc1.weight)
|
|
||||||
nn.init.normal_(module.fc2.weight)
|
|
||||||
nn.init.normal_(module.fc1.bias, std=1e-6)
|
|
||||||
nn.init.normal_(module.fc2.bias, std=1e-6)
|
|
||||||
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
||||||
lecun_normal_(module.weight)
|
|
||||||
if module.bias is not None:
|
|
||||||
nn.init.zeros_(module.bias)
|
|
||||||
elif isinstance(module, nn.LayerNorm):
|
|
||||||
module.bias.data.zero_()
|
|
||||||
module.weight.data.fill_(1.0)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder
|
|
||||||
# with CLIP->Siglip
|
|
||||||
class SiglipEncoder(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config: SiglipVisionConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.layers = nn.ModuleList([
|
|
||||||
SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)
|
|
||||||
])
|
|
||||||
self.gradient_checkpointing = False
|
|
||||||
|
|
||||||
# Ignore copy
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
inputs_embeds,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, BaseModelOutput]:
|
|
||||||
output_attentions = output_attentions if output_attentions is not None \
|
|
||||||
else self.config.output_attentions
|
|
||||||
output_hidden_states = (output_hidden_states
|
|
||||||
if output_hidden_states is not None else
|
|
||||||
self.config.output_hidden_states)
|
|
||||||
return_dict = return_dict if return_dict is not None \
|
|
||||||
else self.config.use_return_dict
|
|
||||||
|
|
||||||
encoder_states = () if output_hidden_states else None
|
|
||||||
all_attentions = () if output_attentions else None
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
for encoder_layer in self.layers:
|
|
||||||
if output_hidden_states:
|
|
||||||
encoder_states = encoder_states + (hidden_states, )
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
|
||||||
encoder_layer.__call__,
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
output_attentions,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = encoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_attentions = all_attentions + (layer_outputs[1], )
|
|
||||||
|
|
||||||
if output_hidden_states:
|
|
||||||
encoder_states = encoder_states + (hidden_states, )
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(
|
|
||||||
v for v in [hidden_states, encoder_states, all_attentions]
|
|
||||||
if v is not None)
|
|
||||||
return BaseModelOutput(last_hidden_state=hidden_states,
|
|
||||||
hidden_states=encoder_states,
|
|
||||||
attentions=all_attentions)
|
|
||||||
|
|
||||||
|
|
||||||
class SiglipVisionTransformer(SiglipPreTrainedModel):
|
|
||||||
config_class = SiglipVisionConfig
|
|
||||||
main_input_name = "pixel_values"
|
|
||||||
_supports_flash_attn_2 = True
|
|
||||||
|
|
||||||
def __init__(self, config: SiglipVisionConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
self.config = config
|
|
||||||
embed_dim = config.hidden_size
|
|
||||||
|
|
||||||
self.embeddings = SiglipVisionEmbeddings(config)
|
|
||||||
self.encoder = SiglipEncoder(config)
|
|
||||||
self.post_layernorm = nn.LayerNorm(embed_dim,
|
|
||||||
eps=config.layer_norm_eps)
|
|
||||||
self._use_flash_attention_2 = (
|
|
||||||
config._attn_implementation == "flash_attention_2")
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def get_input_embeddings(self) -> nn.Module:
|
|
||||||
return self.embeddings.patch_embedding
|
|
||||||
|
|
||||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling,
|
|
||||||
config_class=SiglipVisionConfig)
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
pixel_values,
|
|
||||||
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
|
||||||
tgt_sizes: Optional[torch.IntTensor] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
||||||
r"""
|
|
||||||
Returns:
|
|
||||||
"""
|
|
||||||
output_attentions = output_attentions if output_attentions is not None \
|
|
||||||
else self.config.output_attentions
|
|
||||||
output_hidden_states = (output_hidden_states
|
|
||||||
if output_hidden_states is not None else
|
|
||||||
self.config.output_hidden_states)
|
|
||||||
return_dict = return_dict if return_dict is not None \
|
|
||||||
else self.config.use_return_dict
|
|
||||||
|
|
||||||
batch_size = pixel_values.size(0)
|
|
||||||
if patch_attention_mask is None:
|
|
||||||
patch_attention_mask = torch.ones(
|
|
||||||
size=(
|
|
||||||
batch_size,
|
|
||||||
pixel_values.size(2) // self.config.patch_size,
|
|
||||||
pixel_values.size(3) // self.config.patch_size,
|
|
||||||
),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=pixel_values.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.embeddings(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
patch_attention_mask=patch_attention_mask,
|
|
||||||
tgt_sizes=tgt_sizes)
|
|
||||||
|
|
||||||
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
|
||||||
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
|
||||||
# So when the `patch_attention_mask` is full of 1s
|
|
||||||
# (i.e. attending to the whole sequence),
|
|
||||||
# avoiding passing the attention_mask,
|
|
||||||
# which is equivalent to attending to the full sequence
|
|
||||||
if not torch.any(~patch_attention_mask):
|
|
||||||
attention_mask = None
|
|
||||||
else:
|
|
||||||
attention_mask = (_prepare_4d_attention_mask(
|
|
||||||
patch_attention_mask, hidden_states.dtype)
|
|
||||||
if not self._use_flash_attention_2 else
|
|
||||||
patch_attention_mask)
|
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
|
||||||
inputs_embeds=hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
last_hidden_state = encoder_outputs[0]
|
|
||||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (last_hidden_state, None) + encoder_outputs[1:]
|
|
||||||
|
|
||||||
return BaseModelOutputWithPooling(
|
|
||||||
last_hidden_state=last_hidden_state,
|
|
||||||
pooler_output=None,
|
|
||||||
hidden_states=encoder_outputs.hidden_states,
|
|
||||||
attentions=encoder_outputs.attentions,
|
|
||||||
)
|
|
||||||
Loading…
x
Reference in New Issue
Block a user