From 2c13695951c45749d795d6858ce89f7c580dcb61 Mon Sep 17 00:00:00 2001 From: Yang Date: Thu, 20 Nov 2025 15:39:18 -0800 Subject: [PATCH] org and add imports and fix lint error Signed-off-by: Yang --- vllm/model_executor/models/isaac.py | 483 +++++++++++++++++----------- 1 file changed, 298 insertions(+), 185 deletions(-) diff --git a/vllm/model_executor/models/isaac.py b/vllm/model_executor/models/isaac.py index 786b1fe4e6f1c..5c61e5bf48a70 100644 --- a/vllm/model_executor/models/isaac.py +++ b/vllm/model_executor/models/isaac.py @@ -1,49 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations -from collections.abc import Mapping, Sequence, Iterable -from typing import Any, Optional, Union -from typing_extensions import TypedDict, Unpack - import itertools -from enum import Enum -from dataclasses import dataclass - import math +from collections.abc import Iterable, Mapping, Sequence +from dataclasses import dataclass +from enum import Enum +from typing import Any + import numpy as np import PIL.Image import torch import torch.nn as nn import torch.nn.functional as F - from transformers import PretrainedConfig, Qwen3Config from transformers.image_processing_utils import BatchFeature -from transformers.tokenization_utils import TensorType from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig +from transformers.tokenization_utils import TensorType +from typing_extensions import TypedDict, Unpack -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.model_executor.models.interfaces import SupportsMultiModal -from vllm.model_executor.models.utils import ( - WeightsMapper, - AutoWeightsLoader, - _merge_multimodal_embeddings, - maybe_prefix, - init_vllm_registered_model, -) -from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM -from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.multimodal.processing import ( - BaseMultiModalProcessor, - BaseProcessingInfo, - PromptReplacement, -) -from vllm.multimodal.parse import MultiModalDataItems, ImageSize -from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.multimodal.inputs import ( - MultiModalFieldConfig, - MultiModalKwargs, - MultiModalDataDict, -) +from vllm.attention.backends.registry import _Backend from vllm.config import VllmConfig +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, +) from vllm.model_executor.models.interfaces import ( MultiModalEmbeddings, SupportsLoRA, @@ -51,18 +34,34 @@ from vllm.model_executor.models.interfaces import ( SupportsMultiModal, SupportsPP, ) - -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, -) +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM from vllm.model_executor.models.siglip2navit import Siglip2Encoder -from vllm.attention.backends.registry import _Backend -from vllm.model_executor.layers.quantization import QuantizationConfig - -from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + _merge_multimodal_embeddings, + maybe_prefix, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargs, +) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder # ===== TensorStream Compatibility Layer for Isaac MRoPE ===== -# Minimal implementation of TensorStream classes needed for Isaac's 3D positional encoding +# Minimal implementation of TensorStream classes needed for Isaac's 3D positional +# encoding + class ModalityType(Enum): """ @@ -127,37 +126,46 @@ class TextType(ModalityType): @dataclass class Event: """Represents a single modality event with spatial/temporal dimensions.""" + """ - Represents a single data occurrence (with a specific type, time interval, and data payload). + Represents a single data occurrence (with a specific type, time interval, and + data payload). Attributes: - data (Any): The actual data payload (e.g. a torch.Tensor, a string, etc.). - type (ModalityType): The modality type of the data (e.g., VisionType.image). - time (Tuple[float, float]): (start_time, end_time) indicating when this Event occurs. - role (Optional[str]): The role associated with this event (e.g., "user", "agent", "system"). - If None, the event is always included in loss calculation. + data (Any): The actual data payload (e.g. a torch.Tensor, a string, + etc.). + type (ModalityType): The modality type of the data (e.g., + VisionType.image). + time (Tuple[float, float]): (start_time, end_time) indicating when this + Event occurs. + role (Optional[str]): The role associated with this event (e.g., "user", + "agent", "system"). If None, the event is always included in loss + calculation. Example usage: evt = Event(data=torch.zeros((1, 224, 224, 3)), # e.g. a single image frame type=VisionType.image, time=(0.0, 0.04), role="user") - """ + """ # Descriptors modality_type: ModalityType - + # Structure - dims_virtual: list[int] | None = None # virtual/processed dimensions (e.g., pixel-shuffled) + dims_virtual: list[int] | None = ( + None # virtual/processed dimensions (e.g., pixel-shuffled) + ) dims_real: list[int] | None = None # real/actual tensor dimensions idx_range: tuple[int, int] | None = None - + def dims(self, virtual: bool = True) -> list[int] | None: """ Get the dimensions of this event. Args: - virtual: If True (default), return virtual/processed dimensions (e.g., pixel-shuffled). - If False, return real/actual tensor dimensions. + virtual: If True (default), return virtual/processed dimensions + (e.g., pixel-shuffled). If False, return real/actual tensor + dimensions. Returns: Dimensions list or None if not measured. @@ -171,7 +179,9 @@ class Event: if not virtual: assert partial is False and isinstance(self.data, torch.Tensor) return math.prod(self.dims(virtual=False)) - return self.idx_range[1] - self.idx_range[0] if partial else math.prod(self.dims()) + return ( + self.idx_range[1] - self.idx_range[0] if partial else math.prod(self.dims()) + ) @dataclass @@ -215,7 +225,8 @@ class Stream: yield from self.events -# TODO: implement all types of cool indexing which can happen since TensorStream assuems Event.data = Tensor +# TODO: implement all types of cool indexing which can happen since TensorStream +# assumes Event.data = Tensor @dataclass class TensorStream: streams: list[Stream] @@ -254,7 +265,8 @@ def compute_mrope_pos_tensor(ts: TensorStream, n_pos_dims: int = 3) -> torch.Ten cumulative_offset = 0 # running time index for this stream for event in stream: - # --- build coordinate grid for THIS event using itertools (no tensor ops) --- + # --- build coordinate grid for THIS event using itertools + # (no tensor ops) --- dims = (event.dims() or [1]) + [1] * (n_pos_dims - len(event.dims() or [])) # Create ranges for each dimension (similar to old _finalize implementation) @@ -274,26 +286,30 @@ def compute_mrope_pos_tensor(ts: TensorStream, n_pos_dims: int = 3) -> torch.Ten # Convert to tensor and reshape to (B, T, n_pos_dims) B, T = ts.shape - return torch.tensor(all_coords, dtype=torch.long, device=ts.device).reshape(B, T, n_pos_dims) + return torch.tensor(all_coords, dtype=torch.long, device=ts.device).reshape( + B, T, n_pos_dims + ) def modality_mask(ts: TensorStream, modality_type: ModalityType) -> torch.Tensor: """Create boolean mask for specific modality type in the tensor stream.""" B, T = ts.shape mask = torch.zeros((B, T), dtype=torch.bool, device=ts.device) - + for batch_idx, stream in enumerate(ts.streams): seq_idx = 0 for event in stream: if event.modality_type == modality_type: start, end = event.idx_range - mask[batch_idx, seq_idx:seq_idx+(end-start)] = True - seq_idx += (event.idx_range[1] - event.idx_range[0]) - + mask[batch_idx, seq_idx : seq_idx + (end - start)] = True + seq_idx += event.idx_range[1] - event.idx_range[0] + return mask + # ===== End TensorStream Compatibility Layer ===== + class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig): """Vision configuration for Isaac with Pixel Shuffle support. @@ -338,7 +354,9 @@ class Siglip2VariableSequenceEmbeddings(nn.Module): ) -> torch.Tensor: # Prepare positional embeddings grid: (1, embed_dim, h, w) positional_embeddings = ( - self.position_embedding.weight.reshape(self.position_embedding_size, self.position_embedding_size, -1) + self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, -1 + ) .permute(2, 0, 1) .unsqueeze(0) ) @@ -359,12 +377,16 @@ class Siglip2VariableSequenceEmbeddings(nn.Module): align_corners=align_corners, antialias=antialias, ) - # Reshape from (1, embed_dim, height, width) to (height*width, embed_dim) - resized_pos_embed = resized_pos_embed.reshape(self.embed_dim, height * width).transpose(0, 1) + # Reshape from (1, embed_dim, height, width) to + # (height*width, embed_dim) + resized_pos_embed = resized_pos_embed.reshape( + self.embed_dim, height * width + ).transpose(0, 1) else: # Fallback - should never happen in practice resized_pos_embed = positional_embeddings.reshape( - self.embed_dim, self.position_embedding_size * self.position_embedding_size + self.embed_dim, + self.position_embedding_size * self.position_embedding_size, ).transpose(0, 1)[: height * width] pos_embeds_list.append(resized_pos_embed) @@ -372,7 +394,9 @@ class Siglip2VariableSequenceEmbeddings(nn.Module): pos_embeds = torch.cat(pos_embeds_list, dim=0) return pos_embeds - def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor]): + def forward( + self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ): seq_patches, _seq_sizes, _spatial_shapes = packed_seq_patches # Apply patch embeddings @@ -385,7 +409,9 @@ class Siglip2VariableSequenceEmbeddings(nn.Module): # For variable-length attention, we need to reshape to (total_tokens, embed_dim) if batch_size != 1: - raise ValueError("Variable-length attention expects batch_size=1 for packed sequences") + raise ValueError( + "Variable-length attention expects batch_size=1 for packed sequences" + ) patch_embeds = patch_embeds.view(batch_size * patches_per_image, embed_dim) @@ -427,11 +453,13 @@ def create_pixel_shuffle_index_map( # Safety: all spatial dims must be divisible by r # Cannot run under torch compile fullgraph mode hence - if not torch.compiler.is_compiling(): - if not ((token_grids[:, 0] % r == 0).all() and (token_grids[:, 1] % r == 0).all()): - raise AssertionError( - f"Every (H,W) in `token_grids` must be divisible by scale_factor={r}, got {token_grids.tolist()}" - ) + if not torch.compiler.is_compiling() and not ( + (token_grids[:, 0] % r == 0).all() and (token_grids[:, 1] % r == 0).all() + ): + raise AssertionError( + "Every (H,W) in `token_grids` must be divisible by " + f"scale_factor={r}, got {token_grids.tolist()}" + ) gather_chunks: list[torch.Tensor] = [] tok_offset = 0 @@ -467,19 +495,23 @@ def pixel_shuffle_varlen( Args: x (`torch.Tensor`): - Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or `(1, seq_len, hidden_size)` shapes - produced by stacking image patches. + Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or + `(1, seq_len, hidden_size)` shapes produced by stacking image + patches. token_grids (`torch.Tensor`): - Integer tensor of shape `(num_images, 2)` whose rows give the `(height, width)` patch grid sizes - corresponding to each image segment inside `x`. + Integer tensor of shape `(num_images, 2)` whose rows give the + `(height, width)` patch grid sizes corresponding to each image + segment inside `x`. scale_factor (`int`, *optional*, defaults to 1): - Spatial down-sampling factor specific to pixel shuffle. Values greater than one merge `scale_factor**2` neighboring patches into a + Spatial down-sampling factor specific to pixel shuffle. Values + greater than one merge `scale_factor**2` neighboring patches into a single embedding channel-group. Returns: - `torch.Tensor`: Pixel-shuffled embeddings with shape matching the input convention: - `(seq_len, hidden_size * scale_factor**2)` when the input was 2D, or `(1, seq_len, hidden_size * scale_factor**2)` - if the singleton batch dimension was present. + `torch.Tensor`: Pixel-shuffled embeddings with shape matching the input + convention: `(seq_len, hidden_size * scale_factor**2)` when the input + was 2D, or `(1, seq_len, hidden_size * scale_factor**2)` if the + singleton batch dimension was present. Raises: ValueError: If more than one batch item is provided. @@ -517,6 +549,7 @@ def pixel_shuffle_varlen( out = out.unsqueeze(0) return out + # ============================================================================ # Configuration # ============================================================================ @@ -550,7 +583,9 @@ def _make_writeable(arr: np.ndarray) -> np.ndarray: def extract_image_pil(image: PIL.Image.Image) -> torch.Tensor | None: if image.width * image.height > MAX_PIXELS: - raise ValueError(f"Image (w={image.width}, h={image.height}) > MAX=`{MAX_PIXELS}`") + raise ValueError( + f"Image (w={image.width}, h={image.height}) > MAX=`{MAX_PIXELS}`" + ) img = image if image.mode == "RGB" else image.convert("RGB") arr = np.asarray(img) arr = _make_writeable(arr) @@ -576,17 +611,22 @@ def get_image_size_for_max_num_patches( patch_size (`int`): Size of the square patch used by the vision encoder. max_num_patches (`int`): - Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. + Upper bound on `(height / patch_size) * (width / patch_size)` after + resizing. min_num_patches (`int`, *optional*): - Lower bound on the number of patches. When provided the image will be scaled up if necessary. + Lower bound on the number of patches. When provided the image will + be scaled up if necessary. eps (`float`, *optional*, defaults to 1e-5): - Convergence tolerance for the internal binary search to determing the target dimensions. + Convergence tolerance for the internal binary search to determine + the target dimensions. pixel_shuffle_scale (`int`, *optional*, defaults to 1): - Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. + Additional stride multiplier applied when pixel shuffle later + reduces spatial resolution. Returns: - `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` - and respect both the maximum and optional minimum patch-count constraints. + `tuple[int, int]`: Height and width (in pixels) that are multiples of + `patch_size * pixel_shuffle_scale` and respect both the maximum and + optional minimum patch-count constraints. """ def get_scaled_image_size(scale, original_size, patch_size, pixel_shuffle_scale): @@ -610,16 +650,24 @@ def get_image_size_for_max_num_patches( scale_min, scale_max = 1.0, 100.0 while (scale_max - scale_min) >= eps: scale = (scale_min + scale_max) / 2 - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + target_height = get_scaled_image_size( + scale, image_height, patch_size, pixel_shuffle_scale + ) + target_width = get_scaled_image_size( + scale, image_width, patch_size, pixel_shuffle_scale + ) num_patches = (target_height / patch_size) * (target_width / patch_size) if num_patches >= min_num_patches: scale_max = scale else: scale_min = scale scale = scale_max - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + target_height = get_scaled_image_size( + scale, image_height, patch_size, pixel_shuffle_scale + ) + target_width = get_scaled_image_size( + scale, image_width, patch_size, pixel_shuffle_scale + ) return target_height, target_width elif num_patches <= max_num_patches: return adjusted_height, adjusted_width @@ -628,16 +676,24 @@ def get_image_size_for_max_num_patches( scale_min, scale_max = eps / 10, 1.0 while (scale_max - scale_min) >= eps: scale = (scale_min + scale_max) / 2 - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + target_height = get_scaled_image_size( + scale, image_height, patch_size, pixel_shuffle_scale + ) + target_width = get_scaled_image_size( + scale, image_width, patch_size, pixel_shuffle_scale + ) num_patches = (target_height / patch_size) * (target_width / patch_size) if num_patches <= max_num_patches: scale_min = scale else: scale_max = scale scale = scale_min - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + target_height = get_scaled_image_size( + scale, image_height, patch_size, pixel_shuffle_scale + ) + target_width = get_scaled_image_size( + scale, image_width, patch_size, pixel_shuffle_scale + ) return target_height, target_width @@ -653,12 +709,13 @@ def prepare_image_tensor( Args: image (`torch.Tensor`): - Tensor with shape `(..., height, width, 3)` containing RGB values. The tensor is converted to floating - point if needed. + Tensor with shape `(..., height, width, 3)` containing RGB values. + The tensor is converted to floating point if needed. scale (`float`, *optional*, defaults to `VISION_SCALE`): Scalar multiplier applied before normalization. Returns: - `torch.Tensor`: Normalized tensor with the same shape as the input and dtype `torch.float32`. + `torch.Tensor`: Normalized tensor with the same shape as the input and + dtype `torch.float32`. """ if not torch.is_floating_point(image): image = image.float() @@ -683,17 +740,33 @@ def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: Returns: `torch.Tensor`: - Patch tensor where each position stores the flattened pixels belonging to that patch. + Patch tensor where each position stores the flattened pixels + belonging to that patch. Raises: ValueError: If `height` or `width` is not divisible by `patch_size`. """ num_images, height, width, channels = image.shape if height % patch_size or width % patch_size: - raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.") - patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels) + raise ValueError( + "Dimensions of images " + f"{image.shape} are not divisible by patch_size={patch_size}." + ) + patches = image.reshape( + num_images, + height // patch_size, + patch_size, + width // patch_size, + patch_size, + channels, + ) patches = patches.permute(0, 1, 3, 2, 4, 5) - patches = patches.reshape(num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size) + patches = patches.reshape( + num_images, + height // patch_size, + width // patch_size, + channels * patch_size * patch_size, + ) return patches @@ -708,21 +781,26 @@ def process_vision_for_patches( Args: images (`torch.Tensor`): - Either `(height, width, channels)` for a single image or `(num_images, height, width, channels)` for a - batch. Channels are expected to be RGB. + Either `(height, width, channels)` for a single image or + `(num_images, height, width, channels)` for a batch. Channels are + expected to be RGB. patch_size (`int`): Edge length of square patches; implictly controls resize grid granularity. max_num_patches (`int`): Maximum number of patches allowed after resizing. min_num_patches (`int`, *optional*): - Minimum number of patches. If provided, the routine upsamples images as needed to satisfy the lower bound. + Minimum number of patches. If provided, the routine upsamples images + as needed to satisfy the lower bound. pixel_shuffle_scale (`int`, *optional*, defaults to 1): - pixel shuffle scale factor; influences the target grid that the function produces. + Pixel shuffle scale factor; influences the target grid that the + function produces. Returns: - `tuple[torch.Tensor, list[int]]`: A pair `(patches, dims_virtual)` where `patches` has shape - `(num_images, target_h / patch_size, target_w / patch_size, channels * patch_size**2)` and `dims_virtual` - encodes effective `(images, height, width)` dimensions after optional pixel shuffling. + `tuple[torch.Tensor, list[int]]`: A pair `(patches, dims_virtual)` + where `patches` has shape `(num_images, target_h / patch_size, target_w + / patch_size, channels * patch_size**2)` and `dims_virtual` encodes + effective `(images, height, width)` dimensions after optional pixel + shuffling. """ # Add batch dim if single image if images.dim() == 3: @@ -788,7 +866,7 @@ class IsaacConfig(Qwen3Config): **kwargs, ): super().__init__(**kwargs) - + # EventStreamProcessor parameters (for backward compatibility) self.video_patch_size = vision_patch_size self.vision_max_num_patches = vision_max_num_patches @@ -814,7 +892,6 @@ class IsaacImageProcessorKwargs(TypedDict, total=False): class IsaacImageProcessor: - patch_size = 16 max_num_patches = 6144 min_num_patches = 256 @@ -825,14 +902,18 @@ class IsaacImageProcessor: def __init__(self, kwargs): self.patch_size = kwargs.pop("patch_size", self.patch_size) - self.vision_max_num_patches = kwargs.pop("vision_max_num_patches", self.max_num_patches) - self.vision_min_num_patches = kwargs.pop("vision_min_num_patches", self.min_num_patches) + self.vision_max_num_patches = kwargs.pop( + "vision_max_num_patches", self.max_num_patches + ) + self.vision_min_num_patches = kwargs.pop( + "vision_min_num_patches", self.min_num_patches + ) self.pixel_shuffle_scale = kwargs.pop("pixel_shuffle_scale", 2) def preprocess( self, images: list[torch.Tensor], - return_tensors: Optional[Union[str, TensorType]], + return_tensors: str | TensorType | None, **kwargs: Unpack[IsaacImageProcessorKwargs], ) -> BatchFeature: """Isaac's resize → normalize → patchify → pack.""" @@ -840,9 +921,9 @@ class IsaacImageProcessor: all_pixel_values: list[torch.Tensor] = [] all_image_grids: list[torch.Tensor] = [] - for image in images: + for image in images: image_tensor = extract_image_pil(image) - + patches, dims_virtual = process_vision_for_patches( image_tensor, patch_size=self.patch_size, @@ -874,7 +955,10 @@ class IsaacImageProcessor: final_image_grids = torch.empty(0, 3) return BatchFeature( - data={"pixel_values": final_pixel_values, "image_grid_thw": final_image_grids}, + data={ + "pixel_values": final_pixel_values, + "image_grid_thw": final_image_grids, + }, tensor_type=return_tensors, ) @@ -899,7 +983,7 @@ class IsaacProcessor: image_result = self.image_processor.preprocess(images, **kwargs) result.update(image_result) return BatchFeature(result) - + def apply_chat_template( self, messages: list[dict[str, Any]], @@ -909,7 +993,7 @@ class IsaacProcessor: ) -> Any: # Convert mixed content messages to simple text format processed_messages = [] - + for message in messages: if "content" in message and isinstance(message["content"], list): # Handle mixed content (text + image) @@ -920,23 +1004,25 @@ class IsaacProcessor: elif content_item.get("type") == "image": # Replace image with vision token text_parts.append(self.image_token) - + processed_message = { "role": message.get("role", "user"), - "content": "".join(text_parts) + "content": "".join(text_parts), } processed_messages.append(processed_message) else: # Regular text message processed_messages.append(message) - + return self.tokenizer.apply_chat_template( - processed_messages, tokenize=tokenize, add_generation_prompt=add_generation_prompt, **kwargs + processed_messages, + tokenize=tokenize, + add_generation_prompt=add_generation_prompt, + **kwargs, ) class IsaacProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> IsaacConfig: if hasattr(self.ctx, "get_hf_config"): original_config = self.ctx.get_hf_config() @@ -945,10 +1031,16 @@ class IsaacProcessingInfo(BaseProcessingInfo): # Vision parameters - map from HF names vision_config=getattr(original_config, "vision_config", None), vision_patch_size=getattr(original_config, "video_patch_size", 16), - vision_max_num_patches=getattr(original_config, "vision_max_num_patches", 256), - vision_min_num_patches=getattr(original_config, "vision_min_num_patches", None), + vision_max_num_patches=getattr( + original_config, "vision_max_num_patches", 256 + ), + vision_min_num_patches=getattr( + original_config, "vision_min_num_patches", None + ), pixel_shuffle_scale=getattr(original_config, "pixel_shuffle_scale", 1), - max_sequence_length=getattr(original_config, "max_sequence_length", 16384), + max_sequence_length=getattr( + original_config, "max_sequence_length", 16384 + ), vision_token="<|image_pad|>", ) return IsaacConfig() @@ -975,18 +1067,22 @@ class IsaacProcessingInfo(BaseProcessingInfo): def get_image_processor(self, **kwargs) -> IsaacImageProcessor: return self.get_hf_processor(**kwargs).image_processor - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_mm_max_tokens_per_item( - self, seq_len: int, mm_counts: Mapping[str, int], + self, + seq_len: int, + mm_counts: Mapping[str, int], ) -> Mapping[str, int]: hf_config = self.get_hf_config() - num_vision_tokens = hf_config.vision_max_num_patches // (hf_config.pixel_shuffle_scale**2) + num_vision_tokens = hf_config.vision_max_num_patches // ( + hf_config.pixel_shuffle_scale**2 + ) return {"image": num_vision_tokens} -class IsaacDummyInputsBuilder(BaseDummyInputsBuilder[IsaacProcessingInfo]): +class IsaacDummyInputsBuilder(BaseDummyInputsBuilder[IsaacProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -1017,19 +1113,19 @@ class IsaacDummyInputsBuilder(BaseDummyInputsBuilder[IsaacProcessingInfo]): class IsaacMultiModalProcessor(BaseMultiModalProcessor): - def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - # Configure multimodal fields for Isaac model image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) image_grid_sizes = image_grid_thw.prod(-1) return { - "pixel_values": MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), + "pixel_values": MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes + ), "image_grid_thw": MultiModalFieldConfig.batched("image"), } @@ -1039,24 +1135,23 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor): hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, ) -> Sequence[PromptUpdate]: - - #hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + # hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() placeholder_id = vocab.get("<|image_pad|>", 151655) - - pixel_shuffle_scale = getattr(image_processor, 'pixel_shuffle_scale', 2) - merge_length = pixel_shuffle_scale ** 2 - + + pixel_shuffle_scale = getattr(image_processor, "pixel_shuffle_scale", 2) + merge_length = pixel_shuffle_scale**2 + def get_replacement_isaac(item_idx: int): out_item = out_mm_kwargs["image"][item_idx] grid_thw = out_item["image_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) num_tokens = int(grid_thw.prod()) // merge_length - return [placeholder_id] * num_tokens + return [placeholder_id] * num_tokens return [ PromptReplacement( @@ -1066,6 +1161,7 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor): ) ] + class Siglip2VisionTransformer(nn.Module): def __init__( self, @@ -1107,7 +1203,9 @@ class Siglip2VisionTransformer(nn.Module): # Get embeddings from packed sequence hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids)) - grid_thws = torch.tensor([[1, token_grids[0][0].item(), token_grids[0][1].item()]]) + grid_thws = torch.tensor( + [[1, token_grids[0][0].item(), token_grids[0][1].item()]] + ) last_hidden_state = self.encoder(hidden_states, grid_thws) hidden_states = self.post_layernorm(last_hidden_state) @@ -1123,7 +1221,7 @@ class Siglip2VisionTransformer(nn.Module): # Remove the pseudo batch dimension we added earlier hidden_states = hidden_states.squeeze(0) - #return last_hidden_state + # return last_hidden_state return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -1160,9 +1258,8 @@ class Siglip2VisionTransformer(nn.Module): dummy_inputs=IsaacDummyInputsBuilder, ) class IsaacForConditionalGeneration( - Qwen3ForCausalLM, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE + Qwen3ForCausalLM, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): - packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1176,11 +1273,11 @@ class IsaacForConditionalGeneration( } supports_encoder_tp_data = True - + # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ - "model.vision_embedding.": "vision_embedding.", + "model.vision_embedding.": "vision_embedding.", } ) @@ -1188,11 +1285,10 @@ class IsaacForConditionalGeneration( def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|image_pad|>" - + raise ValueError("Only image modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): - config: IsaacConfig = vllm_config.model_config.hf_config head_dim = config.head_dim @@ -1207,18 +1303,22 @@ class IsaacForConditionalGeneration( # Initialize the parent class with updated config super().__init__(vllm_config=vllm_config, prefix=prefix) - + # Create the language model module to match checkpoint structure - self.language_model = nn.ModuleDict({ - "embed_tokens": self.model.embed_tokens, - "layers": self.model.layers, - "norm": self.model.norm - }) - + self.language_model = nn.ModuleDict( + { + "embed_tokens": self.model.embed_tokens, + "layers": self.model.layers, + "norm": self.model.norm, + } + ) + config.vision_config.preserve_original_pe = True config.vision_config.use_rope = False - config.vision_config.hidden_stride = config.vision_config.pixel_shuffle_scale_factor - config.vision_config.window_size = 32*2 + config.vision_config.hidden_stride = ( + config.vision_config.pixel_shuffle_scale_factor + ) + config.vision_config.window_size = 32 * 2 config.vision_config.fullatt_block_indexes = None vision_cfg = config.vision_config if vision_cfg is None: @@ -1226,7 +1326,9 @@ class IsaacForConditionalGeneration( hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) self.vision_embedding = nn.Sequential( - Siglip2VisionTransformer(vision_cfg, prefix=maybe_prefix(prefix, "vision_embedding")), + Siglip2VisionTransformer( + vision_cfg, prefix=maybe_prefix(prefix, "vision_embedding") + ), nn.Linear( hidden_dim, 4 * hidden_dim, @@ -1250,26 +1352,32 @@ class IsaacForConditionalGeneration( ) -> tuple[torch.Tensor, int]: """Get mrope input positions and delta value.""" - vision_token_id = getattr(self.config, 'image_token_id', 151655) + vision_token_id = getattr(self.config, "image_token_id", 151655) spatial_merge_size = hf_config.vision_config.pixel_shuffle_scale_factor input_tokens_tensor = torch.tensor(input_tokens) - + # Find image token positions - image_positions = torch.where(input_tokens_tensor == vision_token_id)[0].tolist() - - # For text-only inputs, use Isaac's original logic from compute_position_ids_input_ids() + image_positions = torch.where(input_tokens_tensor == vision_token_id)[ + 0 + ].tolist() + + # For text-only inputs, use Isaac's original logic from + # compute_position_ids_input_ids() if len(image_positions) == 0: seq_len = len(input_tokens) - # Create 3D positions where all dimensions get the same 1D temporal progression + # Create 3D positions where all dimensions get the same 1D temporal + # progression position_ids = torch.arange(seq_len, dtype=torch.long) position_ids = position_ids.view(1, -1).expand(1, -1) # [1, seq_len] - position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) # [1, seq_len, 3] + position_ids = position_ids.unsqueeze(2).expand( + -1, -1, 3 + ) # [1, seq_len, 3] # vLLM expects shape [3, seq_len], so transpose position_ids = position_ids.squeeze(0).transpose(0, 1) # [3, seq_len] - + return position_ids, 0 - + events = [] image_idx = 0 current_pos = 0 @@ -1278,7 +1386,7 @@ class IsaacForConditionalGeneration( for image_pos in image_positions: if image_pos <= last_processed_pos: continue # Skip already processed positions - + # Add any text before this image if image_pos > current_pos: text_tokens = image_pos - current_pos @@ -1288,21 +1396,23 @@ class IsaacForConditionalGeneration( idx_range=(0, text_tokens), ) events.append(text_event) - + # Add image t, h, w = image_grid_thw[image_idx] llm_grid_h, llm_grid_w = h // spatial_merge_size, w // spatial_merge_size image_tokens = t * llm_grid_h * llm_grid_w - + image_event = Event( modality_type=VisionType.image, dims_virtual=[t, llm_grid_h, llm_grid_w], idx_range=(0, image_tokens), ) events.append(image_event) - + current_pos = image_pos + image_tokens - last_processed_pos = current_pos - 1 # Mark up to this position as processed + last_processed_pos = ( + current_pos - 1 + ) # Mark up to this position as processed image_idx += 1 # Add final text segment if any @@ -1314,7 +1424,7 @@ class IsaacForConditionalGeneration( idx_range=(0, text_tokens), ) events.append(text_event) - + stream = Stream(events) tensor_stream = TensorStream([stream]) @@ -1334,8 +1444,7 @@ class IsaacForConditionalGeneration( def get_multimodal_embeddings( self, **kwargs: object - ) -> MultiModalEmbeddings | None: - + ) -> MultiModalEmbeddings | None: pixel_values = kwargs.get("pixel_values") image_grid_thw = kwargs.get("image_grid_thw") @@ -1343,15 +1452,21 @@ class IsaacForConditionalGeneration( return [] # Convert image_grid_thw from [batch, 1, [T, H, W]] to [batch, [H, W]] - spatial_grids = image_grid_thw[:, 0, 1:3] # Extract H, W from [T, H, W] for each image - + spatial_grids = image_grid_thw[ + :, 0, 1:3 + ] # Extract H, W from [T, H, W] for each image + # Process packed sequence patches through vision_embedding module vision_embeddings = self.vision_embedding((pixel_values, spatial_grids)) # Split concatenated embeddings for each image item (following Qwen2-VL pattern) - merge_size = self.config.vision_config.pixel_shuffle_scale_factor # Isaac uses pixel shuffle - sizes = spatial_grids.prod(-1) // (merge_size * merge_size) # H * W / (merge_size^2) - + merge_size = ( + self.config.vision_config.pixel_shuffle_scale_factor + ) # Isaac uses pixel shuffle + sizes = spatial_grids.prod(-1) // ( + merge_size * merge_size + ) # H * W / (merge_size^2) + return vision_embeddings.split(sizes.tolist()) def get_input_embeddings( @@ -1362,13 +1477,11 @@ class IsaacForConditionalGeneration( is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: - # Get text embeddings from the base language model inputs_embeds = super().get_input_embeddings(input_ids) - + # If we have multimodal embeddings, merge them with text embeddings if multimodal_embeddings is not None and len(multimodal_embeddings) != 0: - inputs_embeds = _merge_multimodal_embeddings( inputs_embeds=inputs_embeds, multimodal_embeddings=multimodal_embeddings, @@ -1379,7 +1492,7 @@ class IsaacForConditionalGeneration( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [] - + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)