mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 12:35:51 +08:00
1698 lines
61 KiB
Python
1698 lines
61 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import math
|
|
from abc import abstractmethod
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from functools import partial
|
|
from typing import Annotated, Any, Literal, Optional, TypeVar, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from einops import rearrange
|
|
from transformers import PretrainedConfig
|
|
from transformers.activations import GELUActivation
|
|
from transformers.feature_extraction_utils import BatchFeature
|
|
from transformers.modeling_outputs import (BaseModelOutput,
|
|
BaseModelOutputWithPooling)
|
|
from transformers.utils import torch_int
|
|
|
|
from vllm.attention.layer import check_upstream_fa_availability
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor import SamplingMetadata
|
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
|
GPTQMarlinConfig)
|
|
from vllm.model_executor.model_loader.weight_utils import (
|
|
default_weight_loader, maybe_remap_kv_scale_name)
|
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
|
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
|
MultiModalDataDict, MultiModalFieldConfig,
|
|
MultiModalKwargsItems, VideoItem)
|
|
from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
|
|
ModalityDataItems, MultiModalDataItems,
|
|
MultiModalDataParser)
|
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|
BaseProcessingInfo, PromptReplacement,
|
|
PromptUpdate)
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
|
from vllm.platforms import _Backend
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.transformers_utils.config import uses_mrope
|
|
from vllm.utils import is_list_of
|
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|
|
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
|
SupportsMultiModal, SupportsPP)
|
|
from .siglip import SiglipMLP
|
|
from .utils import (AutoWeightsLoader, WeightsMapper,
|
|
init_vllm_registered_model, is_pp_missing_parameter,
|
|
maybe_prefix, merge_multimodal_embeddings)
|
|
from .vision import get_vit_attn_backend
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def smart_resize(
|
|
height: int,
|
|
width: int,
|
|
factor: int,
|
|
min_pixels: int,
|
|
max_pixels: int,
|
|
):
|
|
if height < factor:
|
|
logger.warning(
|
|
"smart_resize: height=%s < factor=%s, reset height=factor",
|
|
height,
|
|
factor,
|
|
)
|
|
width = round((width * factor) / height)
|
|
height = factor
|
|
|
|
if width < factor:
|
|
logger.warning(
|
|
"smart_resize: width=%s < factor=%s, reset width=factor",
|
|
width,
|
|
factor,
|
|
)
|
|
height = round((height * factor) / width)
|
|
width = factor
|
|
|
|
if max(height, width) / min(height, width) > 200:
|
|
raise ValueError("absolute aspect ratio must be smaller than 200, got "
|
|
"{max(height, width) / min(height, width)}")
|
|
h_bar = round(height / factor) * factor
|
|
w_bar = round(width / factor) * factor
|
|
if h_bar * w_bar > max_pixels:
|
|
beta = math.sqrt((height * width) / max_pixels)
|
|
h_bar = math.floor(height / beta / factor) * factor
|
|
w_bar = math.floor(width / beta / factor) * factor
|
|
elif h_bar * w_bar < min_pixels:
|
|
beta = math.sqrt(min_pixels / (height * width))
|
|
h_bar = math.ceil(height * beta / factor) * factor
|
|
w_bar = math.ceil(width * beta / factor) * factor
|
|
return h_bar, w_bar
|
|
|
|
|
|
class KeyeImagePixelInputs(TensorSchema):
|
|
"""
|
|
Dimensions:
|
|
- b: Batch size
|
|
- np: Number of patches
|
|
- c: Number of channels
|
|
- ps: Patch size
|
|
- ni: Number of images
|
|
- g: Grid dimensions (3 for t, h, w)
|
|
"""
|
|
type: Literal["pixel_values"]
|
|
pixel_values: Annotated[
|
|
torch.Tensor,
|
|
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
|
|
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
|
|
|
|
|
class KeyeImageEmbeddingInputs(TensorSchema):
|
|
"""
|
|
Dimensions:
|
|
- nf: Number of image features
|
|
- hs: Hidden size (must match the hidden size of language model
|
|
backbone)
|
|
- ni: Number of images
|
|
- g: Grid dimensions (3 for t, h, w)
|
|
"""
|
|
type: Literal["image_embeds"]
|
|
image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
|
|
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
|
|
|
|
|
KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs]
|
|
|
|
|
|
class KeyeVideoPixelInputs(TensorSchema):
|
|
"""
|
|
Dimensions:
|
|
- b: Batch size
|
|
- np: Number of patches
|
|
- c: Number of channels
|
|
- ps: Patch size
|
|
- ni: Number of images
|
|
- g: Grid dimensions (3 for t, h, w)
|
|
"""
|
|
type: Literal["pixel_values_videos"]
|
|
pixel_values_videos: Annotated[
|
|
torch.Tensor,
|
|
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
|
|
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
|
|
|
|
|
|
class KeyeVideoEmbeddingInputs(TensorSchema):
|
|
"""
|
|
Dimensions:
|
|
- nf: Number of video features
|
|
- hs: Hidden size (must match the hidden size of language model
|
|
backbone)
|
|
- nv: Number of videos
|
|
- g: Grid dimensions (3 for t, h, w)
|
|
"""
|
|
type: Literal["video_embeds"]
|
|
video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
|
|
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
|
|
|
|
|
|
KeyeVideoInputs = Union[KeyeVideoPixelInputs, KeyeVideoEmbeddingInputs]
|
|
|
|
|
|
class KeyeVisionEmbeddings(nn.Module):
|
|
|
|
def __init__(self, config: PretrainedConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.image_size = config.image_size
|
|
self.patch_size = config.patch_size
|
|
|
|
self.patch_embedding = nn.Conv2d(
|
|
in_channels=config.num_channels,
|
|
out_channels=self.embed_dim,
|
|
kernel_size=self.patch_size,
|
|
stride=self.patch_size,
|
|
padding="valid",
|
|
)
|
|
|
|
self.num_patches = (self.image_size // self.patch_size)**2
|
|
self.num_positions = self.num_patches
|
|
self.cache_position_embedding = dict()
|
|
self.cache_position_count = dict()
|
|
self.position_embedding = nn.Embedding(self.num_positions,
|
|
self.embed_dim)
|
|
self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)
|
|
|
|
self.register_buffer(
|
|
"position_ids",
|
|
torch.arange(self.num_positions).expand((1, -1)),
|
|
persistent=False,
|
|
)
|
|
|
|
def interpolate_pos_encoding(
|
|
self,
|
|
embeddings: torch.Tensor,
|
|
height: int,
|
|
width: int,
|
|
is_after_patchify: bool = False,
|
|
) -> torch.Tensor:
|
|
|
|
num_positions = self.position_embedding.weight.shape[0]
|
|
|
|
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
|
|
|
|
dim = embeddings.shape[-1]
|
|
|
|
if is_after_patchify:
|
|
new_height = height
|
|
new_width = width
|
|
else:
|
|
new_height = height // self.patch_size
|
|
new_width = width // self.patch_size
|
|
|
|
sqrt_num_positions = torch_int(num_positions**0.5)
|
|
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions,
|
|
sqrt_num_positions, dim)
|
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
|
|
|
patch_pos_embed = nn.functional.interpolate(
|
|
patch_pos_embed,
|
|
size=(new_height, new_width),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
)
|
|
|
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
|
return patch_pos_embed
|
|
|
|
def fetch_position_embedding_lfu_cache(self,
|
|
embeddings,
|
|
h,
|
|
w,
|
|
max_cache: int = 20):
|
|
grid = (h, w)
|
|
if grid in self.cache_position_embedding:
|
|
self.cache_position_count[grid] += 1
|
|
return self.cache_position_embedding[grid]
|
|
|
|
if len(self.cache_position_embedding) >= max_cache:
|
|
min_hit_grid = min(
|
|
self.cache_position_count,
|
|
key=self.cache_position_count.get,
|
|
)
|
|
self.cache_position_count.pop(min_hit_grid)
|
|
self.cache_position_embedding.pop(min_hit_grid)
|
|
|
|
position_embedding = self.interpolate_pos_encoding(
|
|
embeddings, h, w, True)
|
|
self.cache_position_count[grid] = 1
|
|
self.cache_position_embedding[grid] = position_embedding
|
|
return position_embedding
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
image_grid_thw: Optional[list[Union[
|
|
tuple[int, int, int],
|
|
list[tuple[int, int, int]],
|
|
]]] = None,
|
|
interpolate_pos_encoding=False,
|
|
) -> torch.Tensor:
|
|
if pixel_values.dim() == 4:
|
|
pixel_values = pixel_values.unsqueeze(0)
|
|
if pixel_values.dim() == 5:
|
|
if position_ids is None:
|
|
raise ValueError(
|
|
"position_ids cannot be None when pixel_values.dim() is 5."
|
|
)
|
|
(
|
|
batch_size,
|
|
squence_len,
|
|
channel,
|
|
height,
|
|
width,
|
|
) = pixel_values.shape
|
|
target_dtype = self.patch_embedding.weight.dtype
|
|
pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w")
|
|
patch_embeds = self.patch_embedding(
|
|
pixel_values.to(dtype=target_dtype))
|
|
embeddings = patch_embeds.flatten(-2).squeeze(-1)
|
|
|
|
if interpolate_pos_encoding and image_grid_thw is not None:
|
|
start = 0
|
|
tmp_embeddings = list()
|
|
for image_grid in image_grid_thw:
|
|
t, h, w = image_grid
|
|
end = start + t * h * w
|
|
image_embeddings = embeddings[start:end, :]
|
|
position_embedding = (self.interpolate_pos_encoding(
|
|
image_embeddings, h, w, True).squeeze(0).repeat(t, 1))
|
|
image_embeddings = image_embeddings + position_embedding
|
|
tmp_embeddings.append(image_embeddings)
|
|
start = end
|
|
embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0)
|
|
else:
|
|
embeddings = embeddings + self.packing_position_embedding(
|
|
position_ids)
|
|
return embeddings
|
|
else:
|
|
raise ValueError("Unsupported pixel_values dimension:"
|
|
f" {pixel_values.dim()}. Expected 4 or 5.")
|
|
|
|
|
|
def apply_rotary_pos_emb_flashatt(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
|
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
|
|
|
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
|
|
|
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
|
|
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
|
|
return q_embed, k_embed
|
|
|
|
|
|
class KeyeSiglipAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You
|
|
Need' paper."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
hidden_size = config.hidden_size
|
|
self.hidden_size = config.hidden_size
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.total_num_heads = config.num_attention_heads
|
|
assert self.total_num_heads % tp_size == 0
|
|
self.num_heads = self.total_num_heads // tp_size
|
|
self.total_num_kv_heads = config.num_attention_heads
|
|
if self.total_num_kv_heads >= tp_size:
|
|
assert self.total_num_kv_heads % tp_size == 0
|
|
else:
|
|
assert tp_size % self.total_num_kv_heads == 0
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
|
self.head_dim = config.hidden_size // self.total_num_heads
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
self.scale = self.head_dim**-0.5
|
|
|
|
self.qkv_proj = QKVParallelLinear(
|
|
hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
self.total_num_kv_heads,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.qkv_proj",
|
|
)
|
|
self.out_proj = RowParallelLinear(
|
|
input_size=hidden_size,
|
|
output_size=hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.out_proj",
|
|
)
|
|
|
|
# Detect attention implementation.
|
|
self.attn_backend = get_vit_attn_backend(
|
|
head_size=self.head_dim, dtype=torch.get_default_dtype())
|
|
|
|
self.use_upstream_fa = False
|
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
|
check_upstream_fa_availability(
|
|
torch.get_default_dtype()):
|
|
self.attn_backend = _Backend.FLASH_ATTN
|
|
self.use_upstream_fa = True
|
|
|
|
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
|
|
raise RuntimeError(
|
|
f"Keye-VL does not support {self.attn_backend} backend now.")
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
cu_seqlens: Optional[list[torch.Tensor]] = None,
|
|
rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.split(
|
|
[self.q_size, self.kv_size, self.kv_size],
|
|
dim=-1,
|
|
)
|
|
|
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
|
batch_size = q.shape[0]
|
|
|
|
if rope_emb is None:
|
|
q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
|
|
k = k.view(
|
|
*k.shape[:-1],
|
|
self.num_kv_heads,
|
|
self.head_dim,
|
|
)
|
|
v = v.view(
|
|
*v.shape[:-1],
|
|
self.num_kv_heads,
|
|
self.head_dim,
|
|
)
|
|
else:
|
|
if cu_seqlens is None:
|
|
raise ValueError(
|
|
"cu_seqlens cannot be None when rope_emb is not None.")
|
|
cos, sin = rope_emb
|
|
q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
|
|
k = k.view(
|
|
*k.shape[:-1],
|
|
self.num_kv_heads,
|
|
self.head_dim,
|
|
)
|
|
q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin)
|
|
v = v.view(
|
|
*v.shape[:-1],
|
|
self.num_kv_heads,
|
|
self.head_dim,
|
|
)
|
|
|
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
|
if self.use_upstream_fa:
|
|
from flash_attn import flash_attn_varlen_func
|
|
else:
|
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
|
|
|
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
|
|
|
output = flash_attn_varlen_func(
|
|
q,
|
|
k,
|
|
v,
|
|
cu_seqlens_q=cu_seqlens,
|
|
cu_seqlens_k=cu_seqlens,
|
|
max_seqlen_q=max_seqlen,
|
|
max_seqlen_k=max_seqlen,
|
|
causal=False,
|
|
softmax_scale=self.scale,
|
|
)
|
|
context_layer = rearrange(output,
|
|
"(b s) ... -> b s ...",
|
|
b=batch_size)
|
|
elif self.attn_backend == _Backend.XFORMERS:
|
|
from xformers import ops as xops
|
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
|
|
|
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
|
|
kv_seqlen=None,
|
|
device=q.device)
|
|
|
|
context_layer = xops.memory_efficient_attention_forward(
|
|
q, k, v, attn_bias=attn_bias, p=0, scale=None)
|
|
|
|
context_layer = rearrange(context_layer,
|
|
"b s h d -> b s (h d)").contiguous()
|
|
|
|
output, _ = self.out_proj(context_layer)
|
|
return output
|
|
|
|
|
|
class SigLIPRotaryEmbedding(nn.Module):
|
|
|
|
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.theta = theta
|
|
self.rope_init()
|
|
|
|
def rope_init(self):
|
|
inv_freq = 1.0 / (self.theta**(
|
|
torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim))
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
|
def forward(self, seqlen: int) -> torch.Tensor:
|
|
seq = torch.arange(
|
|
seqlen,
|
|
device=self.inv_freq.device,
|
|
dtype=self.inv_freq.dtype,
|
|
)
|
|
freqs = torch.outer(seq, self.inv_freq)
|
|
return freqs
|
|
|
|
|
|
class KeyeSiglipEncoderLayer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: Union[PretrainedConfig],
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.embed_dim = config.hidden_size
|
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
|
eps=config.layer_norm_eps)
|
|
self.self_attn = KeyeSiglipAttention(
|
|
config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.self_attn",
|
|
)
|
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
|
eps=config.layer_norm_eps)
|
|
self.mlp = SiglipMLP(
|
|
config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.mlp",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
output_attentions: Optional[bool] = False,
|
|
cu_seqlens: Optional[list[torch.Tensor]] = None,
|
|
rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
) -> tuple[torch.FloatTensor]:
|
|
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.layer_norm1(hidden_states)
|
|
hidden_states = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
cu_seqlens=cu_seqlens,
|
|
rope_emb=rope_emb,
|
|
)
|
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm2(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
return hidden_states
|
|
|
|
|
|
class KeyeSiglipEncoder(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
embed_dim = config.hidden_size
|
|
num_heads = config.num_attention_heads
|
|
head_dim = embed_dim // num_heads
|
|
self.layers = nn.ModuleList([
|
|
KeyeSiglipEncoderLayer(
|
|
config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.layers.{layer_idx}",
|
|
) for layer_idx in range(config.num_hidden_layers)
|
|
])
|
|
self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)
|
|
|
|
@staticmethod
|
|
def flatten_list(image_grid_thw):
|
|
tmp_image_grid_thw = list()
|
|
for image_grid in image_grid_thw:
|
|
if isinstance(image_grid, list):
|
|
tmp_image_grid_thw.extend(image_grid)
|
|
else:
|
|
tmp_image_grid_thw.append(image_grid)
|
|
return tmp_image_grid_thw
|
|
|
|
def forward(
|
|
self,
|
|
inputs_embeds,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
cu_seqlens: Optional[list[torch.Tensor]] = None,
|
|
image_grid_thw: Optional[list[Union[
|
|
tuple[int, int, int],
|
|
list[tuple[int, int, int]],
|
|
]]] = None,
|
|
height_position_ids: Optional[torch.Tensor] = None,
|
|
width_position_ids: Optional[torch.Tensor] = None,
|
|
use_rope: Optional[bool] = False,
|
|
window_size: Optional[bool] = -1,
|
|
vision_or_text: str = "vision",
|
|
) -> BaseModelOutput:
|
|
device = inputs_embeds.device
|
|
hidden_states = inputs_embeds
|
|
if use_rope is True:
|
|
flatten_image_grid_thw = self.flatten_list(image_grid_thw)
|
|
|
|
if width_position_ids is None or height_position_ids is None:
|
|
split_hids = list()
|
|
split_wids = list()
|
|
for t, h, w in flatten_image_grid_thw:
|
|
image_pids = torch.arange(t * h * w,
|
|
device=device) % (h * w)
|
|
sample_hids = image_pids // w
|
|
sample_wids = image_pids % w
|
|
split_hids.append(sample_hids)
|
|
split_wids.append(sample_wids)
|
|
width_position_ids = torch.concat(split_wids, dim=0)
|
|
height_position_ids = torch.concat(split_hids, dim=0)
|
|
|
|
pids = torch.stack(
|
|
[height_position_ids, width_position_ids],
|
|
dim=-1,
|
|
)
|
|
max_grid_size = pids.max() + 1
|
|
rope_emb_max_grid = self.rotary_pos_emb(max_grid_size)
|
|
rope_emb = rope_emb_max_grid[pids].flatten(1)
|
|
rope_emb = rope_emb.repeat(1, 2)
|
|
rope_emb = (rope_emb.cos(), rope_emb.sin())
|
|
else:
|
|
rope_emb = None
|
|
|
|
attn_cu_seqlens = cu_seqlens
|
|
hidden_states = inputs_embeds
|
|
assert attention_mask is None
|
|
|
|
for encoder_layer in self.layers:
|
|
hidden_states = encoder_layer(
|
|
hidden_states,
|
|
attention_mask,
|
|
output_attentions=output_attentions,
|
|
cu_seqlens=attn_cu_seqlens,
|
|
rope_emb=rope_emb,
|
|
)
|
|
return hidden_states
|
|
|
|
|
|
class KeyeSiglipVisionTransformer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
embed_dim = config.hidden_size
|
|
|
|
self.embeddings = KeyeVisionEmbeddings(config)
|
|
self.encoder = KeyeSiglipEncoder(
|
|
config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.encoder",
|
|
)
|
|
self.post_layernorm = nn.LayerNorm(embed_dim,
|
|
eps=config.layer_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
interpolate_pos_encoding: Optional[bool] = False,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
sample_indices: Optional[torch.Tensor] = None,
|
|
image_indices: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
height_position_ids: Optional[torch.Tensor] = None,
|
|
width_position_ids: Optional[torch.Tensor] = None,
|
|
cu_seqlens: Optional[list[torch.Tensor]] = None,
|
|
padding_mask: Optional[torch.Tensor] = None,
|
|
vision_return_embed_list: Optional[bool] = False,
|
|
image_grid_thw: Optional[list[Union[
|
|
tuple[int, int, int],
|
|
list[tuple[int, int, int]],
|
|
]]] = None,
|
|
return_pooler_output: Optional[bool] = True,
|
|
use_rope: Optional[bool] = False,
|
|
window_size: Optional[bool] = -1,
|
|
) -> BaseModelOutputWithPooling:
|
|
|
|
hidden_states = self.embeddings(
|
|
pixel_values,
|
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
position_ids=position_ids,
|
|
image_grid_thw=image_grid_thw,
|
|
)
|
|
|
|
last_hidden_state = self.encoder(
|
|
inputs_embeds=hidden_states,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
attention_mask=attention_mask,
|
|
cu_seqlens=cu_seqlens,
|
|
image_grid_thw=image_grid_thw,
|
|
use_rope=use_rope,
|
|
height_position_ids=height_position_ids,
|
|
width_position_ids=width_position_ids,
|
|
window_size=window_size,
|
|
vision_or_text="vision",
|
|
)
|
|
|
|
last_hidden_state = self.post_layernorm(last_hidden_state)
|
|
|
|
sample_hidden_state = list()
|
|
if cu_seqlens is None:
|
|
raise ValueError("cu_seqlens cannot be None for "
|
|
"SiglipVisionTransformer output processing.")
|
|
for i in range(cu_seqlens.shape[0] - 1):
|
|
start = cu_seqlens[i]
|
|
end = cu_seqlens[i + 1]
|
|
tensor = last_hidden_state[:, start:end, :].squeeze(0)
|
|
sample_hidden_state.append(tensor)
|
|
|
|
return sample_hidden_state
|
|
|
|
|
|
class KeyeSiglipVisionModel(nn.Module):
|
|
config_class = PretrainedConfig
|
|
main_input_name = "pixel_values"
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
|
|
self.vision_model = KeyeSiglipVisionTransformer(
|
|
config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.vision_model",
|
|
)
|
|
self.quant_config = quant_config
|
|
|
|
@property
|
|
def dtype(self) -> torch.dtype:
|
|
return self.vision_model.embeddings.patch_embedding.weight.dtype
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return self.vision_model.embeddings.patch_embedding.weight.device
|
|
|
|
def get_input_embeddings(self) -> nn.Module:
|
|
return self.vision_model.embeddings.patch_embedding
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values,
|
|
sample_indices: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
interpolate_pos_encoding: bool = False,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
vision_return_embed_list: Optional[bool] = False,
|
|
image_grid_thw: Optional[list[Union[
|
|
tuple[int, int, int],
|
|
list[tuple[int, int, int]],
|
|
]]] = None,
|
|
cu_seqlens: Optional[list[torch.Tensor]] = None,
|
|
return_pooler_output: Optional[bool] = True,
|
|
use_rope: Optional[bool] = False,
|
|
window_size: Optional[bool] = -1,
|
|
) -> BaseModelOutputWithPooling:
|
|
|
|
return self.vision_model(
|
|
pixel_values=pixel_values,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
position_ids=position_ids,
|
|
vision_return_embed_list=vision_return_embed_list,
|
|
image_grid_thw=image_grid_thw,
|
|
sample_indices=sample_indices,
|
|
cu_seqlens=cu_seqlens,
|
|
return_pooler_output=return_pooler_output,
|
|
use_rope=use_rope,
|
|
window_size=window_size,
|
|
)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str,
|
|
torch.Tensor]]) -> set[str]:
|
|
stacked_params_mapping = [
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
]
|
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
loaded_params: set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
if "head.attention" in name or "head.layernorm" in name:
|
|
continue
|
|
if "head.mlp" in name or "head.probe" in name:
|
|
continue
|
|
if self.quant_config is not None and (
|
|
scale_name := self.quant_config.get_cache_scale(name)):
|
|
param = params_dict[scale_name]
|
|
weight_loader = getattr(
|
|
param,
|
|
"weight_loader",
|
|
default_weight_loader,
|
|
)
|
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
|
loaded_weight[0])
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(scale_name)
|
|
continue
|
|
for (
|
|
param_name,
|
|
weight_name,
|
|
shard_id,
|
|
) in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
if name is None:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(
|
|
param,
|
|
"weight_loader",
|
|
default_weight_loader,
|
|
)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
class Projector(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
text_config: PretrainedConfig,
|
|
vision_config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.text_config = text_config
|
|
self.vision_config = vision_config
|
|
self.merge_kernel_size = (2, 2)
|
|
|
|
self.hidden_size = (self.vision_config.hidden_size *
|
|
self.merge_kernel_size[0] *
|
|
self.merge_kernel_size[1])
|
|
|
|
self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size,
|
|
eps=1e-05)
|
|
self.act = GELUActivation()
|
|
|
|
self.linear_1 = ColumnParallelLinear(
|
|
self.hidden_size,
|
|
self.hidden_size,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.linear_1",
|
|
)
|
|
self.linear_2 = RowParallelLinear(
|
|
self.hidden_size,
|
|
self.text_config.hidden_size,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.linear_2",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
image_features: Union[torch.Tensor, list[torch.Tensor]],
|
|
image_grid_thw: list[tuple[int, int, int]],
|
|
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
|
m1, m2 = self.merge_kernel_size
|
|
if isinstance(image_features, (list, tuple)):
|
|
processed_features = list()
|
|
for image_feature, image_grid in zip(image_features,
|
|
image_grid_thw):
|
|
image_feature = self.pre_norm(image_feature)
|
|
t, h, w = image_grid
|
|
|
|
image_feature = rearrange(
|
|
image_feature,
|
|
"(t h p1 w p2) d -> (t h w) (p1 p2 d)",
|
|
t=t,
|
|
h=h // m1,
|
|
p1=m1,
|
|
w=w // m2,
|
|
p2=m2,
|
|
)
|
|
hidden_states, _ = self.linear_1(image_feature)
|
|
hidden_states = self.act(hidden_states)
|
|
hidden_states, _ = self.linear_2(hidden_states)
|
|
processed_features.append(hidden_states)
|
|
|
|
return processed_features
|
|
|
|
dims = image_features.shape[:-1]
|
|
dim = image_features.shape[-1]
|
|
image_features = image_features.view(np.prod(dims), dim)
|
|
hidden_states = self.pre_norm(image_features).view(
|
|
-1, self.hidden_size)
|
|
hidden_states = self.linear_1(hidden_states)
|
|
hidden_states = self.act(hidden_states)
|
|
hidden_states = self.linear_2(hidden_states)
|
|
|
|
return hidden_states.view(*dims, -1)
|
|
|
|
|
|
def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ):
|
|
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
|
|
image_grid_sizes = image_grid_thw.prod(-1)
|
|
|
|
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
|
|
video_grid_sizes = video_grid_thw.prod(-1)
|
|
|
|
return dict(
|
|
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
|
"image", image_grid_sizes),
|
|
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
|
"image", image_grid_sizes),
|
|
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
|
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
|
"video", video_grid_sizes),
|
|
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
|
"video", video_grid_sizes),
|
|
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
|
)
|
|
|
|
|
|
class KeyeMultiModalDataParser(MultiModalDataParser):
|
|
|
|
def _parse_image_data(
|
|
self,
|
|
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
|
) -> ModalityDataItems[Any, Any]:
|
|
if isinstance(data, dict):
|
|
return DictEmbeddingItems(
|
|
data,
|
|
modality="image",
|
|
required_fields={
|
|
"image_embeds",
|
|
"image_grid_thw",
|
|
},
|
|
fields_factory=_keye_field_config,
|
|
)
|
|
|
|
return super()._parse_image_data(data)
|
|
|
|
def _parse_video_data(
|
|
self,
|
|
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
|
|
) -> ModalityDataItems[Any, Any]:
|
|
if isinstance(data, dict):
|
|
return DictEmbeddingItems(
|
|
data,
|
|
modality="video",
|
|
required_fields={
|
|
"video_embeds",
|
|
"video_grid_thw",
|
|
},
|
|
fields_factory=_keye_field_config,
|
|
)
|
|
|
|
return super()._parse_video_data(data)
|
|
|
|
|
|
class KeyeProcessingInfo(BaseProcessingInfo):
|
|
|
|
def get_max_image_size(self) -> int:
|
|
return 9999999 #_MAX_IMAGE_SIZE
|
|
|
|
def get_max_frame_per_video(self) -> int:
|
|
return 16 #_MAX_FRAMES_PER_VIDEO
|
|
|
|
def get_image_processor(self, **kwargs: object):
|
|
return self.get_hf_processor(**kwargs).image_processor
|
|
|
|
def get_supported_mm_limits(self, ) -> Mapping[str, Optional[int]]:
|
|
return {"image": None, "video": None}
|
|
|
|
def get_mm_max_tokens_per_item(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
) -> Mapping[str, int]:
|
|
return {
|
|
"image": self.get_max_image_tokens(),
|
|
"video": self.get_max_video_tokens(seq_len),
|
|
}
|
|
|
|
def _get_vision_info(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
num_frames: int = 1,
|
|
do_resize: bool = True,
|
|
image_processor,
|
|
) -> tuple[ImageSize, int]:
|
|
if image_processor is None:
|
|
image_processor = self.get_image_processor()
|
|
|
|
hf_config = self.get_hf_config()
|
|
vision_config = hf_config.vision_config
|
|
patch_size = vision_config.patch_size
|
|
merge_size = vision_config.spatial_merge_size
|
|
temporal_patch_size = 1
|
|
|
|
if do_resize:
|
|
resized_height, resized_width = smart_resize(
|
|
height=image_height,
|
|
width=image_width,
|
|
factor=patch_size * merge_size,
|
|
min_pixels=image_processor.min_pixels,
|
|
max_pixels=image_processor.max_pixels,
|
|
)
|
|
preprocessed_size = ImageSize(width=resized_width,
|
|
height=resized_height)
|
|
else:
|
|
preprocessed_size = ImageSize(width=image_width,
|
|
height=image_height)
|
|
|
|
padded_num_frames = num_frames + num_frames % temporal_patch_size
|
|
|
|
grid_t = max(padded_num_frames // temporal_patch_size, 1)
|
|
grid_h = preprocessed_size.height // patch_size
|
|
grid_w = preprocessed_size.width // patch_size
|
|
|
|
num_patches = grid_t * grid_h * grid_w
|
|
num_vision_tokens = num_patches // (merge_size**2)
|
|
|
|
return preprocessed_size, num_vision_tokens
|
|
|
|
def get_num_image_tokens(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
image_processor,
|
|
) -> int:
|
|
_, num_image_tokens = self._get_vision_info(
|
|
image_width=image_width,
|
|
image_height=image_height,
|
|
image_processor=image_processor,
|
|
)
|
|
return num_image_tokens
|
|
|
|
def get_num_video_tokens(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
num_frames: int,
|
|
image_processor,
|
|
) -> int:
|
|
_, num_video_tokens = self._get_vision_info(
|
|
image_width=image_width,
|
|
image_height=image_height,
|
|
num_frames=num_frames,
|
|
image_processor=image_processor,
|
|
)
|
|
return num_video_tokens
|
|
|
|
def get_image_size_with_most_features(self, ) -> ImageSize:
|
|
max_image_size, _ = self._get_vision_info(
|
|
image_width=self.get_max_image_size(),
|
|
image_height=self.get_max_image_size(),
|
|
image_processor=None,
|
|
)
|
|
return max_image_size
|
|
|
|
def get_max_image_tokens(self) -> int:
|
|
target_width, target_height = self.get_image_size_with_most_features()
|
|
|
|
return self.get_num_image_tokens(
|
|
image_width=target_width,
|
|
image_height=target_height,
|
|
image_processor=None,
|
|
)
|
|
|
|
def _get_max_video_frames(self, max_tokens: int) -> int:
|
|
target_width, target_height = self.get_image_size_with_most_features()
|
|
|
|
num_frames = 0
|
|
|
|
while True:
|
|
next_num_frames = num_frames + 1
|
|
next_max_tokens = self.get_num_video_tokens(
|
|
image_width=target_width,
|
|
image_height=target_height,
|
|
num_frames=next_num_frames,
|
|
image_processor=None,
|
|
)
|
|
|
|
if next_max_tokens > max_tokens:
|
|
break
|
|
|
|
num_frames = next_num_frames
|
|
|
|
return num_frames
|
|
|
|
def get_num_frames_with_most_features(self, seq_len: int) -> int:
|
|
mm_config = self.ctx.get_mm_config()
|
|
max_images = mm_config.get_limit_per_prompt("image")
|
|
max_videos = mm_config.get_limit_per_prompt("video")
|
|
|
|
max_image_tokens = self.get_max_image_tokens() * max_images
|
|
max_total_frames = self._get_max_video_frames(seq_len -
|
|
max_image_tokens)
|
|
max_frames_per_video = min(
|
|
max_total_frames // max(max_videos, 1),
|
|
self.get_max_frame_per_video(),
|
|
)
|
|
|
|
return max(max_frames_per_video, 1)
|
|
|
|
def get_max_video_tokens(self, seq_len: int) -> int:
|
|
target_width, target_height = self.get_image_size_with_most_features()
|
|
|
|
return self.get_num_video_tokens(
|
|
image_width=target_width,
|
|
image_height=target_height,
|
|
num_frames=self.get_num_frames_with_most_features(seq_len),
|
|
image_processor=None,
|
|
)
|
|
|
|
|
|
_I = TypeVar("_I", bound=KeyeProcessingInfo)
|
|
|
|
|
|
class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
|
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
num_images = mm_counts.get("image", 0)
|
|
num_videos = mm_counts.get("video", 0)
|
|
|
|
hf_processor = self.info.get_hf_processor()
|
|
image_token: str = hf_processor.image_token
|
|
video_token: str = hf_processor.video_token
|
|
|
|
return image_token * num_images + video_token * num_videos
|
|
|
|
def get_dummy_mm_data(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
) -> MultiModalDataDict:
|
|
num_images = mm_counts.get("image", 0)
|
|
num_videos = mm_counts.get("video", 0)
|
|
|
|
target_width, target_height = (
|
|
self.info.get_image_size_with_most_features())
|
|
target_num_frames = self.info.get_num_frames_with_most_features(
|
|
seq_len)
|
|
|
|
mm_data = {
|
|
"image":
|
|
self._get_dummy_images(
|
|
width=target_width,
|
|
height=target_height,
|
|
num_images=num_images,
|
|
),
|
|
"video":
|
|
self._get_dummy_videos(
|
|
width=target_width,
|
|
height=target_height,
|
|
num_frames=target_num_frames,
|
|
num_videos=num_videos,
|
|
),
|
|
}
|
|
|
|
return mm_data
|
|
|
|
|
|
class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]):
|
|
...
|
|
|
|
|
|
class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
|
|
|
|
def _get_data_parser(self) -> MultiModalDataParser:
|
|
return KeyeMultiModalDataParser()
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, Any],
|
|
out_mm_kwargs: MultiModalKwargsItems,
|
|
) -> Sequence[PromptUpdate]:
|
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
image_processor = self.info.get_image_processor(
|
|
**hf_processor_mm_kwargs)
|
|
tokenizer = self.info.get_tokenizer()
|
|
vocab = tokenizer.get_vocab()
|
|
|
|
placeholder = {
|
|
"image": vocab[hf_processor.image_token],
|
|
"video": vocab[hf_processor.video_token],
|
|
}
|
|
|
|
merge_length = image_processor.merge_size**2
|
|
|
|
def get_replacement_keye(item_idx: int, modality: str):
|
|
out_item = out_mm_kwargs[modality][item_idx]
|
|
grid_thw = out_item[f"{modality}_grid_thw"].data
|
|
assert isinstance(grid_thw, torch.Tensor)
|
|
|
|
num_tokens = int(grid_thw.prod()) // merge_length
|
|
return [placeholder[modality]] * num_tokens
|
|
|
|
return [
|
|
PromptReplacement(
|
|
modality=modality,
|
|
target=[placeholder[modality]],
|
|
replacement=partial(get_replacement_keye, modality=modality),
|
|
) for modality in ("image", "video")
|
|
]
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
return _keye_field_config(hf_inputs)
|
|
|
|
|
|
class BaseKeyeModule(nn.Module):
|
|
packed_modules_mapping = {
|
|
"qkv_proj": [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
],
|
|
"gate_up_proj": [
|
|
"gate_proj",
|
|
"up_proj",
|
|
],
|
|
}
|
|
|
|
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
|
"lm_head.": "language_model.lm_head.",
|
|
"model.": "language_model.model.",
|
|
})
|
|
|
|
@classmethod
|
|
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
|
if modality.startswith("image"):
|
|
return "<|vision_start|><|image_pad|><|vision_end|>"
|
|
if modality.startswith("video"):
|
|
return "<|vision_start|><|video_pad|><|vision_end|>"
|
|
|
|
raise ValueError("Only image or video modality is supported")
|
|
|
|
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
|
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
|
return None
|
|
return quant_config
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config: PretrainedConfig = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
|
|
self.visual = KeyeSiglipVisionModel(
|
|
config.vision_config,
|
|
quant_config=self._maybe_ignore_quant_config(quant_config),
|
|
prefix=maybe_prefix(prefix, "visual"),
|
|
)
|
|
|
|
self.mlp_AR = self._build_projector(
|
|
config,
|
|
config.vision_config,
|
|
quant_config=self._maybe_ignore_quant_config(quant_config),
|
|
prefix=maybe_prefix(prefix, "mlp_AR"),
|
|
)
|
|
|
|
self.language_model = init_vllm_registered_model(
|
|
vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "language_model"),
|
|
architectures=["Qwen3ForCausalLM"],
|
|
)
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.language_model.make_empty_intermediate_tensors)
|
|
|
|
@abstractmethod
|
|
def _build_projector(self,
|
|
text_config: PretrainedConfig,
|
|
vision_config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "") -> nn.Module:
|
|
raise ValueError("Need projector")
|
|
|
|
def _process_image_input(self,
|
|
image_input: Any) -> tuple[torch.Tensor, ...]:
|
|
siglip_position_ids = list()
|
|
image_grid_hws = list()
|
|
sample_indices = list()
|
|
cu_seqlens = [0]
|
|
|
|
image_grid_thw = image_input["image_grid_thw"]
|
|
assert image_grid_thw.ndim == 2
|
|
|
|
for idx, thaw in enumerate(image_grid_thw):
|
|
thw_tuple = tuple(thaw.detach().cpu().numpy().tolist())
|
|
numel = np.prod(thw_tuple)
|
|
image_grid_hws.append(thw_tuple)
|
|
image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
|
|
siglip_position_ids.append(image_position_ids)
|
|
sample_indices.append(torch.full((numel, ), idx,
|
|
dtype=torch.int64))
|
|
cu_seqlens.append(cu_seqlens[-1] + numel)
|
|
|
|
if image_input["type"] == "image_embeds":
|
|
raise ValueError(
|
|
"Image embeddings are not supported for this processing path.")
|
|
else:
|
|
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
|
siglip_position_ids = torch.concat(siglip_position_ids,
|
|
dim=0).to(pixel_values.device)
|
|
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
|
|
pixel_values.device)
|
|
sample_indices = torch.concat(sample_indices,
|
|
dim=0).to(pixel_values.device)
|
|
|
|
image_embeds = self.visual(
|
|
pixel_values=pixel_values,
|
|
image_grid_thw=image_grid_hws,
|
|
position_ids=siglip_position_ids,
|
|
vision_return_embed_list=False,
|
|
interpolate_pos_encoding=True,
|
|
sample_indices=sample_indices,
|
|
cu_seqlens=cu_seqlens,
|
|
use_rope=True,
|
|
window_size=-1,
|
|
)
|
|
image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw))
|
|
return image_embeds
|
|
|
|
def _process_video_embeds(
|
|
self,
|
|
video_type: Literal["video_embeds", "pixel_values_videos"],
|
|
video_grid_thw: list[torch.Tensor],
|
|
pixel_values_videos: Optional[torch.Tensor] = None
|
|
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
|
siglip_position_ids = list()
|
|
video_grid_hws = list()
|
|
sample_indices = list()
|
|
cu_seqlens = [0]
|
|
|
|
assert video_grid_thw.ndim == 2
|
|
for idx, sub_thw in enumerate(video_grid_thw):
|
|
thw_tuple = tuple(sub_thw.detach().cpu().numpy().tolist())
|
|
numel = np.prod(thw_tuple)
|
|
|
|
video_grid_hws.append(thw_tuple)
|
|
video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
|
|
siglip_position_ids.append(video_position_ids)
|
|
sample_indices.append(torch.full((numel, ), idx,
|
|
dtype=torch.int64))
|
|
cu_seqlens.append(cu_seqlens[-1] + numel)
|
|
|
|
if video_type == "video_embeds":
|
|
raise ValueError(
|
|
"Video embeddings are not supported for this processing path.")
|
|
else:
|
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
|
siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
|
|
pixel_values_videos.device)
|
|
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
|
|
pixel_values_videos.device)
|
|
sample_indices = torch.concat(sample_indices,
|
|
dim=0).to(pixel_values_videos.device)
|
|
|
|
video_embeds = self.visual(
|
|
pixel_values=pixel_values_videos,
|
|
image_grid_thw=video_grid_hws,
|
|
position_ids=siglip_position_ids,
|
|
vision_return_embed_list=True,
|
|
interpolate_pos_encoding=True,
|
|
sample_indices=sample_indices,
|
|
cu_seqlens=cu_seqlens,
|
|
use_rope=True,
|
|
window_size=-1,
|
|
)
|
|
video_embeds = self.mlp_AR(video_embeds, video_grid_thw)
|
|
return video_embeds
|
|
|
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
|
modalities = {}
|
|
|
|
for input_key in kwargs:
|
|
if (input_key in ("pixel_values", "image_embeds")
|
|
and "images" not in modalities):
|
|
modalities["images"] = self._parse_and_validate_image_input(
|
|
**kwargs)
|
|
if (input_key in ("pixel_values_videos", "video_embeds")
|
|
and "videos" not in modalities):
|
|
modalities["videos"] = self._parse_and_validate_video_input(
|
|
**kwargs)
|
|
|
|
return modalities
|
|
|
|
def get_language_model(self) -> torch.nn.Module:
|
|
return self.language_model
|
|
|
|
def get_multimodal_embeddings(
|
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
|
|
|
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
|
if not modalities:
|
|
return None
|
|
|
|
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
|
|
|
for modality in modalities:
|
|
if modality == "images":
|
|
image_input = modalities["images"]
|
|
vision_embeddings = self._process_image_input(image_input)
|
|
multimodal_embeddings += vision_embeddings
|
|
if modality == "videos":
|
|
video_input = modalities["videos"]
|
|
video_embeddings = self._process_video_input(video_input)
|
|
multimodal_embeddings += video_embeddings
|
|
return multimodal_embeddings
|
|
|
|
def get_input_embeddings(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
) -> torch.Tensor:
|
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
if multimodal_embeddings is not None:
|
|
inputs_embeds = merge_multimodal_embeddings(
|
|
input_ids,
|
|
inputs_embeds,
|
|
multimodal_embeddings,
|
|
[
|
|
self.config.image_token_id,
|
|
self.config.video_token_id,
|
|
],
|
|
)
|
|
return inputs_embeds
|
|
|
|
def get_input_embeddings_v0(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
image_input: Optional[Any] = None,
|
|
video_input: Optional[Any] = None,
|
|
) -> torch.Tensor:
|
|
inputs_embeds = self.get_input_embeddings(input_ids)
|
|
if image_input is not None:
|
|
image_embeds = self._process_image_input(image_input)
|
|
inputs_embeds = merge_multimodal_embeddings(
|
|
input_ids,
|
|
inputs_embeds,
|
|
image_embeds,
|
|
placeholder_token_id=self.config.image_token_id,
|
|
)
|
|
|
|
if video_input is not None:
|
|
video_embeds = self._process_video_input(video_input)
|
|
inputs_embeds = merge_multimodal_embeddings(
|
|
input_ids,
|
|
inputs_embeds,
|
|
video_embeds,
|
|
placeholder_token_id=self.config.video_token_id,
|
|
)
|
|
return inputs_embeds
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
**kwargs: object,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
"""Run forward pass for Keye-VL.
|
|
|
|
Args:
|
|
input_ids: Flattened (concatenated) input_ids corresponding to a
|
|
batch.
|
|
positions: Flattened (concatenated) position ids corresponding to a
|
|
batch.
|
|
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
|
opensource models), the shape will be `(3, seq_len)`,
|
|
otherwise it will be `(seq_len,).
|
|
pixel_values: Pixel values to be fed to a model.
|
|
`None` if no images are passed.
|
|
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
|
|
`None` if no images are passed.
|
|
pixel_values_videos: Pixel values of videos to be fed to a model.
|
|
`None` if no videos are passed.
|
|
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
|
|
`None` if no videos are passed.
|
|
"""
|
|
if intermediate_tensors is not None:
|
|
inputs_embeds = None
|
|
|
|
elif inputs_embeds is None:
|
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
|
video_input = self._parse_and_validate_video_input(**kwargs)
|
|
if image_input is None and video_input is None:
|
|
inputs_embeds = None
|
|
else:
|
|
if uses_mrope(self.config):
|
|
assert positions.ndim == 2 and positions.size(0) == 3, (
|
|
"multimodal section rotary embedding requires "
|
|
f"(3, seq_len) positions, but got {positions.size()}")
|
|
inputs_embeds = self.get_input_embeddings_v0(
|
|
input_ids,
|
|
image_input=image_input,
|
|
video_input=video_input,
|
|
)
|
|
input_ids = None
|
|
|
|
hidden_states = self.language_model.model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
return self.language_model.compute_logits(hidden_states,
|
|
sampling_metadata)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str,
|
|
torch.Tensor]]) -> set[str]:
|
|
loader = AutoWeightsLoader(self)
|
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
|
|
|
def get_mm_mapping(self) -> MultiModelKeys:
|
|
"""Get the module prefix in multimodal models."""
|
|
return MultiModelKeys.from_string_field(
|
|
language_model="language_model",
|
|
connector="mlp_AR.",
|
|
tower_model="visual.",
|
|
)
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(
|
|
KeyeMultiModalProcessor,
|
|
info=KeyeProcessingInfo,
|
|
dummy_inputs=KeyeDummyInputsBuilder,
|
|
)
|
|
class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
|
SupportsLoRA, SupportsPP):
|
|
|
|
def _build_projector(self,
|
|
text_config: PretrainedConfig,
|
|
vision_config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "") -> nn.Module:
|
|
return Projector(text_config, vision_config, quant_config, prefix)
|
|
|
|
def _validate_and_reshape_mm_tensor(
|
|
self, mm_input: NestedTensors,
|
|
name: str) -> Union[torch.Tensor, list[torch.Tensor]]:
|
|
if not isinstance(mm_input, (torch.Tensor, list)):
|
|
raise ValueError(f"Incorrect type of {name}. "
|
|
f"Got type: {type(mm_input)}")
|
|
if isinstance(mm_input, torch.Tensor):
|
|
if mm_input.ndim == 2:
|
|
return mm_input
|
|
if mm_input.ndim == 5:
|
|
return mm_input
|
|
if mm_input.ndim != 3:
|
|
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
|
f"Got ndim: {mm_input.ndim} "
|
|
f"(shape={mm_input.shape})")
|
|
return mm_input.reshape(-1, mm_input.shape[-1])
|
|
elif is_list_of(mm_input, torch.Tensor):
|
|
if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
|
|
for p in mm_input):
|
|
return mm_input
|
|
return torch.concat(mm_input)
|
|
|
|
def _parse_and_validate_image_input(
|
|
self, **kwargs: object) -> Optional[KeyeImageInputs]:
|
|
pixel_values = kwargs.pop("pixel_values", None)
|
|
image_embeds = kwargs.pop("image_embeds", None)
|
|
image_grid_thw = kwargs.pop("image_grid_thw", None)
|
|
|
|
if pixel_values is None and image_embeds is None:
|
|
return None
|
|
|
|
if pixel_values is not None:
|
|
pixel_values = self._validate_and_reshape_mm_tensor(
|
|
pixel_values, "image pixel values")
|
|
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
image_grid_thw, "image grid_thw")
|
|
|
|
return KeyeImagePixelInputs(
|
|
type="pixel_values",
|
|
pixel_values=pixel_values,
|
|
image_grid_thw=image_grid_thw,
|
|
)
|
|
|
|
if image_embeds is not None:
|
|
image_embeds = self._validate_and_reshape_mm_tensor(
|
|
image_embeds, "image embeds")
|
|
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
image_grid_thw, "image grid_thw")
|
|
|
|
return KeyeImageEmbeddingInputs(
|
|
type="image_embeds",
|
|
image_embeds=image_embeds,
|
|
image_grid_thw=image_grid_thw,
|
|
)
|
|
|
|
def _parse_and_validate_video_input(
|
|
self, **kwargs: object) -> Optional[KeyeVideoInputs]:
|
|
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
|
|
video_embeds = kwargs.pop("video_embeds", None)
|
|
video_grid_thw = kwargs.pop("video_grid_thw", None)
|
|
|
|
if pixel_values_videos is None and video_embeds is None:
|
|
return None
|
|
|
|
if pixel_values_videos is not None:
|
|
pixel_values_videos = self._validate_and_reshape_mm_tensor(
|
|
pixel_values_videos,
|
|
"video pixel values",
|
|
)
|
|
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
video_grid_thw, "video grid_thw")
|
|
|
|
return KeyeVideoPixelInputs(
|
|
type="pixel_values_videos",
|
|
pixel_values_videos=pixel_values_videos,
|
|
video_grid_thw=video_grid_thw,
|
|
)
|
|
|
|
if video_embeds is not None:
|
|
video_embeds = self._validate_and_reshape_mm_tensor(
|
|
video_embeds, "video embeds")
|
|
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
video_grid_thw, "video grid_thw")
|
|
|
|
return KeyeVideoEmbeddingInputs(
|
|
type="video_embeds",
|
|
video_embeds=video_embeds,
|
|
video_grid_thw=video_grid_thw,
|
|
)
|
|
|
|
def _process_video_input(
|
|
self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]:
|
|
video_type = video_input["type"]
|
|
video_grid_thw = video_input["video_grid_thw"]
|
|
pixel_values_videos = video_input.get("pixel_values_videos", None)
|
|
|
|
return tuple(
|
|
self._process_video_embeds(video_type, video_grid_thw,
|
|
pixel_values_videos))
|