org and add imports and fix lint error

Signed-off-by: Yang <lymailforjob@gmail.com>
This commit is contained in:
Yang 2025-11-20 15:39:18 -08:00 committed by Yang Liu
parent 0dbe093c56
commit 2c13695951

View File

@ -1,49 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from collections.abc import Mapping, Sequence, Iterable
from typing import Any, Optional, Union
from typing_extensions import TypedDict, Unpack
import itertools
from enum import Enum
from dataclasses import dataclass
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, Qwen3Config
from transformers.image_processing_utils import BatchFeature
from transformers.tokenization_utils import TensorType
from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig
from transformers.tokenization_utils import TensorType
from typing_extensions import TypedDict, Unpack
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (
WeightsMapper,
AutoWeightsLoader,
_merge_multimodal_embeddings,
maybe_prefix,
init_vllm_registered_model,
)
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
)
from vllm.multimodal.parse import MultiModalDataItems, ImageSize
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.inputs import (
MultiModalFieldConfig,
MultiModalKwargs,
MultiModalDataDict,
)
from vllm.attention.backends.registry import _Backend
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
)
from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
@ -51,18 +34,34 @@ from vllm.model_executor.models.interfaces import (
SupportsMultiModal,
SupportsPP,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
from vllm.model_executor.models.siglip2navit import Siglip2Encoder
from vllm.attention.backends.registry import _Backend
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
WeightsMapper,
_merge_multimodal_embeddings,
maybe_prefix,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargs,
)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
# ===== TensorStream Compatibility Layer for Isaac MRoPE =====
# Minimal implementation of TensorStream classes needed for Isaac's 3D positional encoding
# Minimal implementation of TensorStream classes needed for Isaac's 3D positional
# encoding
class ModalityType(Enum):
"""
@ -127,37 +126,46 @@ class TextType(ModalityType):
@dataclass
class Event:
"""Represents a single modality event with spatial/temporal dimensions."""
"""
Represents a single data occurrence (with a specific type, time interval, and data payload).
Represents a single data occurrence (with a specific type, time interval, and
data payload).
Attributes:
data (Any): The actual data payload (e.g. a torch.Tensor, a string, etc.).
type (ModalityType): The modality type of the data (e.g., VisionType.image).
time (Tuple[float, float]): (start_time, end_time) indicating when this Event occurs.
role (Optional[str]): The role associated with this event (e.g., "user", "agent", "system").
If None, the event is always included in loss calculation.
data (Any): The actual data payload (e.g. a torch.Tensor, a string,
etc.).
type (ModalityType): The modality type of the data (e.g.,
VisionType.image).
time (Tuple[float, float]): (start_time, end_time) indicating when this
Event occurs.
role (Optional[str]): The role associated with this event (e.g., "user",
"agent", "system"). If None, the event is always included in loss
calculation.
Example usage:
evt = Event(data=torch.zeros((1, 224, 224, 3)), # e.g. a single image frame
type=VisionType.image,
time=(0.0, 0.04),
role="user")
"""
"""
# Descriptors
modality_type: ModalityType
# Structure
dims_virtual: list[int] | None = None # virtual/processed dimensions (e.g., pixel-shuffled)
dims_virtual: list[int] | None = (
None # virtual/processed dimensions (e.g., pixel-shuffled)
)
dims_real: list[int] | None = None # real/actual tensor dimensions
idx_range: tuple[int, int] | None = None
def dims(self, virtual: bool = True) -> list[int] | None:
"""
Get the dimensions of this event.
Args:
virtual: If True (default), return virtual/processed dimensions (e.g., pixel-shuffled).
If False, return real/actual tensor dimensions.
virtual: If True (default), return virtual/processed dimensions
(e.g., pixel-shuffled). If False, return real/actual tensor
dimensions.
Returns:
Dimensions list or None if not measured.
@ -171,7 +179,9 @@ class Event:
if not virtual:
assert partial is False and isinstance(self.data, torch.Tensor)
return math.prod(self.dims(virtual=False))
return self.idx_range[1] - self.idx_range[0] if partial else math.prod(self.dims())
return (
self.idx_range[1] - self.idx_range[0] if partial else math.prod(self.dims())
)
@dataclass
@ -215,7 +225,8 @@ class Stream:
yield from self.events
# TODO: implement all types of cool indexing which can happen since TensorStream assuems Event.data = Tensor
# TODO: implement all types of cool indexing which can happen since TensorStream
# assumes Event.data = Tensor
@dataclass
class TensorStream:
streams: list[Stream]
@ -254,7 +265,8 @@ def compute_mrope_pos_tensor(ts: TensorStream, n_pos_dims: int = 3) -> torch.Ten
cumulative_offset = 0 # running time index for this stream
for event in stream:
# --- build coordinate grid for THIS event using itertools (no tensor ops) ---
# --- build coordinate grid for THIS event using itertools
# (no tensor ops) ---
dims = (event.dims() or [1]) + [1] * (n_pos_dims - len(event.dims() or []))
# Create ranges for each dimension (similar to old _finalize implementation)
@ -274,26 +286,30 @@ def compute_mrope_pos_tensor(ts: TensorStream, n_pos_dims: int = 3) -> torch.Ten
# Convert to tensor and reshape to (B, T, n_pos_dims)
B, T = ts.shape
return torch.tensor(all_coords, dtype=torch.long, device=ts.device).reshape(B, T, n_pos_dims)
return torch.tensor(all_coords, dtype=torch.long, device=ts.device).reshape(
B, T, n_pos_dims
)
def modality_mask(ts: TensorStream, modality_type: ModalityType) -> torch.Tensor:
"""Create boolean mask for specific modality type in the tensor stream."""
B, T = ts.shape
mask = torch.zeros((B, T), dtype=torch.bool, device=ts.device)
for batch_idx, stream in enumerate(ts.streams):
seq_idx = 0
for event in stream:
if event.modality_type == modality_type:
start, end = event.idx_range
mask[batch_idx, seq_idx:seq_idx+(end-start)] = True
seq_idx += (event.idx_range[1] - event.idx_range[0])
mask[batch_idx, seq_idx : seq_idx + (end - start)] = True
seq_idx += event.idx_range[1] - event.idx_range[0]
return mask
# ===== End TensorStream Compatibility Layer =====
class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig):
"""Vision configuration for Isaac with Pixel Shuffle support.
@ -338,7 +354,9 @@ class Siglip2VariableSequenceEmbeddings(nn.Module):
) -> torch.Tensor:
# Prepare positional embeddings grid: (1, embed_dim, h, w)
positional_embeddings = (
self.position_embedding.weight.reshape(self.position_embedding_size, self.position_embedding_size, -1)
self.position_embedding.weight.reshape(
self.position_embedding_size, self.position_embedding_size, -1
)
.permute(2, 0, 1)
.unsqueeze(0)
)
@ -359,12 +377,16 @@ class Siglip2VariableSequenceEmbeddings(nn.Module):
align_corners=align_corners,
antialias=antialias,
)
# Reshape from (1, embed_dim, height, width) to (height*width, embed_dim)
resized_pos_embed = resized_pos_embed.reshape(self.embed_dim, height * width).transpose(0, 1)
# Reshape from (1, embed_dim, height, width) to
# (height*width, embed_dim)
resized_pos_embed = resized_pos_embed.reshape(
self.embed_dim, height * width
).transpose(0, 1)
else:
# Fallback - should never happen in practice
resized_pos_embed = positional_embeddings.reshape(
self.embed_dim, self.position_embedding_size * self.position_embedding_size
self.embed_dim,
self.position_embedding_size * self.position_embedding_size,
).transpose(0, 1)[: height * width]
pos_embeds_list.append(resized_pos_embed)
@ -372,7 +394,9 @@ class Siglip2VariableSequenceEmbeddings(nn.Module):
pos_embeds = torch.cat(pos_embeds_list, dim=0)
return pos_embeds
def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
def forward(
self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
):
seq_patches, _seq_sizes, _spatial_shapes = packed_seq_patches
# Apply patch embeddings
@ -385,7 +409,9 @@ class Siglip2VariableSequenceEmbeddings(nn.Module):
# For variable-length attention, we need to reshape to (total_tokens, embed_dim)
if batch_size != 1:
raise ValueError("Variable-length attention expects batch_size=1 for packed sequences")
raise ValueError(
"Variable-length attention expects batch_size=1 for packed sequences"
)
patch_embeds = patch_embeds.view(batch_size * patches_per_image, embed_dim)
@ -427,11 +453,13 @@ def create_pixel_shuffle_index_map(
# Safety: all spatial dims must be divisible by r
# Cannot run under torch compile fullgraph mode hence
if not torch.compiler.is_compiling():
if not ((token_grids[:, 0] % r == 0).all() and (token_grids[:, 1] % r == 0).all()):
raise AssertionError(
f"Every (H,W) in `token_grids` must be divisible by scale_factor={r}, got {token_grids.tolist()}"
)
if not torch.compiler.is_compiling() and not (
(token_grids[:, 0] % r == 0).all() and (token_grids[:, 1] % r == 0).all()
):
raise AssertionError(
"Every (H,W) in `token_grids` must be divisible by "
f"scale_factor={r}, got {token_grids.tolist()}"
)
gather_chunks: list[torch.Tensor] = []
tok_offset = 0
@ -467,19 +495,23 @@ def pixel_shuffle_varlen(
Args:
x (`torch.Tensor`):
Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or `(1, seq_len, hidden_size)` shapes
produced by stacking image patches.
Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or
`(1, seq_len, hidden_size)` shapes produced by stacking image
patches.
token_grids (`torch.Tensor`):
Integer tensor of shape `(num_images, 2)` whose rows give the `(height, width)` patch grid sizes
corresponding to each image segment inside `x`.
Integer tensor of shape `(num_images, 2)` whose rows give the
`(height, width)` patch grid sizes corresponding to each image
segment inside `x`.
scale_factor (`int`, *optional*, defaults to 1):
Spatial down-sampling factor specific to pixel shuffle. Values greater than one merge `scale_factor**2` neighboring patches into a
Spatial down-sampling factor specific to pixel shuffle. Values
greater than one merge `scale_factor**2` neighboring patches into a
single embedding channel-group.
Returns:
`torch.Tensor`: Pixel-shuffled embeddings with shape matching the input convention:
`(seq_len, hidden_size * scale_factor**2)` when the input was 2D, or `(1, seq_len, hidden_size * scale_factor**2)`
if the singleton batch dimension was present.
`torch.Tensor`: Pixel-shuffled embeddings with shape matching the input
convention: `(seq_len, hidden_size * scale_factor**2)` when the input
was 2D, or `(1, seq_len, hidden_size * scale_factor**2)` if the
singleton batch dimension was present.
Raises:
ValueError: If more than one batch item is provided.
@ -517,6 +549,7 @@ def pixel_shuffle_varlen(
out = out.unsqueeze(0)
return out
# ============================================================================
# Configuration
# ============================================================================
@ -550,7 +583,9 @@ def _make_writeable(arr: np.ndarray) -> np.ndarray:
def extract_image_pil(image: PIL.Image.Image) -> torch.Tensor | None:
if image.width * image.height > MAX_PIXELS:
raise ValueError(f"Image (w={image.width}, h={image.height}) > MAX=`{MAX_PIXELS}`")
raise ValueError(
f"Image (w={image.width}, h={image.height}) > MAX=`{MAX_PIXELS}`"
)
img = image if image.mode == "RGB" else image.convert("RGB")
arr = np.asarray(img)
arr = _make_writeable(arr)
@ -576,17 +611,22 @@ def get_image_size_for_max_num_patches(
patch_size (`int`):
Size of the square patch used by the vision encoder.
max_num_patches (`int`):
Upper bound on `(height / patch_size) * (width / patch_size)` after resizing.
Upper bound on `(height / patch_size) * (width / patch_size)` after
resizing.
min_num_patches (`int`, *optional*):
Lower bound on the number of patches. When provided the image will be scaled up if necessary.
Lower bound on the number of patches. When provided the image will
be scaled up if necessary.
eps (`float`, *optional*, defaults to 1e-5):
Convergence tolerance for the internal binary search to determing the target dimensions.
Convergence tolerance for the internal binary search to determine
the target dimensions.
pixel_shuffle_scale (`int`, *optional*, defaults to 1):
Additional stride multiplier applied when pixel shuffle later reduces spatial resolution.
Additional stride multiplier applied when pixel shuffle later
reduces spatial resolution.
Returns:
`tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale`
and respect both the maximum and optional minimum patch-count constraints.
`tuple[int, int]`: Height and width (in pixels) that are multiples of
`patch_size * pixel_shuffle_scale` and respect both the maximum and
optional minimum patch-count constraints.
"""
def get_scaled_image_size(scale, original_size, patch_size, pixel_shuffle_scale):
@ -610,16 +650,24 @@ def get_image_size_for_max_num_patches(
scale_min, scale_max = 1.0, 100.0
while (scale_max - scale_min) >= eps:
scale = (scale_min + scale_max) / 2
target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale)
target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale)
target_height = get_scaled_image_size(
scale, image_height, patch_size, pixel_shuffle_scale
)
target_width = get_scaled_image_size(
scale, image_width, patch_size, pixel_shuffle_scale
)
num_patches = (target_height / patch_size) * (target_width / patch_size)
if num_patches >= min_num_patches:
scale_max = scale
else:
scale_min = scale
scale = scale_max
target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale)
target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale)
target_height = get_scaled_image_size(
scale, image_height, patch_size, pixel_shuffle_scale
)
target_width = get_scaled_image_size(
scale, image_width, patch_size, pixel_shuffle_scale
)
return target_height, target_width
elif num_patches <= max_num_patches:
return adjusted_height, adjusted_width
@ -628,16 +676,24 @@ def get_image_size_for_max_num_patches(
scale_min, scale_max = eps / 10, 1.0
while (scale_max - scale_min) >= eps:
scale = (scale_min + scale_max) / 2
target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale)
target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale)
target_height = get_scaled_image_size(
scale, image_height, patch_size, pixel_shuffle_scale
)
target_width = get_scaled_image_size(
scale, image_width, patch_size, pixel_shuffle_scale
)
num_patches = (target_height / patch_size) * (target_width / patch_size)
if num_patches <= max_num_patches:
scale_min = scale
else:
scale_max = scale
scale = scale_min
target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale)
target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale)
target_height = get_scaled_image_size(
scale, image_height, patch_size, pixel_shuffle_scale
)
target_width = get_scaled_image_size(
scale, image_width, patch_size, pixel_shuffle_scale
)
return target_height, target_width
@ -653,12 +709,13 @@ def prepare_image_tensor(
Args:
image (`torch.Tensor`):
Tensor with shape `(..., height, width, 3)` containing RGB values. The tensor is converted to floating
point if needed.
Tensor with shape `(..., height, width, 3)` containing RGB values.
The tensor is converted to floating point if needed.
scale (`float`, *optional*, defaults to `VISION_SCALE`):
Scalar multiplier applied before normalization.
Returns:
`torch.Tensor`: Normalized tensor with the same shape as the input and dtype `torch.float32`.
`torch.Tensor`: Normalized tensor with the same shape as the input and
dtype `torch.float32`.
"""
if not torch.is_floating_point(image):
image = image.float()
@ -683,17 +740,33 @@ def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor:
Returns:
`torch.Tensor`:
Patch tensor where each position stores the flattened pixels belonging to that patch.
Patch tensor where each position stores the flattened pixels
belonging to that patch.
Raises:
ValueError: If `height` or `width` is not divisible by `patch_size`.
"""
num_images, height, width, channels = image.shape
if height % patch_size or width % patch_size:
raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.")
patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels)
raise ValueError(
"Dimensions of images "
f"{image.shape} are not divisible by patch_size={patch_size}."
)
patches = image.reshape(
num_images,
height // patch_size,
patch_size,
width // patch_size,
patch_size,
channels,
)
patches = patches.permute(0, 1, 3, 2, 4, 5)
patches = patches.reshape(num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size)
patches = patches.reshape(
num_images,
height // patch_size,
width // patch_size,
channels * patch_size * patch_size,
)
return patches
@ -708,21 +781,26 @@ def process_vision_for_patches(
Args:
images (`torch.Tensor`):
Either `(height, width, channels)` for a single image or `(num_images, height, width, channels)` for a
batch. Channels are expected to be RGB.
Either `(height, width, channels)` for a single image or
`(num_images, height, width, channels)` for a batch. Channels are
expected to be RGB.
patch_size (`int`):
Edge length of square patches; implictly controls resize grid granularity.
max_num_patches (`int`):
Maximum number of patches allowed after resizing.
min_num_patches (`int`, *optional*):
Minimum number of patches. If provided, the routine upsamples images as needed to satisfy the lower bound.
Minimum number of patches. If provided, the routine upsamples images
as needed to satisfy the lower bound.
pixel_shuffle_scale (`int`, *optional*, defaults to 1):
pixel shuffle scale factor; influences the target grid that the function produces.
Pixel shuffle scale factor; influences the target grid that the
function produces.
Returns:
`tuple[torch.Tensor, list[int]]`: A pair `(patches, dims_virtual)` where `patches` has shape
`(num_images, target_h / patch_size, target_w / patch_size, channels * patch_size**2)` and `dims_virtual`
encodes effective `(images, height, width)` dimensions after optional pixel shuffling.
`tuple[torch.Tensor, list[int]]`: A pair `(patches, dims_virtual)`
where `patches` has shape `(num_images, target_h / patch_size, target_w
/ patch_size, channels * patch_size**2)` and `dims_virtual` encodes
effective `(images, height, width)` dimensions after optional pixel
shuffling.
"""
# Add batch dim if single image
if images.dim() == 3:
@ -788,7 +866,7 @@ class IsaacConfig(Qwen3Config):
**kwargs,
):
super().__init__(**kwargs)
# EventStreamProcessor parameters (for backward compatibility)
self.video_patch_size = vision_patch_size
self.vision_max_num_patches = vision_max_num_patches
@ -814,7 +892,6 @@ class IsaacImageProcessorKwargs(TypedDict, total=False):
class IsaacImageProcessor:
patch_size = 16
max_num_patches = 6144
min_num_patches = 256
@ -825,14 +902,18 @@ class IsaacImageProcessor:
def __init__(self, kwargs):
self.patch_size = kwargs.pop("patch_size", self.patch_size)
self.vision_max_num_patches = kwargs.pop("vision_max_num_patches", self.max_num_patches)
self.vision_min_num_patches = kwargs.pop("vision_min_num_patches", self.min_num_patches)
self.vision_max_num_patches = kwargs.pop(
"vision_max_num_patches", self.max_num_patches
)
self.vision_min_num_patches = kwargs.pop(
"vision_min_num_patches", self.min_num_patches
)
self.pixel_shuffle_scale = kwargs.pop("pixel_shuffle_scale", 2)
def preprocess(
self,
images: list[torch.Tensor],
return_tensors: Optional[Union[str, TensorType]],
return_tensors: str | TensorType | None,
**kwargs: Unpack[IsaacImageProcessorKwargs],
) -> BatchFeature:
"""Isaac's resize → normalize → patchify → pack."""
@ -840,9 +921,9 @@ class IsaacImageProcessor:
all_pixel_values: list[torch.Tensor] = []
all_image_grids: list[torch.Tensor] = []
for image in images:
for image in images:
image_tensor = extract_image_pil(image)
patches, dims_virtual = process_vision_for_patches(
image_tensor,
patch_size=self.patch_size,
@ -874,7 +955,10 @@ class IsaacImageProcessor:
final_image_grids = torch.empty(0, 3)
return BatchFeature(
data={"pixel_values": final_pixel_values, "image_grid_thw": final_image_grids},
data={
"pixel_values": final_pixel_values,
"image_grid_thw": final_image_grids,
},
tensor_type=return_tensors,
)
@ -899,7 +983,7 @@ class IsaacProcessor:
image_result = self.image_processor.preprocess(images, **kwargs)
result.update(image_result)
return BatchFeature(result)
def apply_chat_template(
self,
messages: list[dict[str, Any]],
@ -909,7 +993,7 @@ class IsaacProcessor:
) -> Any:
# Convert mixed content messages to simple text format
processed_messages = []
for message in messages:
if "content" in message and isinstance(message["content"], list):
# Handle mixed content (text + image)
@ -920,23 +1004,25 @@ class IsaacProcessor:
elif content_item.get("type") == "image":
# Replace image with vision token
text_parts.append(self.image_token)
processed_message = {
"role": message.get("role", "user"),
"content": "".join(text_parts)
"content": "".join(text_parts),
}
processed_messages.append(processed_message)
else:
# Regular text message
processed_messages.append(message)
return self.tokenizer.apply_chat_template(
processed_messages, tokenize=tokenize, add_generation_prompt=add_generation_prompt, **kwargs
processed_messages,
tokenize=tokenize,
add_generation_prompt=add_generation_prompt,
**kwargs,
)
class IsaacProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> IsaacConfig:
if hasattr(self.ctx, "get_hf_config"):
original_config = self.ctx.get_hf_config()
@ -945,10 +1031,16 @@ class IsaacProcessingInfo(BaseProcessingInfo):
# Vision parameters - map from HF names
vision_config=getattr(original_config, "vision_config", None),
vision_patch_size=getattr(original_config, "video_patch_size", 16),
vision_max_num_patches=getattr(original_config, "vision_max_num_patches", 256),
vision_min_num_patches=getattr(original_config, "vision_min_num_patches", None),
vision_max_num_patches=getattr(
original_config, "vision_max_num_patches", 256
),
vision_min_num_patches=getattr(
original_config, "vision_min_num_patches", None
),
pixel_shuffle_scale=getattr(original_config, "pixel_shuffle_scale", 1),
max_sequence_length=getattr(original_config, "max_sequence_length", 16384),
max_sequence_length=getattr(
original_config, "max_sequence_length", 16384
),
vision_token="<|image_pad|>",
)
return IsaacConfig()
@ -975,18 +1067,22 @@ class IsaacProcessingInfo(BaseProcessingInfo):
def get_image_processor(self, **kwargs) -> IsaacImageProcessor:
return self.get_hf_processor(**kwargs).image_processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_mm_max_tokens_per_item(
self, seq_len: int, mm_counts: Mapping[str, int],
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
hf_config = self.get_hf_config()
num_vision_tokens = hf_config.vision_max_num_patches // (hf_config.pixel_shuffle_scale**2)
num_vision_tokens = hf_config.vision_max_num_patches // (
hf_config.pixel_shuffle_scale**2
)
return {"image": num_vision_tokens}
class IsaacDummyInputsBuilder(BaseDummyInputsBuilder[IsaacProcessingInfo]):
class IsaacDummyInputsBuilder(BaseDummyInputsBuilder[IsaacProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
@ -1017,19 +1113,19 @@ class IsaacDummyInputsBuilder(BaseDummyInputsBuilder[IsaacProcessingInfo]):
class IsaacMultiModalProcessor(BaseMultiModalProcessor):
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
# Configure multimodal fields for Isaac model
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_grid_sizes = image_grid_thw.prod(-1)
return {
"pixel_values": MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
"pixel_values": MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes
),
"image_grid_thw": MultiModalFieldConfig.batched("image"),
}
@ -1039,24 +1135,23 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
#hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
# hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
placeholder_id = vocab.get("<|image_pad|>", 151655)
pixel_shuffle_scale = getattr(image_processor, 'pixel_shuffle_scale', 2)
merge_length = pixel_shuffle_scale ** 2
pixel_shuffle_scale = getattr(image_processor, "pixel_shuffle_scale", 2)
merge_length = pixel_shuffle_scale**2
def get_replacement_isaac(item_idx: int):
out_item = out_mm_kwargs["image"][item_idx]
grid_thw = out_item["image_grid_thw"].data
assert isinstance(grid_thw, torch.Tensor)
num_tokens = int(grid_thw.prod()) // merge_length
return [placeholder_id] * num_tokens
return [placeholder_id] * num_tokens
return [
PromptReplacement(
@ -1066,6 +1161,7 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor):
)
]
class Siglip2VisionTransformer(nn.Module):
def __init__(
self,
@ -1107,7 +1203,9 @@ class Siglip2VisionTransformer(nn.Module):
# Get embeddings from packed sequence
hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids))
grid_thws = torch.tensor([[1, token_grids[0][0].item(), token_grids[0][1].item()]])
grid_thws = torch.tensor(
[[1, token_grids[0][0].item(), token_grids[0][1].item()]]
)
last_hidden_state = self.encoder(hidden_states, grid_thws)
hidden_states = self.post_layernorm(last_hidden_state)
@ -1123,7 +1221,7 @@ class Siglip2VisionTransformer(nn.Module):
# Remove the pseudo batch dimension we added earlier
hidden_states = hidden_states.squeeze(0)
#return last_hidden_state
# return last_hidden_state
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
@ -1160,9 +1258,8 @@ class Siglip2VisionTransformer(nn.Module):
dummy_inputs=IsaacDummyInputsBuilder,
)
class IsaacForConditionalGeneration(
Qwen3ForCausalLM, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
Qwen3ForCausalLM, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -1176,11 +1273,11 @@ class IsaacForConditionalGeneration(
}
supports_encoder_tp_data = True
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.vision_embedding.": "vision_embedding.",
"model.vision_embedding.": "vision_embedding.",
}
)
@ -1188,11 +1285,10 @@ class IsaacForConditionalGeneration(
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<|image_pad|>"
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
config: IsaacConfig = vllm_config.model_config.hf_config
head_dim = config.head_dim
@ -1207,18 +1303,22 @@ class IsaacForConditionalGeneration(
# Initialize the parent class with updated config
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Create the language model module to match checkpoint structure
self.language_model = nn.ModuleDict({
"embed_tokens": self.model.embed_tokens,
"layers": self.model.layers,
"norm": self.model.norm
})
self.language_model = nn.ModuleDict(
{
"embed_tokens": self.model.embed_tokens,
"layers": self.model.layers,
"norm": self.model.norm,
}
)
config.vision_config.preserve_original_pe = True
config.vision_config.use_rope = False
config.vision_config.hidden_stride = config.vision_config.pixel_shuffle_scale_factor
config.vision_config.window_size = 32*2
config.vision_config.hidden_stride = (
config.vision_config.pixel_shuffle_scale_factor
)
config.vision_config.window_size = 32 * 2
config.vision_config.fullatt_block_indexes = None
vision_cfg = config.vision_config
if vision_cfg is None:
@ -1226,7 +1326,9 @@ class IsaacForConditionalGeneration(
hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2)
self.vision_embedding = nn.Sequential(
Siglip2VisionTransformer(vision_cfg, prefix=maybe_prefix(prefix, "vision_embedding")),
Siglip2VisionTransformer(
vision_cfg, prefix=maybe_prefix(prefix, "vision_embedding")
),
nn.Linear(
hidden_dim,
4 * hidden_dim,
@ -1250,26 +1352,32 @@ class IsaacForConditionalGeneration(
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""
vision_token_id = getattr(self.config, 'image_token_id', 151655)
vision_token_id = getattr(self.config, "image_token_id", 151655)
spatial_merge_size = hf_config.vision_config.pixel_shuffle_scale_factor
input_tokens_tensor = torch.tensor(input_tokens)
# Find image token positions
image_positions = torch.where(input_tokens_tensor == vision_token_id)[0].tolist()
# For text-only inputs, use Isaac's original logic from compute_position_ids_input_ids()
image_positions = torch.where(input_tokens_tensor == vision_token_id)[
0
].tolist()
# For text-only inputs, use Isaac's original logic from
# compute_position_ids_input_ids()
if len(image_positions) == 0:
seq_len = len(input_tokens)
# Create 3D positions where all dimensions get the same 1D temporal progression
# Create 3D positions where all dimensions get the same 1D temporal
# progression
position_ids = torch.arange(seq_len, dtype=torch.long)
position_ids = position_ids.view(1, -1).expand(1, -1) # [1, seq_len]
position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) # [1, seq_len, 3]
position_ids = position_ids.unsqueeze(2).expand(
-1, -1, 3
) # [1, seq_len, 3]
# vLLM expects shape [3, seq_len], so transpose
position_ids = position_ids.squeeze(0).transpose(0, 1) # [3, seq_len]
return position_ids, 0
events = []
image_idx = 0
current_pos = 0
@ -1278,7 +1386,7 @@ class IsaacForConditionalGeneration(
for image_pos in image_positions:
if image_pos <= last_processed_pos:
continue # Skip already processed positions
# Add any text before this image
if image_pos > current_pos:
text_tokens = image_pos - current_pos
@ -1288,21 +1396,23 @@ class IsaacForConditionalGeneration(
idx_range=(0, text_tokens),
)
events.append(text_event)
# Add image
t, h, w = image_grid_thw[image_idx]
llm_grid_h, llm_grid_w = h // spatial_merge_size, w // spatial_merge_size
image_tokens = t * llm_grid_h * llm_grid_w
image_event = Event(
modality_type=VisionType.image,
dims_virtual=[t, llm_grid_h, llm_grid_w],
idx_range=(0, image_tokens),
)
events.append(image_event)
current_pos = image_pos + image_tokens
last_processed_pos = current_pos - 1 # Mark up to this position as processed
last_processed_pos = (
current_pos - 1
) # Mark up to this position as processed
image_idx += 1
# Add final text segment if any
@ -1314,7 +1424,7 @@ class IsaacForConditionalGeneration(
idx_range=(0, text_tokens),
)
events.append(text_event)
stream = Stream(events)
tensor_stream = TensorStream([stream])
@ -1334,8 +1444,7 @@ class IsaacForConditionalGeneration(
def get_multimodal_embeddings(
self, **kwargs: object
) -> MultiModalEmbeddings | None:
) -> MultiModalEmbeddings | None:
pixel_values = kwargs.get("pixel_values")
image_grid_thw = kwargs.get("image_grid_thw")
@ -1343,15 +1452,21 @@ class IsaacForConditionalGeneration(
return []
# Convert image_grid_thw from [batch, 1, [T, H, W]] to [batch, [H, W]]
spatial_grids = image_grid_thw[:, 0, 1:3] # Extract H, W from [T, H, W] for each image
spatial_grids = image_grid_thw[
:, 0, 1:3
] # Extract H, W from [T, H, W] for each image
# Process packed sequence patches through vision_embedding module
vision_embeddings = self.vision_embedding((pixel_values, spatial_grids))
# Split concatenated embeddings for each image item (following Qwen2-VL pattern)
merge_size = self.config.vision_config.pixel_shuffle_scale_factor # Isaac uses pixel shuffle
sizes = spatial_grids.prod(-1) // (merge_size * merge_size) # H * W / (merge_size^2)
merge_size = (
self.config.vision_config.pixel_shuffle_scale_factor
) # Isaac uses pixel shuffle
sizes = spatial_grids.prod(-1) // (
merge_size * merge_size
) # H * W / (merge_size^2)
return vision_embeddings.split(sizes.tolist())
def get_input_embeddings(
@ -1362,13 +1477,11 @@ class IsaacForConditionalGeneration(
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# Get text embeddings from the base language model
inputs_embeds = super().get_input_embeddings(input_ids)
# If we have multimodal embeddings, merge them with text embeddings
if multimodal_embeddings is not None and len(multimodal_embeddings) != 0:
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
@ -1379,7 +1492,7 @@ class IsaacForConditionalGeneration(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = []
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)