Lukas Geiger 57f94e88ea
[Models] Optimise and simplify _validate_and_reshape_mm_tensor (#24742)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
2025-09-12 15:37:37 +00:00

1698 lines
61 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Literal, Optional, TypeVar, Union
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from transformers import PretrainedConfig
from transformers.activations import GELUActivation
from transformers.feature_extraction_utils import BatchFeature
from transformers.modeling_outputs import (BaseModelOutput,
BaseModelOutputWithPooling)
from transformers.utils import torch_int
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, VideoItem)
from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .siglip import SiglipMLP
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, is_pp_missing_parameter,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vit_attn_backend
logger = init_logger(__name__)
def smart_resize(
height: int,
width: int,
factor: int,
min_pixels: int,
max_pixels: int,
):
if height < factor:
logger.warning(
"smart_resize: height=%s < factor=%s, reset height=factor",
height,
factor,
)
width = round((width * factor) / height)
height = factor
if width < factor:
logger.warning(
"smart_resize: width=%s < factor=%s, reset width=factor",
width,
factor,
)
height = round((height * factor) / width)
width = factor
if max(height, width) / min(height, width) > 200:
raise ValueError("absolute aspect ratio must be smaller than 200, got "
"{max(height, width) / min(height, width)}")
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = math.floor(height / beta / factor) * factor
w_bar = math.floor(width / beta / factor) * factor
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
class KeyeImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- np: Number of patches
- c: Number of channels
- ps: Patch size
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["pixel_values"]
pixel_values: Annotated[
torch.Tensor,
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
class KeyeImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- nf: Number of image features
- hs: Hidden size (must match the hidden size of language model
backbone)
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["image_embeds"]
image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs]
class KeyeVideoPixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- np: Number of patches
- c: Number of channels
- ps: Patch size
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["pixel_values_videos"]
pixel_values_videos: Annotated[
torch.Tensor,
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
class KeyeVideoEmbeddingInputs(TensorSchema):
"""
Dimensions:
- nf: Number of video features
- hs: Hidden size (must match the hidden size of language model
backbone)
- nv: Number of videos
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["video_embeds"]
video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
KeyeVideoInputs = Union[KeyeVideoPixelInputs, KeyeVideoEmbeddingInputs]
class KeyeVisionEmbeddings(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches = (self.image_size // self.patch_size)**2
self.num_positions = self.num_patches
self.cache_position_embedding = dict()
self.cache_position_count = dict()
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)
self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False,
)
def interpolate_pos_encoding(
self,
embeddings: torch.Tensor,
height: int,
width: int,
is_after_patchify: bool = False,
) -> torch.Tensor:
num_positions = self.position_embedding.weight.shape[0]
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
dim = embeddings.shape[-1]
if is_after_patchify:
new_height = height
new_width = width
else:
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions,
sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bilinear",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
def fetch_position_embedding_lfu_cache(self,
embeddings,
h,
w,
max_cache: int = 20):
grid = (h, w)
if grid in self.cache_position_embedding:
self.cache_position_count[grid] += 1
return self.cache_position_embedding[grid]
if len(self.cache_position_embedding) >= max_cache:
min_hit_grid = min(
self.cache_position_count,
key=self.cache_position_count.get,
)
self.cache_position_count.pop(min_hit_grid)
self.cache_position_embedding.pop(min_hit_grid)
position_embedding = self.interpolate_pos_encoding(
embeddings, h, w, True)
self.cache_position_count[grid] = 1
self.cache_position_embedding[grid] = position_embedding
return position_embedding
def forward(
self,
pixel_values: torch.FloatTensor,
position_ids: Optional[torch.Tensor] = None,
image_grid_thw: Optional[list[Union[
tuple[int, int, int],
list[tuple[int, int, int]],
]]] = None,
interpolate_pos_encoding=False,
) -> torch.Tensor:
if pixel_values.dim() == 4:
pixel_values = pixel_values.unsqueeze(0)
if pixel_values.dim() == 5:
if position_ids is None:
raise ValueError(
"position_ids cannot be None when pixel_values.dim() is 5."
)
(
batch_size,
squence_len,
channel,
height,
width,
) = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w")
patch_embeds = self.patch_embedding(
pixel_values.to(dtype=target_dtype))
embeddings = patch_embeds.flatten(-2).squeeze(-1)
if interpolate_pos_encoding and image_grid_thw is not None:
start = 0
tmp_embeddings = list()
for image_grid in image_grid_thw:
t, h, w = image_grid
end = start + t * h * w
image_embeddings = embeddings[start:end, :]
position_embedding = (self.interpolate_pos_encoding(
image_embeddings, h, w, True).squeeze(0).repeat(t, 1))
image_embeddings = image_embeddings + position_embedding
tmp_embeddings.append(image_embeddings)
start = end
embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0)
else:
embeddings = embeddings + self.packing_position_embedding(
position_ids)
return embeddings
else:
raise ValueError("Unsupported pixel_values dimension:"
f" {pixel_values.dim()}. Expected 4 or 5.")
def apply_rotary_pos_emb_flashatt(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
return q_embed, k_embed
class KeyeSiglipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You
Need' paper."""
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
hidden_size = config.hidden_size
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_attention_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scale = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
input_size=hidden_size,
output_size=hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
head_size=self.head_dim, dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
raise RuntimeError(
f"Keye-VL does not support {self.attn_backend} backend now.")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
cu_seqlens: Optional[list[torch.Tensor]] = None,
rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split(
[self.q_size, self.kv_size, self.kv_size],
dim=-1,
)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
batch_size = q.shape[0]
if rope_emb is None:
q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
k = k.view(
*k.shape[:-1],
self.num_kv_heads,
self.head_dim,
)
v = v.view(
*v.shape[:-1],
self.num_kv_heads,
self.head_dim,
)
else:
if cu_seqlens is None:
raise ValueError(
"cu_seqlens cannot be None when rope_emb is not None.")
cos, sin = rope_emb
q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
k = k.view(
*k.shape[:-1],
self.num_kv_heads,
self.head_dim,
)
q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin)
v = v.view(
*v.shape[:-1],
self.num_kv_heads,
self.head_dim,
)
if self.attn_backend == _Backend.FLASH_ATTN:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=False,
softmax_scale=self.scale,
)
context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
kv_seqlen=None,
device=q.device)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None)
context_layer = rearrange(context_layer,
"b s h d -> b s (h d)").contiguous()
output, _ = self.out_proj(context_layer)
return output
class SigLIPRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
self.rope_init()
def rope_init(self):
inv_freq = 1.0 / (self.theta**(
torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(
seqlen,
device=self.inv_freq.device,
dtype=self.inv_freq.dtype,
)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class KeyeSiglipEncoderLayer(nn.Module):
def __init__(
self,
config: Union[PretrainedConfig],
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.embed_dim = config.hidden_size
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.self_attn = KeyeSiglipAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
cu_seqlens: Optional[list[torch.Tensor]] = None,
rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
) -> tuple[torch.FloatTensor]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
cu_seqlens=cu_seqlens,
rope_emb=rope_emb,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class KeyeSiglipEncoder(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
embed_dim = config.hidden_size
num_heads = config.num_attention_heads
head_dim = embed_dim // num_heads
self.layers = nn.ModuleList([
KeyeSiglipEncoderLayer(
config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
) for layer_idx in range(config.num_hidden_layers)
])
self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)
@staticmethod
def flatten_list(image_grid_thw):
tmp_image_grid_thw = list()
for image_grid in image_grid_thw:
if isinstance(image_grid, list):
tmp_image_grid_thw.extend(image_grid)
else:
tmp_image_grid_thw.append(image_grid)
return tmp_image_grid_thw
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cu_seqlens: Optional[list[torch.Tensor]] = None,
image_grid_thw: Optional[list[Union[
tuple[int, int, int],
list[tuple[int, int, int]],
]]] = None,
height_position_ids: Optional[torch.Tensor] = None,
width_position_ids: Optional[torch.Tensor] = None,
use_rope: Optional[bool] = False,
window_size: Optional[bool] = -1,
vision_or_text: str = "vision",
) -> BaseModelOutput:
device = inputs_embeds.device
hidden_states = inputs_embeds
if use_rope is True:
flatten_image_grid_thw = self.flatten_list(image_grid_thw)
if width_position_ids is None or height_position_ids is None:
split_hids = list()
split_wids = list()
for t, h, w in flatten_image_grid_thw:
image_pids = torch.arange(t * h * w,
device=device) % (h * w)
sample_hids = image_pids // w
sample_wids = image_pids % w
split_hids.append(sample_hids)
split_wids.append(sample_wids)
width_position_ids = torch.concat(split_wids, dim=0)
height_position_ids = torch.concat(split_hids, dim=0)
pids = torch.stack(
[height_position_ids, width_position_ids],
dim=-1,
)
max_grid_size = pids.max() + 1
rope_emb_max_grid = self.rotary_pos_emb(max_grid_size)
rope_emb = rope_emb_max_grid[pids].flatten(1)
rope_emb = rope_emb.repeat(1, 2)
rope_emb = (rope_emb.cos(), rope_emb.sin())
else:
rope_emb = None
attn_cu_seqlens = cu_seqlens
hidden_states = inputs_embeds
assert attention_mask is None
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
cu_seqlens=attn_cu_seqlens,
rope_emb=rope_emb,
)
return hidden_states
class KeyeSiglipVisionTransformer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = KeyeVisionEmbeddings(config)
self.encoder = KeyeSiglipEncoder(
config,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
)
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
def forward(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
attention_mask: Optional[torch.Tensor] = None,
sample_indices: Optional[torch.Tensor] = None,
image_indices: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
height_position_ids: Optional[torch.Tensor] = None,
width_position_ids: Optional[torch.Tensor] = None,
cu_seqlens: Optional[list[torch.Tensor]] = None,
padding_mask: Optional[torch.Tensor] = None,
vision_return_embed_list: Optional[bool] = False,
image_grid_thw: Optional[list[Union[
tuple[int, int, int],
list[tuple[int, int, int]],
]]] = None,
return_pooler_output: Optional[bool] = True,
use_rope: Optional[bool] = False,
window_size: Optional[bool] = -1,
) -> BaseModelOutputWithPooling:
hidden_states = self.embeddings(
pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
position_ids=position_ids,
image_grid_thw=image_grid_thw,
)
last_hidden_state = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
attention_mask=attention_mask,
cu_seqlens=cu_seqlens,
image_grid_thw=image_grid_thw,
use_rope=use_rope,
height_position_ids=height_position_ids,
width_position_ids=width_position_ids,
window_size=window_size,
vision_or_text="vision",
)
last_hidden_state = self.post_layernorm(last_hidden_state)
sample_hidden_state = list()
if cu_seqlens is None:
raise ValueError("cu_seqlens cannot be None for "
"SiglipVisionTransformer output processing.")
for i in range(cu_seqlens.shape[0] - 1):
start = cu_seqlens[i]
end = cu_seqlens[i + 1]
tensor = last_hidden_state[:, start:end, :].squeeze(0)
sample_hidden_state.append(tensor)
return sample_hidden_state
class KeyeSiglipVisionModel(nn.Module):
config_class = PretrainedConfig
main_input_name = "pixel_values"
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.vision_model = KeyeSiglipVisionTransformer(
config,
quant_config=quant_config,
prefix=f"{prefix}.vision_model",
)
self.quant_config = quant_config
@property
def dtype(self) -> torch.dtype:
return self.vision_model.embeddings.patch_embedding.weight.dtype
@property
def device(self) -> torch.device:
return self.vision_model.embeddings.patch_embedding.weight.device
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def forward(
self,
pixel_values,
sample_indices: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
position_ids: Optional[torch.Tensor] = None,
vision_return_embed_list: Optional[bool] = False,
image_grid_thw: Optional[list[Union[
tuple[int, int, int],
list[tuple[int, int, int]],
]]] = None,
cu_seqlens: Optional[list[torch.Tensor]] = None,
return_pooler_output: Optional[bool] = True,
use_rope: Optional[bool] = False,
window_size: Optional[bool] = -1,
) -> BaseModelOutputWithPooling:
return self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
position_ids=position_ids,
vision_return_embed_list=vision_return_embed_list,
image_grid_thw=image_grid_thw,
sample_indices=sample_indices,
cu_seqlens=cu_seqlens,
return_pooler_output=return_pooler_output,
use_rope=use_rope,
window_size=window_size,
)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "head.attention" in name or "head.layernorm" in name:
continue
if "head.mlp" in name or "head.probe" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)):
param = params_dict[scale_name]
weight_loader = getattr(
param,
"weight_loader",
default_weight_loader,
)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (
param_name,
weight_name,
shard_id,
) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(
param,
"weight_loader",
default_weight_loader,
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Projector(nn.Module):
def __init__(
self,
text_config: PretrainedConfig,
vision_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.text_config = text_config
self.vision_config = vision_config
self.merge_kernel_size = (2, 2)
self.hidden_size = (self.vision_config.hidden_size *
self.merge_kernel_size[0] *
self.merge_kernel_size[1])
self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size,
eps=1e-05)
self.act = GELUActivation()
self.linear_1 = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_1",
)
self.linear_2 = RowParallelLinear(
self.hidden_size,
self.text_config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_2",
)
def forward(
self,
image_features: Union[torch.Tensor, list[torch.Tensor]],
image_grid_thw: list[tuple[int, int, int]],
) -> Union[torch.Tensor, list[torch.Tensor]]:
m1, m2 = self.merge_kernel_size
if isinstance(image_features, (list, tuple)):
processed_features = list()
for image_feature, image_grid in zip(image_features,
image_grid_thw):
image_feature = self.pre_norm(image_feature)
t, h, w = image_grid
image_feature = rearrange(
image_feature,
"(t h p1 w p2) d -> (t h w) (p1 p2 d)",
t=t,
h=h // m1,
p1=m1,
w=w // m2,
p2=m2,
)
hidden_states, _ = self.linear_1(image_feature)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.linear_2(hidden_states)
processed_features.append(hidden_states)
return processed_features
dims = image_features.shape[:-1]
dim = image_features.shape[-1]
image_features = image_features.view(np.prod(dims), dim)
hidden_states = self.pre_norm(image_features).view(
-1, self.hidden_size)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states.view(*dims, -1)
def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ):
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_grid_sizes = image_grid_thw.prod(-1)
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
)
class KeyeMultiModalDataParser(MultiModalDataParser):
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="image",
required_fields={
"image_embeds",
"image_grid_thw",
},
fields_factory=_keye_field_config,
)
return super()._parse_image_data(data)
def _parse_video_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="video",
required_fields={
"video_embeds",
"video_grid_thw",
},
fields_factory=_keye_field_config,
)
return super()._parse_video_data(data)
class KeyeProcessingInfo(BaseProcessingInfo):
def get_max_image_size(self) -> int:
return 9999999 #_MAX_IMAGE_SIZE
def get_max_frame_per_video(self) -> int:
return 16 #_MAX_FRAMES_PER_VIDEO
def get_image_processor(self, **kwargs: object):
return self.get_hf_processor(**kwargs).image_processor
def get_supported_mm_limits(self, ) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len),
}
def _get_vision_info(
self,
*,
image_width: int,
image_height: int,
num_frames: int = 1,
do_resize: bool = True,
image_processor,
) -> tuple[ImageSize, int]:
if image_processor is None:
image_processor = self.get_image_processor()
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = 1
if do_resize:
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
)
preprocessed_size = ImageSize(width=resized_width,
height=resized_height)
else:
preprocessed_size = ImageSize(width=image_width,
height=image_height)
padded_num_frames = num_frames + num_frames % temporal_patch_size
grid_t = max(padded_num_frames // temporal_patch_size, 1)
grid_h = preprocessed_size.height // patch_size
grid_w = preprocessed_size.width // patch_size
num_patches = grid_t * grid_h * grid_w
num_vision_tokens = num_patches // (merge_size**2)
return preprocessed_size, num_vision_tokens
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
image_processor,
) -> int:
_, num_image_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
image_processor=image_processor,
)
return num_image_tokens
def get_num_video_tokens(
self,
*,
image_width: int,
image_height: int,
num_frames: int,
image_processor,
) -> int:
_, num_video_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
num_frames=num_frames,
image_processor=image_processor,
)
return num_video_tokens
def get_image_size_with_most_features(self, ) -> ImageSize:
max_image_size, _ = self._get_vision_info(
image_width=self.get_max_image_size(),
image_height=self.get_max_image_size(),
image_processor=None,
)
return max_image_size
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
image_processor=None,
)
def _get_max_video_frames(self, max_tokens: int) -> int:
target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0
while True:
next_num_frames = num_frames + 1
next_max_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
image_processor=None,
)
if next_max_tokens > max_tokens:
break
num_frames = next_num_frames
return num_frames
def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.get_limit_per_prompt("video")
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
max_frames_per_video = min(
max_total_frames // max(max_videos, 1),
self.get_max_frame_per_video(),
)
return max(max_frames_per_video, 1)
def get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len),
image_processor=None,
)
_I = TypeVar("_I", bound=KeyeProcessingInfo)
class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token
video_token: str = hf_processor.video_token
return image_token * num_images + video_token * num_videos
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
target_width, target_height = (
self.info.get_image_size_with_most_features())
target_num_frames = self.info.get_num_frames_with_most_features(
seq_len)
mm_data = {
"image":
self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=target_num_frames,
num_videos=num_videos,
),
}
return mm_data
class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]):
...
class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
return KeyeMultiModalDataParser()
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor(
**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
placeholder = {
"image": vocab[hf_processor.image_token],
"video": vocab[hf_processor.video_token],
}
merge_length = image_processor.merge_size**2
def get_replacement_keye(item_idx: int, modality: str):
out_item = out_mm_kwargs[modality][item_idx]
grid_thw = out_item[f"{modality}_grid_thw"].data
assert isinstance(grid_thw, torch.Tensor)
num_tokens = int(grid_thw.prod()) // merge_length
return [placeholder[modality]] * num_tokens
return [
PromptReplacement(
modality=modality,
target=[placeholder[modality]],
replacement=partial(get_replacement_keye, modality=modality),
) for modality in ("image", "video")
]
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _keye_field_config(hf_inputs)
class BaseKeyeModule(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
"lm_head.": "language_model.lm_head.",
"model.": "language_model.model.",
})
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
return "<|vision_start|><|image_pad|><|vision_end|>"
if modality.startswith("video"):
return "<|vision_start|><|video_pad|><|vision_end|>"
raise ValueError("Only image or video modality is supported")
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.visual = KeyeSiglipVisionModel(
config.vision_config,
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
self.mlp_AR = self._build_projector(
config,
config.vision_config,
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "mlp_AR"),
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen3ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@abstractmethod
def _build_projector(self,
text_config: PretrainedConfig,
vision_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
raise ValueError("Need projector")
def _process_image_input(self,
image_input: Any) -> tuple[torch.Tensor, ...]:
siglip_position_ids = list()
image_grid_hws = list()
sample_indices = list()
cu_seqlens = [0]
image_grid_thw = image_input["image_grid_thw"]
assert image_grid_thw.ndim == 2
for idx, thaw in enumerate(image_grid_thw):
thw_tuple = tuple(thaw.detach().cpu().numpy().tolist())
numel = np.prod(thw_tuple)
image_grid_hws.append(thw_tuple)
image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
siglip_position_ids.append(image_position_ids)
sample_indices.append(torch.full((numel, ), idx,
dtype=torch.int64))
cu_seqlens.append(cu_seqlens[-1] + numel)
if image_input["type"] == "image_embeds":
raise ValueError(
"Image embeddings are not supported for this processing path.")
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
siglip_position_ids = torch.concat(siglip_position_ids,
dim=0).to(pixel_values.device)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
pixel_values.device)
sample_indices = torch.concat(sample_indices,
dim=0).to(pixel_values.device)
image_embeds = self.visual(
pixel_values=pixel_values,
image_grid_thw=image_grid_hws,
position_ids=siglip_position_ids,
vision_return_embed_list=False,
interpolate_pos_encoding=True,
sample_indices=sample_indices,
cu_seqlens=cu_seqlens,
use_rope=True,
window_size=-1,
)
image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw))
return image_embeds
def _process_video_embeds(
self,
video_type: Literal["video_embeds", "pixel_values_videos"],
video_grid_thw: list[torch.Tensor],
pixel_values_videos: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, list[torch.Tensor]]:
siglip_position_ids = list()
video_grid_hws = list()
sample_indices = list()
cu_seqlens = [0]
assert video_grid_thw.ndim == 2
for idx, sub_thw in enumerate(video_grid_thw):
thw_tuple = tuple(sub_thw.detach().cpu().numpy().tolist())
numel = np.prod(thw_tuple)
video_grid_hws.append(thw_tuple)
video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
siglip_position_ids.append(video_position_ids)
sample_indices.append(torch.full((numel, ), idx,
dtype=torch.int64))
cu_seqlens.append(cu_seqlens[-1] + numel)
if video_type == "video_embeds":
raise ValueError(
"Video embeddings are not supported for this processing path.")
else:
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
pixel_values_videos.device)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
pixel_values_videos.device)
sample_indices = torch.concat(sample_indices,
dim=0).to(pixel_values_videos.device)
video_embeds = self.visual(
pixel_values=pixel_values_videos,
image_grid_thw=video_grid_hws,
position_ids=siglip_position_ids,
vision_return_embed_list=True,
interpolate_pos_encoding=True,
sample_indices=sample_indices,
cu_seqlens=cu_seqlens,
use_rope=True,
window_size=-1,
)
video_embeds = self.mlp_AR(video_embeds, video_grid_thw)
return video_embeds
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
for input_key in kwargs:
if (input_key in ("pixel_values", "image_embeds")
and "images" not in modalities):
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if (input_key in ("pixel_values_videos", "video_embeds")
and "videos" not in modalities):
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
[
self.config.image_token_id,
self.config.video_token_id,
],
)
return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,
image_input: Optional[Any] = None,
video_input: Optional[Any] = None,
) -> torch.Tensor:
inputs_embeds = self.get_input_embeddings(input_ids)
if image_input is not None:
image_embeds = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
image_embeds,
placeholder_token_id=self.config.image_token_id,
)
if video_input is not None:
video_embeds = self._process_video_input(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
video_embeds,
placeholder_token_id=self.config.video_token_id,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for Keye-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
pixel_values_videos: Pixel values of videos to be fed to a model.
`None` if no videos are passed.
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
"""
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
if image_input is None and video_input is None:
inputs_embeds = None
else:
if uses_mrope(self.config):
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}")
inputs_embeds = self.get_input_embeddings_v0(
input_ids,
image_input=image_input,
video_input=video_input,
)
input_ids = None
hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""Get the module prefix in multimodal models."""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="mlp_AR.",
tower_model="visual.",
)
@MULTIMODAL_REGISTRY.register_processor(
KeyeMultiModalProcessor,
info=KeyeProcessingInfo,
dummy_inputs=KeyeDummyInputsBuilder,
)
class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
SupportsLoRA, SupportsPP):
def _build_projector(self,
text_config: PretrainedConfig,
vision_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
return Projector(text_config, vision_config, quant_config, prefix)
def _validate_and_reshape_mm_tensor(
self, mm_input: NestedTensors,
name: str) -> Union[torch.Tensor, list[torch.Tensor]]:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim == 5:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return mm_input.reshape(-1, mm_input.shape[-1])
elif is_list_of(mm_input, torch.Tensor):
if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
for p in mm_input):
return mm_input
return torch.concat(mm_input)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[KeyeImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
return KeyeImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, "image embeds")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
return KeyeImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
image_grid_thw=image_grid_thw,
)
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[KeyeVideoInputs]:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_embeds = kwargs.pop("video_embeds", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
if pixel_values_videos is None and video_embeds is None:
return None
if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos,
"video pixel values",
)
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return KeyeVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
)
if video_embeds is not None:
video_embeds = self._validate_and_reshape_mm_tensor(
video_embeds, "video embeds")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return KeyeVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw,
)
def _process_video_input(
self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]:
video_type = video_input["type"]
video_grid_thw = video_input["video_grid_thw"]
pixel_values_videos = video_input.get("pixel_values_videos", None)
return tuple(
self._process_video_embeds(video_type, video_grid_thw,
pixel_values_videos))