# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence from functools import partial from typing import Annotated, Any, Literal, TypeAlias, TypeVar import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from transformers import PretrainedConfig from transformers.activations import GELUActivation from transformers.feature_extraction_utils import BatchFeature from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.utils import torch_int from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( maybe_get_vit_flash_attn_backend, ) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( ImageItem, ModalityData, MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, ) from vllm.multimodal.parse import ( DictEmbeddingItems, ImageSize, ModalityDataItems, MultiModalDataItems, MultiModalDataParser, ) from vllm.multimodal.processing import ( BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE, SupportsMultiModal, SupportsPP, ) from .siglip import SiglipMLP from .utils import ( AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, is_pp_missing_parameter, maybe_prefix, ) from .vision import get_vit_attn_backend logger = init_logger(__name__) def smart_resize( height: int, width: int, factor: int, min_pixels: int, max_pixels: int, ): if height < factor: logger.warning( "smart_resize: height=%s < factor=%s, reset height=factor", height, factor, ) width = round((width * factor) / height) height = factor if width < factor: logger.warning( "smart_resize: width=%s < factor=%s, reset width=factor", width, factor, ) height = round((height * factor) / width) width = factor if max(height, width) / min(height, width) > 200: raise ValueError( "absolute aspect ratio must be smaller than 200, got " "{max(height, width) / min(height, width)}" ) h_bar = round(height / factor) * factor w_bar = round(width / factor) * factor if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = math.floor(height / beta / factor) * factor w_bar = math.floor(width / beta / factor) * factor elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = math.ceil(height * beta / factor) * factor w_bar = math.ceil(width * beta / factor) * factor return h_bar, w_bar class KeyeImagePixelInputs(TensorSchema): """ Dimensions: - bnp: Batch size * Number of patches - c: Number of channels - ps: Patch size - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ type: Literal["pixel_values"] pixel_values: Annotated[ torch.Tensor, TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"}) ] image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] class KeyeImageEmbeddingInputs(TensorSchema): """ Dimensions: - nf: Number of image features - hs: Hidden size (must match the hidden size of language model backbone) - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ type: Literal["image_embeds"] image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] KeyeImageInputs: TypeAlias = KeyeImagePixelInputs | KeyeImageEmbeddingInputs class KeyeVideoPixelInputs(TensorSchema): """ Dimensions: - bnp: Batch size * Number of patches - c: Number of channels - ps: Patch size - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ type: Literal["pixel_values_videos"] pixel_values_videos: Annotated[ torch.Tensor, TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"}) ] video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] class KeyeVideoEmbeddingInputs(TensorSchema): """ Dimensions: - nf: Number of video features - hs: Hidden size (must match the hidden size of language model backbone) - nv: Number of videos - g: Grid dimensions (3 for t, h, w) """ type: Literal["video_embeds"] video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] KeyeVideoInputs: TypeAlias = KeyeVideoPixelInputs | KeyeVideoEmbeddingInputs class KeyeVisionEmbeddings(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.patch_embedding = Conv2dLayer( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.cache_position_embedding = dict() self.cache_position_count = dict() self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.packing_position_embedding = nn.Embedding(32768, self.embed_dim) self.register_buffer( "position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False, ) def interpolate_pos_encoding( self, embeddings: torch.Tensor, height: int, width: int, is_after_patchify: bool = False, ) -> torch.Tensor: num_positions = self.position_embedding.weight.shape[0] patch_pos_embed = self.position_embedding.weight.unsqueeze(0) dim = embeddings.shape[-1] if is_after_patchify: new_height = height new_width = width else: new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape( 1, sqrt_num_positions, sqrt_num_positions, dim ) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(new_height, new_width), mode="bilinear", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed def fetch_position_embedding_lfu_cache(self, embeddings, h, w, max_cache: int = 20): grid = (h, w) if grid in self.cache_position_embedding: self.cache_position_count[grid] += 1 return self.cache_position_embedding[grid] if len(self.cache_position_embedding) >= max_cache: min_hit_grid = min( self.cache_position_count, key=self.cache_position_count.get, ) self.cache_position_count.pop(min_hit_grid) self.cache_position_embedding.pop(min_hit_grid) position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True) self.cache_position_count[grid] = 1 self.cache_position_embedding[grid] = position_embedding return position_embedding def forward( self, pixel_values: torch.FloatTensor, position_ids: torch.Tensor | None = None, image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, interpolate_pos_encoding=False, ) -> torch.Tensor: if pixel_values.dim() == 4: pixel_values = pixel_values.unsqueeze(0) if pixel_values.dim() == 5: if position_ids is None: raise ValueError( "position_ids cannot be None when pixel_values.dim() is 5." ) ( batch_size, squence_len, channel, height, width, ) = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w") patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) embeddings = patch_embeds.flatten(-2).squeeze(-1) if interpolate_pos_encoding and image_grid_thw is not None: start = 0 tmp_embeddings = list() for image_grid in image_grid_thw: t, h, w = image_grid end = start + t * h * w image_embeddings = embeddings[start:end, :] position_embedding = ( self.interpolate_pos_encoding(image_embeddings, h, w, True) .squeeze(0) .repeat(t, 1) ) image_embeddings = image_embeddings + position_embedding tmp_embeddings.append(image_embeddings) start = end embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0) else: embeddings = embeddings + self.packing_position_embedding(position_ids) return embeddings else: raise ValueError( "Unsupported pixel_values dimension:" f" {pixel_values.dim()}. Expected 4 or 5." ) def apply_rotary_pos_emb_flashatt( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() if current_platform.is_cuda(): from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb elif current_platform.is_rocm(): from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb else: # For other platforms, use PyTorch fallback from vllm.model_executor.layers.rotary_embedding.common import ( apply_rotary_emb_torch, ) apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True) q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) return q_embed, k_embed class KeyeSiglipAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper.""" def __init__( self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config hidden_size = config.hidden_size self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = config.num_attention_heads if self.total_num_kv_heads >= tp_size: assert self.total_num_kv_heads % tp_size == 0 else: assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = config.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scale = self.head_dim**-0.5 self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=True, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.out_proj = RowParallelLinear( input_size=hidden_size, output_size=hidden_size, quant_config=quant_config, prefix=f"{prefix}.out_proj", ) # Detect attention implementation. self.attn_backend = get_vit_attn_backend( head_size=self.head_dim, dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, attn_backend_override=attn_backend_override, ) ) if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Keye-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, } def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, output_attentions: bool | None = False, cu_seqlens: list[torch.Tensor] | None = None, rope_emb: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split( [self.q_size, self.kv_size, self.kv_size], dim=-1, ) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() batch_size = q.shape[0] if rope_emb is None: q = q.view(*q.shape[:-1], self.num_heads, self.head_dim) k = k.view( *k.shape[:-1], self.num_kv_heads, self.head_dim, ) v = v.view( *v.shape[:-1], self.num_kv_heads, self.head_dim, ) else: if cu_seqlens is None: raise ValueError("cu_seqlens cannot be None when rope_emb is not None.") cos, sin = rope_emb q = q.view(*q.shape[:-1], self.num_heads, self.head_dim) k = k.view( *k.shape[:-1], self.num_kv_heads, self.head_dim, ) q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin) v = v.view( *v.shape[:-1], self.num_kv_heads, self.head_dim, ) if self.is_flash_attn_backend: q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) output = self.flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, causal=False, softmax_scale=self.scale, ) context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: outputs = [] for i in range(1, len(cu_seqlens)): start_idx = cu_seqlens[i - 1] end_idx = cu_seqlens[i] q_i = q[:, start_idx:end_idx] k_i = k[:, start_idx:end_idx] v_i = v[:, start_idx:end_idx] q_i, k_i, v_i = ( rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i) ) output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous() output, _ = self.out_proj(context_layer) return output class SigLIPRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta self.rope_init() def rope_init(self): inv_freq = 1.0 / ( self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: seq = torch.arange( seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype, ) freqs = torch.outer(seq, self.inv_freq) return freqs class KeyeSiglipEncoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.embed_dim = config.hidden_size self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.self_attn = KeyeSiglipAttention( config, quant_config=quant_config, prefix=f"{prefix}.self_attn", attn_backend_override=attn_backend_override, ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( config, quant_config=quant_config, prefix=f"{prefix}.mlp", ) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool | None = False, cu_seqlens: list[torch.Tensor] | None = None, rope_emb: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> tuple[torch.FloatTensor]: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, cu_seqlens=cu_seqlens, rope_emb=rope_emb, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class KeyeSiglipEncoder(nn.Module): def __init__( self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config embed_dim = config.hidden_size num_heads = config.num_attention_heads head_dim = embed_dim // num_heads self.layers = nn.ModuleList( [ KeyeSiglipEncoderLayer( config, quant_config=quant_config, prefix=f"{prefix}.layers.{layer_idx}", attn_backend_override=attn_backend_override, ) for layer_idx in range(config.num_hidden_layers) ] ) self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2) @staticmethod def flatten_list(image_grid_thw): tmp_image_grid_thw = list() for image_grid in image_grid_thw: if isinstance(image_grid, list): tmp_image_grid_thw.extend(image_grid) else: tmp_image_grid_thw.append(image_grid) return tmp_image_grid_thw def forward( self, inputs_embeds, attention_mask: torch.Tensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, cu_seqlens: list[torch.Tensor] | None = None, image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, height_position_ids: torch.Tensor | None = None, width_position_ids: torch.Tensor | None = None, use_rope: bool | None = False, window_size: bool | None = -1, vision_or_text: str = "vision", ) -> BaseModelOutput: device = inputs_embeds.device hidden_states = inputs_embeds if use_rope is True: flatten_image_grid_thw = self.flatten_list(image_grid_thw) if width_position_ids is None or height_position_ids is None: split_hids = list() split_wids = list() for t, h, w in flatten_image_grid_thw: image_pids = torch.arange(t * h * w, device=device) % (h * w) sample_hids = image_pids // w sample_wids = image_pids % w split_hids.append(sample_hids) split_wids.append(sample_wids) width_position_ids = torch.concat(split_wids, dim=0) height_position_ids = torch.concat(split_hids, dim=0) pids = torch.stack( [height_position_ids, width_position_ids], dim=-1, ) max_grid_size = pids.max() + 1 rope_emb_max_grid = self.rotary_pos_emb(max_grid_size) rope_emb = rope_emb_max_grid[pids].flatten(1) rope_emb = rope_emb.repeat(1, 2) rope_emb = (rope_emb.cos(), rope_emb.sin()) else: rope_emb = None attn_cu_seqlens = cu_seqlens hidden_states = inputs_embeds assert attention_mask is None for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, attention_mask, output_attentions=output_attentions, cu_seqlens=attn_cu_seqlens, rope_emb=rope_emb, ) return hidden_states class KeyeSiglipVisionTransformer(nn.Module): def __init__( self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = KeyeVisionEmbeddings(config) self.encoder = KeyeSiglipEncoder( config, quant_config=quant_config, prefix=f"{prefix}.encoder", attn_backend_override=attn_backend_override, ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, pixel_values, output_attentions: bool | None = None, output_hidden_states: bool | None = None, interpolate_pos_encoding: bool | None = False, attention_mask: torch.Tensor | None = None, sample_indices: torch.Tensor | None = None, image_indices: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, height_position_ids: torch.Tensor | None = None, width_position_ids: torch.Tensor | None = None, cu_seqlens: list[torch.Tensor] | None = None, padding_mask: torch.Tensor | None = None, vision_return_embed_list: bool | None = False, image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, return_pooler_output: bool | None = True, use_rope: bool | None = False, window_size: bool | None = -1, ) -> BaseModelOutputWithPooling: hidden_states = self.embeddings( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, position_ids=position_ids, image_grid_thw=image_grid_thw, ) last_hidden_state = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, attention_mask=attention_mask, cu_seqlens=cu_seqlens, image_grid_thw=image_grid_thw, use_rope=use_rope, height_position_ids=height_position_ids, width_position_ids=width_position_ids, window_size=window_size, vision_or_text="vision", ) last_hidden_state = self.post_layernorm(last_hidden_state) sample_hidden_state = list() if cu_seqlens is None: raise ValueError( "cu_seqlens cannot be None for " "SiglipVisionTransformer output processing." ) for i in range(cu_seqlens.shape[0] - 1): start = cu_seqlens[i] end = cu_seqlens[i + 1] tensor = last_hidden_state[:, start:end, :].squeeze(0) sample_hidden_state.append(tensor) return sample_hidden_state class KeyeSiglipVisionModel(nn.Module): config_class = PretrainedConfig main_input_name = "pixel_values" def __init__( self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.vision_model = KeyeSiglipVisionTransformer( config, quant_config=quant_config, prefix=f"{prefix}.vision_model", attn_backend_override=attn_backend_override, ) self.quant_config = quant_config @property def dtype(self) -> torch.dtype: return self.vision_model.embeddings.patch_embedding.weight.dtype @property def device(self) -> torch.device: return self.vision_model.embeddings.patch_embedding.weight.device def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding def forward( self, pixel_values, sample_indices: torch.Tensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, interpolate_pos_encoding: bool = False, position_ids: torch.Tensor | None = None, vision_return_embed_list: bool | None = False, image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, cu_seqlens: list[torch.Tensor] | None = None, return_pooler_output: bool | None = True, use_rope: bool | None = False, window_size: bool | None = -1, ) -> BaseModelOutputWithPooling: return self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, position_ids=position_ids, vision_return_embed_list=vision_return_embed_list, image_grid_thw=image_grid_thw, sample_indices=sample_indices, cu_seqlens=cu_seqlens, return_pooler_output=return_pooler_output, use_rope=use_rope, window_size=window_size, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if "head.attention" in name or "head.layernorm" in name: continue if "head.mlp" in name or "head.probe" in name: continue if self.quant_config is not None and ( scale_name := self.quant_config.get_cache_scale(name) ): param = params_dict[scale_name] weight_loader = getattr( param, "weight_loader", default_weight_loader, ) loaded_weight = ( loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue for ( param_name, weight_name, shard_id, ) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: if name.endswith(".bias") and name not in params_dict: continue name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader, ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Projector(nn.Module): def __init__( self, text_config: PretrainedConfig, vision_config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.text_config = text_config self.vision_config = vision_config self.merge_kernel_size = (2, 2) self.hidden_size = ( self.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1] ) self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, eps=1e-05) self.act = GELUActivation() self.linear_1 = ColumnParallelLinear( self.hidden_size, self.hidden_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.linear_1", ) self.linear_2 = RowParallelLinear( self.hidden_size, self.text_config.hidden_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.linear_2", ) def forward( self, image_features: torch.Tensor | list[torch.Tensor], image_grid_thw: list[tuple[int, int, int]], ) -> torch.Tensor | list[torch.Tensor]: m1, m2 = self.merge_kernel_size if isinstance(image_features, (list, tuple)): processed_features = list() for image_feature, image_grid in zip(image_features, image_grid_thw): image_feature = self.pre_norm(image_feature) t, h, w = image_grid image_feature = rearrange( image_feature, "(t h p1 w p2) d -> (t h w) (p1 p2 d)", t=t, h=h // m1, p1=m1, w=w // m2, p2=m2, ) hidden_states, _ = self.linear_1(image_feature) hidden_states = self.act(hidden_states) hidden_states, _ = self.linear_2(hidden_states) processed_features.append(hidden_states) return processed_features dims = image_features.shape[:-1] dim = image_features.shape[-1] image_features = image_features.view(np.prod(dims), dim) hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size) hidden_states = self.linear_1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states.view(*dims, -1) def _keye_field_config( hf_inputs: Mapping[str, torch.Tensor], ): image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) image_grid_sizes = image_grid_thw.prod(-1) video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) video_grid_sizes = video_grid_thw.prod(-1) return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), image_grid_thw=MultiModalFieldConfig.batched("image"), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( "video", video_grid_sizes ), video_embeds=MultiModalFieldConfig.flat_from_sizes("video", video_grid_sizes), video_grid_thw=MultiModalFieldConfig.batched("video"), ) class KeyeMultiModalDataParser(MultiModalDataParser): def _parse_image_data( self, data: dict[str, torch.Tensor] | ModalityData[ImageItem], ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, modality="image", required_fields={ "image_embeds", "image_grid_thw", }, fields_factory=_keye_field_config, ) return super()._parse_image_data(data) def _parse_video_data( self, data: dict[str, torch.Tensor] | ModalityData[VideoItem], ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, modality="video", required_fields={ "video_embeds", "video_grid_thw", }, fields_factory=_keye_field_config, ) return super()._parse_video_data(data) class KeyeProcessingInfo(BaseProcessingInfo): def get_max_image_size(self) -> int: return 9999999 # _MAX_IMAGE_SIZE def get_max_frame_per_video(self) -> int: return 16 # _MAX_FRAMES_PER_VIDEO def get_image_processor(self, **kwargs: object): return self.get_hf_processor(**kwargs).image_processor def get_supported_mm_limits( self, ) -> Mapping[str, int | None]: return {"image": None, "video": None} def get_mm_max_tokens_per_item( self, seq_len: int, mm_counts: Mapping[str, int], ) -> Mapping[str, int]: return { "image": self.get_max_image_tokens(), "video": self.get_max_video_tokens(seq_len), } def _get_vision_info( self, *, image_width: int, image_height: int, num_frames: int = 1, do_resize: bool = True, image_processor, ) -> tuple[ImageSize, int]: if image_processor is None: image_processor = self.get_image_processor() hf_config = self.get_hf_config() vision_config = hf_config.vision_config patch_size = vision_config.patch_size merge_size = vision_config.spatial_merge_size temporal_patch_size = 1 if do_resize: resized_height, resized_width = smart_resize( height=image_height, width=image_width, factor=patch_size * merge_size, min_pixels=image_processor.min_pixels, max_pixels=image_processor.max_pixels, ) preprocessed_size = ImageSize(width=resized_width, height=resized_height) else: preprocessed_size = ImageSize(width=image_width, height=image_height) padded_num_frames = num_frames + num_frames % temporal_patch_size grid_t = max(padded_num_frames // temporal_patch_size, 1) grid_h = preprocessed_size.height // patch_size grid_w = preprocessed_size.width // patch_size num_patches = grid_t * grid_h * grid_w num_vision_tokens = num_patches // (merge_size**2) return preprocessed_size, num_vision_tokens def get_num_image_tokens( self, *, image_width: int, image_height: int, image_processor, ) -> int: _, num_image_tokens = self._get_vision_info( image_width=image_width, image_height=image_height, image_processor=image_processor, ) return num_image_tokens def get_num_video_tokens( self, *, image_width: int, image_height: int, num_frames: int, image_processor, ) -> int: _, num_video_tokens = self._get_vision_info( image_width=image_width, image_height=image_height, num_frames=num_frames, image_processor=image_processor, ) return num_video_tokens def get_image_size_with_most_features( self, ) -> ImageSize: max_image_size, _ = self._get_vision_info( image_width=self.get_max_image_size(), image_height=self.get_max_image_size(), image_processor=None, ) return max_image_size def get_max_image_tokens(self) -> int: target_width, target_height = self.get_image_size_with_most_features() return self.get_num_image_tokens( image_width=target_width, image_height=target_height, image_processor=None, ) def _get_max_video_frames(self, max_tokens: int) -> int: target_width, target_height = self.get_image_size_with_most_features() num_frames = 0 while True: next_num_frames = num_frames + 1 next_max_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=next_num_frames, image_processor=None, ) if next_max_tokens > max_tokens: break num_frames = next_num_frames return num_frames def get_num_frames_with_most_features(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() max_images = mm_config.get_limit_per_prompt("image") max_videos = mm_config.get_limit_per_prompt("video") max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) max_frames_per_video = min( max_total_frames // max(max_videos, 1), self.get_max_frame_per_video(), ) return max(max_frames_per_video, 1) def get_max_video_tokens(self, seq_len: int) -> int: target_width, target_height = self.get_image_size_with_most_features() return self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=self.get_num_frames_with_most_features(seq_len), image_processor=None, ) _I = TypeVar("_I", bound=KeyeProcessingInfo) class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) hf_processor = self.info.get_hf_processor() image_token: str = hf_processor.image_token video_token: str = hf_processor.video_token return image_token * num_images + video_token * num_videos def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) target_width, target_height = self.info.get_image_size_with_most_features() target_num_frames = self.info.get_num_frames_with_most_features(seq_len) image_overrides = mm_options.get("image") if mm_options else None video_overrides = mm_options.get("video") if mm_options else None mm_data = { "image": self._get_dummy_images( width=target_width, height=target_height, num_images=num_images, overrides=image_overrides, ), "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos, overrides=video_overrides, ), } return mm_data class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]): ... class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: return KeyeMultiModalDataParser() def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() placeholder = { "image": vocab[hf_processor.image_token], "video": vocab[hf_processor.video_token], } merge_length = image_processor.merge_size**2 def get_replacement_keye(item_idx: int, modality: str): out_item = out_mm_kwargs[modality][item_idx] grid_thw = out_item[f"{modality}_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) num_tokens = int(grid_thw.prod()) // merge_length return [placeholder[modality]] * num_tokens return [ PromptReplacement( modality=modality, target=[placeholder[modality]], replacement=partial(get_replacement_keye, modality=modality), ) for modality in ("image", "video") ] def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return _keye_field_config(hf_inputs) class BaseKeyeModule(nn.Module): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "lm_head.": "language_model.lm_head.", "model.": "language_model.model.", } ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|vision_start|><|image_pad|><|vision_end|>" if modality.startswith("video"): return "<|vision_start|><|video_pad|><|vision_end|>" raise ValueError("Only image or video modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: PretrainedConfig = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config attn_backend_override = ( multimodal_config.mm_encoder_attn_backend if multimodal_config is not None else None ) self.visual = KeyeSiglipVisionModel( config.vision_config, quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), attn_backend_override=attn_backend_override, ) self.mlp_AR = self._build_projector( config, config.vision_config, quant_config=quant_config, prefix=maybe_prefix(prefix, "mlp_AR"), ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model"), architectures=["Qwen3ForCausalLM"], ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors ) @abstractmethod def _build_projector( self, text_config: PretrainedConfig, vision_config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: raise ValueError("Need projector") def _process_image_input(self, image_input: Any) -> tuple[torch.Tensor, ...]: siglip_position_ids = list() image_grid_hws = list() sample_indices = list() cu_seqlens = [0] image_grid_thw = image_input["image_grid_thw"] assert image_grid_thw.ndim == 2 for idx, thaw in enumerate(image_grid_thw): thw_tuple = tuple(thaw.detach().cpu().numpy().tolist()) numel = np.prod(thw_tuple) image_grid_hws.append(thw_tuple) image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) siglip_position_ids.append(image_position_ids) sample_indices.append(torch.full((numel,), idx, dtype=torch.int64)) cu_seqlens.append(cu_seqlens[-1] + numel) if image_input["type"] == "image_embeds": raise ValueError( "Image embeddings are not supported for this processing path." ) else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to( pixel_values.device ) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( pixel_values.device ) sample_indices = torch.concat(sample_indices, dim=0).to(pixel_values.device) image_embeds = self.visual( pixel_values=pixel_values, image_grid_thw=image_grid_hws, position_ids=siglip_position_ids, vision_return_embed_list=False, interpolate_pos_encoding=True, sample_indices=sample_indices, cu_seqlens=cu_seqlens, use_rope=True, window_size=-1, ) image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw)) return image_embeds def _process_video_embeds( self, video_type: Literal["video_embeds", "pixel_values_videos"], video_grid_thw: list[torch.Tensor], pixel_values_videos: torch.Tensor | None = None, ) -> torch.Tensor | list[torch.Tensor]: siglip_position_ids = list() video_grid_hws = list() sample_indices = list() cu_seqlens = [0] assert video_grid_thw.ndim == 2 for idx, sub_thw in enumerate(video_grid_thw): thw_tuple = tuple(sub_thw.detach().cpu().numpy().tolist()) numel = np.prod(thw_tuple) video_grid_hws.append(thw_tuple) video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) siglip_position_ids.append(video_position_ids) sample_indices.append(torch.full((numel,), idx, dtype=torch.int64)) cu_seqlens.append(cu_seqlens[-1] + numel) if video_type == "video_embeds": raise ValueError( "Video embeddings are not supported for this processing path." ) else: pixel_values_videos = pixel_values_videos.type(self.visual.dtype) siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to( pixel_values_videos.device ) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( pixel_values_videos.device ) sample_indices = torch.concat(sample_indices, dim=0).to( pixel_values_videos.device ) video_embeds = self.visual( pixel_values=pixel_values_videos, image_grid_thw=video_grid_hws, position_ids=siglip_position_ids, vision_return_embed_list=True, interpolate_pos_encoding=True, sample_indices=sample_indices, cu_seqlens=cu_seqlens, use_rope=True, window_size=-1, ) video_embeds = self.mlp_AR(video_embeds, video_grid_thw) return video_embeds def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: modalities = {} for input_key in kwargs: if ( input_key in ("pixel_values", "image_embeds") and "images" not in modalities ): modalities["images"] = self._parse_and_validate_image_input(**kwargs) if ( input_key in ("pixel_values_videos", "video_embeds") and "videos" not in modalities ): modalities["videos"] = self._parse_and_validate_video_input(**kwargs) return modalities def get_language_model(self) -> torch.nn.Module: return self.language_model def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return None multimodal_embeddings: tuple[torch.Tensor, ...] = () for modality in modalities: if modality == "images": image_input = modalities["images"] image_embeddings = self._process_image_input(image_input) multimodal_embeddings += tuple(image_embeddings) if modality == "videos": video_input = modalities["videos"] video_embeddings = self._process_video_input(video_input) multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor | IntermediateTensors: """Run forward pass for Keye-VL. Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. positions: Flattened (concatenated) position ids corresponding to a batch. **NOTE**: If mrope is enabled (default setting for Qwen2-VL opensource models), the shape will be `(3, seq_len)`, otherwise it will be `(seq_len,)`. intermediate_tensors: Intermediate tensors from prior forward pass. inputs_embeds: Optional tensor of input embeddings. """ if intermediate_tensors is not None: inputs_embeds = None hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: """Get the module prefix in multimodal models.""" return MultiModelKeys.from_string_field( language_model="language_model", connector="mlp_AR.", tower_model="visual.", ) @MULTIMODAL_REGISTRY.register_processor( KeyeMultiModalProcessor, info=KeyeProcessingInfo, dummy_inputs=KeyeDummyInputsBuilder, ) class KeyeForConditionalGeneration( BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): def _build_projector( self, text_config: PretrainedConfig, vision_config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: return Projector(text_config, vision_config, quant_config, prefix) def _parse_and_validate_image_input( self, **kwargs: object ) -> KeyeImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) if pixel_values is None and image_embeds is None: return None if pixel_values is not None: return KeyeImagePixelInputs( type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw, ) if image_embeds is not None: return KeyeImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, image_grid_thw=image_grid_thw, ) def _parse_and_validate_video_input( self, **kwargs: object ) -> KeyeVideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) if pixel_values_videos is None and video_embeds is None: return None if pixel_values_videos is not None: return KeyeVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, ) if video_embeds is not None: return KeyeVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, video_grid_thw=video_grid_thw, ) def _process_video_input( self, video_input: KeyeVideoInputs ) -> tuple[torch.Tensor, ...]: video_type = video_input["type"] video_grid_thw = video_input["video_grid_thw"] pixel_values_videos = video_input.get("pixel_values_videos", None) return tuple( self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos) ) def get_mrope_input_positions( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: kwargs = MultiModalFeatureSpec.gather_kwargs( mm_features, {"image_grid_thw", "video_grid_thw"}, ) image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: video_grid_thw = video_grid_thw[0] def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: """ Split grid_thw along the t dimension. Args: grid_thw: shape [N, 3] tensor or nested list of [t, h, w]. Returns: List of [1, h, w] rows, repeated t times for each original row. """ if isinstance(grid_thw, list): grid_thw = torch.tensor(grid_thw, dtype=torch.long) if grid_thw.numel() == 0: return [] t, hw = grid_thw[:, 0], grid_thw[:, 1:] ones = torch.ones_like(hw[:, :1]) # [N,1] out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0) return out.tolist() video_grid_thw = split_thw(video_grid_thw) hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size image_nums = len(image_grid_thw) frame_nums = len(video_grid_thw) llm_pos_ids_list: list = [] st = 0 remain_images, remain_frames = image_nums, frame_nums image_index, video_index = 0, 0 for _ in range(image_nums + frame_nums): if remain_images > 0: try: ed_image = input_tokens.index(image_token_id, st) except ValueError: ed_image = len(input_tokens) + 1 else: ed_image = len(input_tokens) + 1 if remain_frames > 0: try: ed_video = input_tokens.index(video_token_id, st) except ValueError: ed_video = len(input_tokens) + 1 else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: t, h, w = video_grid_thw[video_index] video_index += 1 remain_frames -= 1 ed = ed_video llm_grid_t, llm_grid_h, llm_grid_w = ( t, h // spatial_merge_size, w // spatial_merge_size, ) text_len = ed - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append( torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx ) t_index = ( ( torch.arange(llm_grid_t) .view(-1, 1) .expand(-1, llm_grid_h * llm_grid_w) ) .long() .flatten() ) h_index = ( torch.arange(llm_grid_h) .view(1, -1, 1) .expand(llm_grid_t, -1, llm_grid_w) .flatten() ) w_index = ( torch.arange(llm_grid_w) .view(1, 1, -1) .expand(llm_grid_t, llm_grid_h, -1) .flatten() ) llm_pos_ids_list.append( torch.stack([t_index, h_index, w_index]) + text_len + st_idx ) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append( torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx ) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return llm_positions, mrope_position_delta