mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 15:36:29 +08:00
[Bugfix] Clean up MiniCPM-V (#6939)
Co-authored-by: hezhihui <hzh7269@modelbest.cn> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
6512937de1
commit
2f4e108f75
@ -222,9 +222,13 @@ Vision Language Models
|
||||
-
|
||||
* - :code:`MiniCPM-V`
|
||||
- MiniCPM-V
|
||||
- :code:`openbmb/MiniCPM-V-2`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
|
||||
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
|
||||
-
|
||||
|
||||
.. note::
|
||||
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
||||
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
|
||||
|
||||
----
|
||||
|
||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||
|
||||
@ -418,11 +418,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
input_embeds: Optional[torch.Tensor] = None
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
model_output = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
input_embeds)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return model_output
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -370,6 +370,7 @@ class MiniCPMModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
@ -463,11 +464,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_embeds: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, input_embeds)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -20,32 +20,34 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM-V-2 model compatible with HuggingFace weights."""
|
||||
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
|
||||
import math
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.types
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.idefics2.modeling_idefics2 import (
|
||||
Idefics2VisionTransformer)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsVision
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.models.minicpm import MiniCPMModel
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import (cached_get_image_processor,
|
||||
@ -53,12 +55,12 @@ from vllm.multimodal.image import (cached_get_image_processor,
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"language_model.lm_head": "lm_head",
|
||||
"language_model.model": "language_model",
|
||||
"llm.lm_head": "lm_head",
|
||||
"llm.model": "llm",
|
||||
}
|
||||
|
||||
|
||||
def get_abs_pos(abs_pos, tgt_size):
|
||||
def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
|
||||
# abs_pos: L, C
|
||||
# tgt_size: (H, W)
|
||||
# return: M, C
|
||||
@ -75,10 +77,10 @@ def get_abs_pos(abs_pos, tgt_size):
|
||||
|
||||
|
||||
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
||||
def get_2d_sincos_pos_embed(embed_dim,
|
||||
grid_size,
|
||||
cls_token=False,
|
||||
version=2.0):
|
||||
def get_2d_sincos_pos_embed(embed_dim: int,
|
||||
grid_size: Union[int, Tuple[int, int]],
|
||||
cls_token: bool = False,
|
||||
version: Tuple[int, int] = (2, 0)):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
@ -95,7 +97,7 @@ def get_2d_sincos_pos_embed(embed_dim,
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
if version == 2.0:
|
||||
if version == (2, 0):
|
||||
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
|
||||
if cls_token:
|
||||
@ -106,7 +108,9 @@ def get_2d_sincos_pos_embed(embed_dim,
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version=2.0):
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
|
||||
grid: Union[int, Tuple[int, int]],
|
||||
version: Tuple[int, int] = (2, 0)):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
@ -115,14 +119,16 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version=2.0):
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
|
||||
|
||||
if version == 2.0:
|
||||
if version == (2, 0):
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
else:
|
||||
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, version=2.0):
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
|
||||
pos: int,
|
||||
version: Tuple[int, int] = (2, 0)):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,) / (H, W)
|
||||
@ -133,7 +139,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, version=2.0):
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000**omega # (D/2,)
|
||||
|
||||
if version == 2.0:
|
||||
if version == (2, 0):
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
@ -158,19 +164,19 @@ class Resampler(nn.Module):
|
||||
default_norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
def __init__(self,
|
||||
num_queries,
|
||||
grid_size,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kv_dim=None,
|
||||
norm_layer=default_norm_layer,
|
||||
adaptive=False,
|
||||
max_size=(70, 70),
|
||||
version=2.0):
|
||||
num_queries: int,
|
||||
grid_size: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
kv_dim: Optional[int] = None,
|
||||
norm_layer: nn.Module = default_norm_layer,
|
||||
adaptive: bool = False,
|
||||
max_size: Tuple[int, int] = (70, 70),
|
||||
version: Tuple[int, int] = (2, 0)):
|
||||
super().__init__()
|
||||
|
||||
self.version = version
|
||||
if self.version == 2.0:
|
||||
if self.version == (2, 0):
|
||||
self.num_queries = grid_size**2
|
||||
else:
|
||||
self.num_queries = num_queries
|
||||
@ -195,7 +201,7 @@ class Resampler(nn.Module):
|
||||
self.proj = nn.Parameter(
|
||||
(embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
|
||||
|
||||
if self.version == 2.0:
|
||||
if self.version == (2, 0):
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.from_numpy(
|
||||
get_2d_sincos_pos_embed(
|
||||
@ -206,14 +212,17 @@ class Resampler(nn.Module):
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _set_2d_pos_cache(self, max_size, device='cpu'):
|
||||
def _set_2d_pos_cache(self,
|
||||
max_size: Tuple[int, int],
|
||||
device: torch.types.Device = 'cpu'):
|
||||
pos_embed = torch.from_numpy(
|
||||
get_2d_sincos_pos_embed(self.embed_dim,
|
||||
max_size,
|
||||
version=self.version)).float().to(device)
|
||||
self.register_buffer("pos_embed", pos_embed, persistent=False)
|
||||
|
||||
def _adjust_pos_cache(self, tgt_sizes, device):
|
||||
def _adjust_pos_cache(self, tgt_sizes: torch.Tensor,
|
||||
device: torch.types.Device):
|
||||
max_h = torch.max(tgt_sizes[:, 0])
|
||||
max_w = torch.max(tgt_sizes[:, 1])
|
||||
if max_h > self.max_size[0] or max_w > self.max_size[1]:
|
||||
@ -223,7 +232,7 @@ class Resampler(nn.Module):
|
||||
]
|
||||
self._set_2d_pos_cache(self.max_size, device)
|
||||
|
||||
def _init_weights(self, m):
|
||||
def _init_weights(self, m: nn.Module):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
@ -232,7 +241,9 @@ class Resampler(nn.Module):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def forward_2_5(self, x, tgt_sizes=None):
|
||||
def forward_2_5(self,
|
||||
x: torch.Tensor,
|
||||
tgt_sizes: Optional[torch.Tensor] = None):
|
||||
assert x.shape[0] == tgt_sizes.shape[0]
|
||||
bs = x.shape[0]
|
||||
|
||||
@ -278,7 +289,10 @@ class Resampler(nn.Module):
|
||||
x = x @ self.proj
|
||||
return x
|
||||
|
||||
def forward_2(self, x, tgt_sizes=None, attn_mask=None):
|
||||
def forward_2(self,
|
||||
x: torch.Tensor,
|
||||
tgt_sizes: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None):
|
||||
if self.adaptive:
|
||||
pos_embed = torch.Tensor(
|
||||
get_2d_sincos_pos_embed(self.embed_dim,
|
||||
@ -302,8 +316,11 @@ class Resampler(nn.Module):
|
||||
x = x @ self.proj
|
||||
return x
|
||||
|
||||
def forward(self, x, tgt_sizes=None, attn_mask=None):
|
||||
if self.version == 2.0:
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
tgt_sizes: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None):
|
||||
if self.version == (2, 0):
|
||||
return self.forward_2(x, tgt_sizes=tgt_sizes, attn_mask=attn_mask)
|
||||
else:
|
||||
return self.forward_2_5(x, tgt_sizes=tgt_sizes)
|
||||
@ -322,7 +339,7 @@ def dummy_seq_data_for_minicpmv(seq_len: int):
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
def dummy_image_for_minicpmv(hf_config):
|
||||
def dummy_image_for_minicpmv(hf_config: PretrainedConfig):
|
||||
width = height = hf_config.image_size
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return {"image": image}
|
||||
@ -381,7 +398,7 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: PretrainedConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
@ -390,30 +407,48 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.version = float(self.config.version)
|
||||
if not hasattr(self.config, "version"):
|
||||
if self.config.hidden_size == 2304 and self.config.query_num == 64:
|
||||
self.version = (2, 0)
|
||||
else:
|
||||
self.version = (2, 5)
|
||||
else:
|
||||
self.version = str(self.config.version).split(".")
|
||||
self.version = tuple([int(x) for x in self.version])
|
||||
self.llm = self.init_llm(config, cache_config, quant_config)
|
||||
self.vpm = self.init_vision_module()
|
||||
param_dtype = torch.get_default_dtype()
|
||||
self.vpm.to(dtype=param_dtype)
|
||||
self.vision_dim = self.vpm.embed_dim if self.version == 2.0 \
|
||||
self.vision_dim = self.vpm.embed_dim if self.version == (2, 0) \
|
||||
else self.vpm.embeddings.embed_dim
|
||||
self.embed_dim = self.llm.config.hidden_size
|
||||
self.embed_dim = self.config.hidden_size
|
||||
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
|
||||
self.resampler.to(device="cuda", dtype=param_dtype)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def init_llm(self, config, cache_config, quant_config):
|
||||
if self.version == 2.0:
|
||||
return MiniCPMForCausalLM(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
def init_llm(self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
if self.version == (2, 0):
|
||||
return MiniCPMModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
elif self.version == (2, 5):
|
||||
return LlamaModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
return LlamaForCausalLM(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
return Qwen2Model(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def init_vision_module(self):
|
||||
if self.version == 2.0:
|
||||
if self.version == (2, 0):
|
||||
try:
|
||||
import timm
|
||||
except ImportError:
|
||||
@ -433,16 +468,30 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
|
||||
if self.config.drop_vision_last_layer:
|
||||
model.blocks = model.blocks[:-1]
|
||||
else:
|
||||
elif self.version == (2, 5):
|
||||
from transformers.models.idefics2.modeling_idefics2 import (
|
||||
Idefics2VisionTransformer)
|
||||
model = Idefics2VisionTransformer(self.config.vision_config)
|
||||
if self.config.drop_vision_last_layer:
|
||||
model.encoder.layers = model.encoder.layers[:-1]
|
||||
else:
|
||||
from vllm.model_executor.models.na_vit import (
|
||||
SiglipVisionTransformer)
|
||||
if self.config._attn_implementation == 'flash_attention_2':
|
||||
self.config.vision_config._attn_implementation \
|
||||
= 'flash_attention_2'
|
||||
else:
|
||||
# not support sdpa
|
||||
self.config.vision_config._attn_implementation = 'eager'
|
||||
model = SiglipVisionTransformer(self.config.vision_config)
|
||||
if self.config.drop_vision_last_layer:
|
||||
model.encoder.layers = model.encoder.layers[:-1]
|
||||
return model
|
||||
|
||||
def init_resampler(self, embed_dim, vision_dim):
|
||||
def init_resampler(self, embed_dim: int, vision_dim: int):
|
||||
default_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(torch.float16)
|
||||
if self.version == 2.0:
|
||||
if self.version == (2, 0):
|
||||
resampler = Resampler(grid_size=int(
|
||||
math.sqrt(self.config.query_num)),
|
||||
num_queries=None,
|
||||
@ -463,11 +512,11 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
return resampler
|
||||
|
||||
def get_vision_embedding(self,
|
||||
pixel_values,
|
||||
patch_attn_mask=None,
|
||||
tgt_sizes=None,
|
||||
version=2.0):
|
||||
if version == 2.0:
|
||||
pixel_values: List[List[torch.Tensor]],
|
||||
patch_attn_mask: Optional[torch.Tensor] = None,
|
||||
tgt_sizes: Optional[torch.Tensor] = None,
|
||||
version: Tuple[int, int] = (2, 0)):
|
||||
if version == (2, 0):
|
||||
res = []
|
||||
dtype = self.vpm.pos_embed.data.dtype
|
||||
for pixel_value in pixel_values:
|
||||
@ -484,21 +533,32 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
num_prefix_tokens:]
|
||||
res.append(self.resampler(vision_embedding, tgt_size))
|
||||
return torch.vstack(res)
|
||||
else:
|
||||
elif version == (2, 5):
|
||||
vision_embedding = self.vpm(
|
||||
pixel_values.type(dtype),
|
||||
patch_attention_mask=patch_attn_mask).last_hidden_state
|
||||
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
|
||||
else:
|
||||
vision_embedding = self.vpm(pixel_values.type(dtype),
|
||||
patch_attention_mask=patch_attn_mask,
|
||||
tgt_sizes=tgt_sizes).last_hidden_state
|
||||
|
||||
def get_image_bounds(self, input_ids):
|
||||
def get_image_bounds(self, input_ids: torch.Tensor):
|
||||
tokenizer = cached_get_tokenizer(self.config._name_or_path,
|
||||
trust_remote_code=True)
|
||||
im_start_token_id = tokenizer.im_start_id
|
||||
im_end_token_id = tokenizer.im_end_id
|
||||
image_start_tokens = torch.where(input_ids == im_start_token_id)[0]
|
||||
if not hasattr(tokenizer, "slice_start_id"):
|
||||
start_cond = input_ids == tokenizer.im_start_id
|
||||
end_cond = input_ids == tokenizer.im_end_id
|
||||
else:
|
||||
start_cond = (input_ids == tokenizer.im_start_id) | (
|
||||
input_ids == tokenizer.slice_start_id)
|
||||
end_cond = (input_ids == tokenizer.im_end_id) | (
|
||||
input_ids == tokenizer.slice_end_id)
|
||||
|
||||
image_start_tokens = torch.where(start_cond)[0]
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(input_ids == im_end_token_id)[0]
|
||||
valid_image_nums = min(len(image_start_tokens), len(image_end_tokens))
|
||||
image_end_tokens = torch.where(end_cond)[0]
|
||||
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
|
||||
if valid_image_nums == 0:
|
||||
return []
|
||||
image_bound = torch.hstack([
|
||||
@ -508,12 +568,14 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
|
||||
return image_bound
|
||||
|
||||
def get_vision_hidden_states(self, data):
|
||||
def get_vision_hidden_states(self, data: Dict[str,
|
||||
Union[List[torch.Tensor],
|
||||
torch.Tensor]]):
|
||||
if "vision_hidden_states" not in data:
|
||||
pixel_values = data["pixel_values"]
|
||||
tgt_sizes = data["tgt_sizes"]
|
||||
vision_hidden_states = []
|
||||
if self.version == 2.0:
|
||||
if self.version == (2, 0):
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
vision_hidden_states = self.get_vision_embedding(
|
||||
pixel_values)
|
||||
@ -534,17 +596,26 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
B, L, _ = all_pixel_values.shape
|
||||
all_pixel_values = all_pixel_values.permute(
|
||||
0, 2, 1).reshape(B, 3, -1, L)
|
||||
|
||||
patch_attn_mask = torch.zeros((B, 1, max_patches),
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
for i in range(B):
|
||||
patch_attn_mask[i, :tgt_sizes[i][0] *
|
||||
tgt_sizes[i][1]] = True
|
||||
if self.version == (2, 5):
|
||||
for i in range(B):
|
||||
patch_attn_mask[i, :tgt_sizes[i][0] *
|
||||
tgt_sizes[i][1]] = True
|
||||
vision_embedding = self.vpm(
|
||||
all_pixel_values.type(dtype),
|
||||
patch_attention_mask=patch_attn_mask
|
||||
).last_hidden_state
|
||||
else:
|
||||
for i in range(B):
|
||||
patch_attn_mask[i, 0, :tgt_sizes[i][0] *
|
||||
tgt_sizes[i][1]] = True
|
||||
vision_embedding = self.vpm(
|
||||
all_pixel_values.type(dtype),
|
||||
patch_attention_mask=patch_attn_mask,
|
||||
tgt_sizes=tgt_sizes).last_hidden_state
|
||||
|
||||
vision_embedding = self.vpm(
|
||||
all_pixel_values.type(dtype),
|
||||
patch_attention_mask=patch_attn_mask).last_hidden_state
|
||||
vision_hidden_states = self.resampler(
|
||||
vision_embedding, tgt_sizes)
|
||||
|
||||
@ -556,7 +627,8 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
|
||||
return vision_hidden_states
|
||||
|
||||
def get_embedding(self, data):
|
||||
def get_embedding(self, data: Dict[str, Union[List[torch.Tensor],
|
||||
torch.Tensor]]):
|
||||
input_ids = data["input_ids"]
|
||||
|
||||
vision_hidden_states = self.get_vision_hidden_states(data)
|
||||
@ -565,11 +637,11 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
else:
|
||||
image_bounds = []
|
||||
|
||||
if hasattr(self.llm.config, 'scale_emb'):
|
||||
vlm_embedding = self.llm.model.embed_tokens(
|
||||
input_ids) * self.llm.config.scale_emb
|
||||
if hasattr(self.config, 'scale_emb'):
|
||||
vlm_embedding = self.llm.embed_tokens(
|
||||
input_ids) * self.config.scale_emb
|
||||
else:
|
||||
vlm_embedding = self.llm.model.embed_tokens(input_ids)
|
||||
vlm_embedding = self.llm.embed_tokens(input_ids)
|
||||
vision_hidden_states = [
|
||||
i.type(vlm_embedding.dtype) if isinstance(i, torch.Tensor) else i
|
||||
for i in vision_hidden_states
|
||||
@ -587,7 +659,9 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]))
|
||||
return vlm_embedding, vision_hidden_states
|
||||
|
||||
def process_multimodal_inputs(self, inputs):
|
||||
def process_multimodal_inputs(self, inputs: Dict[str,
|
||||
Union[List[torch.Tensor],
|
||||
torch.Tensor]]):
|
||||
pixel_values = []
|
||||
tgt_sizes = []
|
||||
for b in range(len(inputs["pixel_values"])):
|
||||
@ -613,7 +687,6 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
"input_ids": input_ids,
|
||||
"tgt_sizes": kwargs.pop("tgt_sizes", None),
|
||||
}
|
||||
|
||||
inputs = self.process_multimodal_inputs(inputs)
|
||||
|
||||
vlm_embeddings, vision_hidden_states = self.get_embedding(inputs)
|
||||
@ -623,19 +696,21 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
input_embeds=vlm_embeddings)
|
||||
inputs_embeds=vlm_embeddings)
|
||||
return output
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
return self.llm.compute_logits(hidden_states, sampling_metadata)
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.llm.sample(logits, sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
@ -649,9 +724,9 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
# for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
# if key_to_modify in name:
|
||||
# name = name.replace(key_to_modify, new_key)
|
||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in name:
|
||||
name = name.replace(key_to_modify, new_key)
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
|
||||
804
vllm/model_executor/models/na_vit.py
Normal file
804
vllm/model_executor/models/na_vit.py
Normal file
@ -0,0 +1,804 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from transformers.modeling_outputs import (BaseModelOutput,
|
||||
BaseModelOutputWithPooling)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import (ModelOutput, is_flash_attn_2_available,
|
||||
replace_return_docstrings)
|
||||
|
||||
logger = logging.getLogger("vllm")
|
||||
|
||||
|
||||
# For Siglip: copied from
|
||||
# HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
|
||||
# Remove hints as there's little possibility to change these code.
|
||||
class SiglipVisionConfig(PretrainedConfig):
|
||||
|
||||
model_type = "siglip_vision_model"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
num_channels=3,
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
layer_norm_eps=1e-6,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str,
|
||||
os.PathLike],
|
||||
**kwargs) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(
|
||||
pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# get the vision config dict if we are loading from SiglipConfig
|
||||
if config_dict.get("model_type") == "siglip":
|
||||
config_dict = config_dict["vision_config"]
|
||||
|
||||
if "model_type" in config_dict and hasattr(
|
||||
cls,
|
||||
"model_type") and config_dict["model_type"] != cls.model_type:
|
||||
logger.warning(
|
||||
"You are using a model of type %s to "
|
||||
"instantiate a model of type %s. "
|
||||
"This is not supported for all configurations"
|
||||
"of models and can yield errors.", config_dict['model_type'],
|
||||
cls.model_type)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
|
||||
|
||||
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"google/siglip-base-patch16-224",
|
||||
# See all SigLIP models at https://huggingface.co/models?filter=siglip
|
||||
]
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import pad_input # noqa
|
||||
from flash_attn.bert_padding import index_first_axis, unpad_input
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||
def _get_unpad_data(attention_mask):
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
return (
|
||||
indices,
|
||||
cu_seqlens,
|
||||
max_seqlen_in_batch,
|
||||
)
|
||||
|
||||
|
||||
def _trunc_normal_(tensor, mean, std, a, b):
|
||||
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
warnings.warn(
|
||||
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||
"The distribution of values may be incorrect.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
l_ = norm_cdf((a - mean) / std)
|
||||
u = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to
|
||||
# [2l-1, 2u-1].
|
||||
tensor.uniform_(2 * l_ - 1, 2 * u - 1)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
if tensor.dtype in [torch.float16, torch.bfloat16]:
|
||||
# The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
|
||||
og_dtype = tensor.dtype
|
||||
tensor = tensor.to(torch.float32)
|
||||
tensor.erfinv_()
|
||||
tensor = tensor.to(og_dtype)
|
||||
else:
|
||||
tensor.erfinv_()
|
||||
|
||||
# Transform to proper mean, std
|
||||
tensor.mul_(std * math.sqrt(2.0))
|
||||
tensor.add_(mean)
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
if tensor.dtype == torch.float16:
|
||||
# The `clamp_` op is not (yet?) defined in float16+cpu
|
||||
tensor = tensor.to(torch.float32)
|
||||
tensor.clamp_(min=a, max=b)
|
||||
tensor = tensor.to(torch.float16)
|
||||
else:
|
||||
tensor.clamp_(min=a, max=b)
|
||||
|
||||
|
||||
def trunc_normal_tf_(tensor: torch.Tensor,
|
||||
mean: float = 0.0,
|
||||
std: float = 1.0,
|
||||
a: float = -2.0,
|
||||
b: float = 2.0) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
_trunc_normal_(tensor, 0, 1.0, a, b)
|
||||
tensor.mul_(std).add_(mean)
|
||||
|
||||
|
||||
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
if mode == "fan_in":
|
||||
denom = fan_in
|
||||
elif mode == "fan_out":
|
||||
denom = fan_out
|
||||
elif mode == "fan_avg":
|
||||
denom = (fan_in + fan_out) / 2
|
||||
|
||||
variance = scale / denom
|
||||
|
||||
if distribution == "truncated_normal":
|
||||
# constant is stddev of standard normal truncated to (-2, 2)
|
||||
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
||||
elif distribution == "normal":
|
||||
with torch.no_grad():
|
||||
tensor.normal_(std=math.sqrt(variance))
|
||||
elif distribution == "uniform":
|
||||
bound = math.sqrt(3 * variance)
|
||||
with torch.no_grad():
|
||||
tensor.uniform_(-bound, bound)
|
||||
else:
|
||||
raise ValueError(f"invalid distribution {distribution}")
|
||||
|
||||
|
||||
def lecun_normal_(tensor):
|
||||
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
||||
|
||||
|
||||
def default_flax_embed_init(tensor):
|
||||
variance_scaling_(tensor, mode="fan_in", distribution="normal")
|
||||
|
||||
|
||||
class SiglipVisionModelOutput(ModelOutput):
|
||||
image_embeds: Optional[torch.FloatTensor] = None
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class SiglipVisionEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, config: SiglipVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
padding="valid",
|
||||
)
|
||||
|
||||
self.num_patches_per_side = self.image_size // self.patch_size
|
||||
self.num_patches = self.num_patches_per_side**2
|
||||
self.num_positions = self.num_patches
|
||||
self.position_embedding = nn.Embedding(self.num_positions,
|
||||
self.embed_dim)
|
||||
|
||||
def forward(self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
patch_attention_mask: torch.BoolTensor,
|
||||
tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
|
||||
batch_size = pixel_values.size(0)
|
||||
|
||||
patch_embeds = self.patch_embedding(pixel_values)
|
||||
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
|
||||
max_nb_patches_h, max_nb_patches_w = (max_im_h // self.patch_size,
|
||||
max_im_w // self.patch_size)
|
||||
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0,
|
||||
1 / self.num_patches_per_side)
|
||||
position_ids = torch.full(
|
||||
size=(
|
||||
batch_size,
|
||||
max_nb_patches_h * max_nb_patches_w,
|
||||
),
|
||||
fill_value=0,
|
||||
)
|
||||
|
||||
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
||||
if tgt_sizes is not None:
|
||||
nb_patches_h = tgt_sizes[batch_idx][0]
|
||||
nb_patches_w = tgt_sizes[batch_idx][1]
|
||||
else:
|
||||
nb_patches_h = p_attn_mask[:, 0].sum()
|
||||
nb_patches_w = p_attn_mask[0].sum()
|
||||
|
||||
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
||||
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
||||
|
||||
bucket_coords_h = torch.bucketize(fractional_coords_h,
|
||||
boundaries,
|
||||
right=True)
|
||||
bucket_coords_w = torch.bucketize(fractional_coords_w,
|
||||
boundaries,
|
||||
right=True)
|
||||
|
||||
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side +
|
||||
bucket_coords_w).flatten()
|
||||
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
||||
|
||||
position_ids = position_ids.to(self.position_embedding.weight.device)
|
||||
|
||||
embeddings = embeddings + self.position_embedding(position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
class SiglipAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
"embed_dim must be divisible by num_heads (got `embed_dim`: "
|
||||
f"{self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads}).")
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(batch_size, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(batch_size, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(batch_size, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
|
||||
k_v_seq_len = key_states.shape[-2]
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(
|
||||
2, 3)) * self.scale
|
||||
|
||||
if attn_weights.size() != (batch_size, self.num_heads, q_len,
|
||||
k_v_seq_len):
|
||||
raise ValueError(
|
||||
"Attention weights should be of size "
|
||||
f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
||||
f" {attn_weights.size()}")
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
|
||||
raise ValueError(
|
||||
"Attention mask should be of size "
|
||||
f"{(batch_size, 1, q_len, k_v_seq_len)}",
|
||||
f"but is {attention_mask.size()}")
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights,
|
||||
dim=-1,
|
||||
dtype=torch.float32).to(
|
||||
query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights,
|
||||
p=self.dropout,
|
||||
training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (batch_size, self.num_heads, q_len,
|
||||
self.head_dim):
|
||||
raise ValueError(
|
||||
"`attn_output` should be of size "
|
||||
f"{(batch_size, self.num_heads, q_len, self.head_dim)}, "
|
||||
"but is"
|
||||
f" {attn_output.size()}")
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class SiglipFlashAttention2(SiglipAttention):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_causal = False # Hack to make sure we don't use a causal mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[Tuple[torch.Tensor]]]:
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(
|
||||
kv_seq_len, self.layer_idx)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.dropout if self.training else 0.0
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning(
|
||||
"The input hidden states seems to be "
|
||||
"silently casted in float32, "
|
||||
"this might be related to the fact "
|
||||
"you have upcasted embedding or layer norm layers in float32. "
|
||||
"We will cast back the input in"
|
||||
" %s.", target_dtype)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = self._flash_attention_forward(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len,
|
||||
self.embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _flash_attention_forward(self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
query_length,
|
||||
dropout=0.0,
|
||||
softmax_scale=None):
|
||||
causal = self.is_causal and query_length != 1
|
||||
|
||||
# Contains at least one padding token in the sequence
|
||||
if attention_mask is not None:
|
||||
batch_size = query_states.shape[0]
|
||||
(query_states, key_states, value_states, indices_q, cu_seq_lens,
|
||||
max_seq_lens) = self._upad_input(query_states, key_states,
|
||||
value_states, attention_mask,
|
||||
query_length)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
|
||||
query_length)
|
||||
else:
|
||||
attn_output = flash_attn_func(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal)
|
||||
|
||||
return attn_output
|
||||
|
||||
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
|
||||
query_length):
|
||||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
|
||||
attention_mask)
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||
|
||||
key_layer = index_first_axis(
|
||||
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
|
||||
head_dim), indices_k)
|
||||
value_layer = index_first_axis(
|
||||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
|
||||
head_dim), indices_k)
|
||||
if query_length == kv_seq_len:
|
||||
query_layer = index_first_axis(
|
||||
query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
|
||||
head_dim), indices_k)
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||
indices_q = indices_k
|
||||
elif query_length == 1:
|
||||
max_seqlen_in_batch_q = 1
|
||||
cu_seqlens_q = torch.arange(
|
||||
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||
) # There is a memcpy here, that is very bad.
|
||||
indices_q = cu_seqlens_q[:-1]
|
||||
query_layer = query_layer.squeeze(1)
|
||||
else:
|
||||
# The -q_len: slice assumes left padding.
|
||||
attention_mask = attention_mask[:, -query_length:]
|
||||
(query_layer, indices_q, cu_seqlens_q,
|
||||
max_seqlen_in_batch_q) = unpad_input(query_layer, attention_mask)
|
||||
|
||||
return (
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
indices_q,
|
||||
(cu_seqlens_q, cu_seqlens_k),
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
|
||||
class SiglipMLP(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer
|
||||
# with CLIP->Siglip
|
||||
class SiglipEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: SiglipVisionConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self._use_flash_attention_2 = (
|
||||
config._attn_implementation == "flash_attention_2")
|
||||
self.self_attn = (SiglipAttention(config)
|
||||
if not self._use_flash_attention_2 else
|
||||
SiglipFlashAttention2(config))
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(config)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states, )
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights, )
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class SiglipPreTrainedModel(PreTrainedModel):
|
||||
config_class = SiglipVisionConfig
|
||||
base_model_prefix = "siglip"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
if isinstance(module, SiglipVisionEmbeddings):
|
||||
width = self.config.hidden_size
|
||||
nn.init.normal_(module.position_embedding.weight,
|
||||
std=1 / np.sqrt(width))
|
||||
elif isinstance(module, nn.Embedding):
|
||||
default_flax_embed_init(module.weight)
|
||||
elif isinstance(module, SiglipAttention):
|
||||
nn.init.normal_(module.q_proj.weight)
|
||||
nn.init.normal_(module.k_proj.weight)
|
||||
nn.init.normal_(module.v_proj.weight)
|
||||
nn.init.normal_(module.out_proj.weight)
|
||||
nn.init.zeros_(module.q_proj.bias)
|
||||
nn.init.zeros_(module.k_proj.bias)
|
||||
nn.init.zeros_(module.v_proj.bias)
|
||||
nn.init.zeros_(module.out_proj.bias)
|
||||
elif isinstance(module, SiglipMLP):
|
||||
nn.init.normal_(module.fc1.weight)
|
||||
nn.init.normal_(module.fc2.weight)
|
||||
nn.init.normal_(module.fc1.bias, std=1e-6)
|
||||
nn.init.normal_(module.fc2.bias, std=1e-6)
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
lecun_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder
|
||||
# with CLIP->Siglip
|
||||
class SiglipEncoder(nn.Module):
|
||||
|
||||
def __init__(self, config: SiglipVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([
|
||||
SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Ignore copy
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
output_attentions = output_attentions if output_attentions is not None \
|
||||
else self.config.output_attentions
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else
|
||||
self.config.output_hidden_states)
|
||||
return_dict = return_dict if return_dict is not None \
|
||||
else self.config.use_return_dict
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states, )
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1], )
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states, )
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v for v in [hidden_states, encoder_states, all_attentions]
|
||||
if v is not None)
|
||||
return BaseModelOutput(last_hidden_state=hidden_states,
|
||||
hidden_states=encoder_states,
|
||||
attentions=all_attentions)
|
||||
|
||||
|
||||
class SiglipVisionTransformer(SiglipPreTrainedModel):
|
||||
config_class = SiglipVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def __init__(self, config: SiglipVisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = SiglipVisionEmbeddings(config)
|
||||
self.encoder = SiglipEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self._use_flash_attention_2 = (
|
||||
config._attn_implementation == "flash_attention_2")
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.embeddings.patch_embedding
|
||||
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling,
|
||||
config_class=SiglipVisionConfig)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
tgt_sizes: Optional[torch.IntTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None \
|
||||
else self.config.output_attentions
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else
|
||||
self.config.output_hidden_states)
|
||||
return_dict = return_dict if return_dict is not None \
|
||||
else self.config.use_return_dict
|
||||
|
||||
batch_size = pixel_values.size(0)
|
||||
if patch_attention_mask is None:
|
||||
patch_attention_mask = torch.ones(
|
||||
size=(
|
||||
batch_size,
|
||||
pixel_values.size(2) // self.config.patch_size,
|
||||
pixel_values.size(3) // self.config.patch_size,
|
||||
),
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
tgt_sizes=tgt_sizes)
|
||||
|
||||
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
||||
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
||||
# So when the `patch_attention_mask` is full of 1s
|
||||
# (i.e. attending to the whole sequence),
|
||||
# avoiding passing the attention_mask,
|
||||
# which is equivalent to attending to the full sequence
|
||||
if not torch.any(~patch_attention_mask):
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = (_prepare_4d_attention_mask(
|
||||
patch_attention_mask, hidden_states.dtype)
|
||||
if not self._use_flash_attention_2 else
|
||||
patch_attention_mask)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, None) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=None,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
@ -342,7 +342,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user