mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 06:17:03 +08:00
fix cr
Signed-off-by: Yang <lymailforjob@gmail.com>
This commit is contained in:
parent
c7c3853e9e
commit
c4a6119925
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import Annotated, Any
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
@ -12,9 +12,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import Qwen3Config
|
||||
from transformers.image_processing_utils import BatchFeature
|
||||
from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig
|
||||
from transformers.tokenization_utils import TensorType
|
||||
from typing_extensions import TypedDict, Unpack
|
||||
|
||||
@ -67,28 +65,11 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.tokenizers.hf import get_cached_tokenizer
|
||||
|
||||
|
||||
class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig):
|
||||
"""Vision configuration for Isaac with Pixel Shuffle support.
|
||||
|
||||
Extends Siglip2VisionConfig with additional fields for pixel shuffle.
|
||||
"""
|
||||
|
||||
model_type = "pixel_shuffle_siglip2"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pixel_shuffle_scale_factor: int = 1,
|
||||
num_patches: int = 256,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Add our custom fields
|
||||
self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor
|
||||
self.num_patches = num_patches
|
||||
from vllm.transformers_utils.configs import (
|
||||
IsaacConfig,
|
||||
PixelShuffleSiglip2VisionConfig,
|
||||
)
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
|
||||
def create_cumulative_seq_lengths(
|
||||
@ -629,58 +610,6 @@ def process_vision_for_patches(
|
||||
return patches, dims_virtual
|
||||
|
||||
|
||||
class IsaacConfig(Qwen3Config):
|
||||
"""Configuration class for Isaac multimodal model."""
|
||||
|
||||
model_type = "isaac"
|
||||
sub_configs = {"vision_config": PixelShuffleSiglip2VisionConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
vision_patch_size: int = 16,
|
||||
vision_max_num_patches: int = 256,
|
||||
vision_min_num_patches: int | None = None,
|
||||
pixel_shuffle_scale: int = 1,
|
||||
max_sequence_length: int = 16384,
|
||||
vision_token: str = "<image>",
|
||||
vision_attn_implementation: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# EventStreamProcessor parameters (for backward compatibility)
|
||||
self.video_patch_size = vision_patch_size
|
||||
self.vision_max_num_patches = vision_max_num_patches
|
||||
self.vision_min_num_patches = vision_min_num_patches
|
||||
self.pixel_shuffle_scale = pixel_shuffle_scale
|
||||
|
||||
# Processing parameters
|
||||
self.max_sequence_length = max_sequence_length
|
||||
self.vision_token = vision_token
|
||||
|
||||
# Handle vision config - PixelShuffleSiglip2VisionConfig instance
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = PixelShuffleSiglip2VisionConfig(**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = PixelShuffleSiglip2VisionConfig()
|
||||
else:
|
||||
self.vision_config = vision_config
|
||||
|
||||
# Ensure compatibility with pretrained checkpoints
|
||||
self.vision_config.pixel_shuffle_scale_factor = getattr(
|
||||
self.vision_config,
|
||||
"pixel_shuffle_scale_factor",
|
||||
pixel_shuffle_scale,
|
||||
)
|
||||
self.vision_config.num_patches = getattr(
|
||||
self.vision_config,
|
||||
"num_patches",
|
||||
vision_max_num_patches,
|
||||
)
|
||||
self.vision_attn_implementation = vision_attn_implementation
|
||||
|
||||
|
||||
class IsaacImageProcessorKwargs(TypedDict, total=False):
|
||||
patch_size: int
|
||||
max_num_patches: int
|
||||
@ -914,6 +843,32 @@ class IsaacDummyInputsBuilder(BaseDummyInputsBuilder[IsaacProcessingInfo]):
|
||||
}
|
||||
|
||||
|
||||
class IsaacImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Schema for validating Isaac image inputs.
|
||||
|
||||
Dimensions:
|
||||
- np: Number of patches
|
||||
- d: Patch dimension
|
||||
- ni: Number of images
|
||||
|
||||
The schema enforces:
|
||||
- pixel_values must be 2D: (num_patches, patch_dim)
|
||||
- image_grid_thw must be 2D: (num_images, 3)
|
||||
where 3 represents [T, H, W]
|
||||
"""
|
||||
|
||||
pixel_values: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("np", "d"),
|
||||
]
|
||||
|
||||
image_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("ni", 3),
|
||||
]
|
||||
|
||||
|
||||
class IsaacMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
@ -1423,19 +1378,21 @@ class IsaacForConditionalGeneration(
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> dict[str, torch.Tensor] | None:
|
||||
) -> IsaacImagePixelInputs | None:
|
||||
pixel_values = kwargs.get("pixel_values")
|
||||
image_grid_thw = kwargs.get("image_grid_thw")
|
||||
if pixel_values is None or image_grid_thw is None:
|
||||
return None
|
||||
return {
|
||||
"pixel_values": pixel_values,
|
||||
"image_grid_thw": image_grid_thw,
|
||||
}
|
||||
|
||||
# TensorSchema will automatically validate shapes on initialization
|
||||
return IsaacImagePixelInputs(
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: dict[str, torch.Tensor],
|
||||
image_input: IsaacImagePixelInputs,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
pixel_values = image_input["pixel_values"]
|
||||
image_grid_thw = image_input["image_grid_thw"]
|
||||
@ -1445,8 +1402,6 @@ class IsaacForConditionalGeneration(
|
||||
device = next(self.language_model.parameters()).device
|
||||
dtype = self.vision_embedding.linear_fc1.weight.dtype
|
||||
pixel_values = pixel_values.to(device=device, dtype=dtype)
|
||||
if image_grid_thw.dim() == 3:
|
||||
image_grid_thw = image_grid_thw[0]
|
||||
spatial_grids = image_grid_thw[:, 1:3].to(device, dtype=torch.int32)
|
||||
|
||||
vision_embeddings = self.vision_embedding((pixel_values, spatial_grids))
|
||||
|
||||
@ -72,6 +72,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
||||
deepseek_v32="DeepseekV3Config",
|
||||
flex_olmo="FlexOlmoConfig",
|
||||
hunyuan_vl="HunYuanVLConfig",
|
||||
isaac="IsaacConfig",
|
||||
kimi_linear="KimiLinearConfig",
|
||||
kimi_vl="KimiVLConfig",
|
||||
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
|
||||
|
||||
@ -25,6 +25,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
|
||||
"HunYuanVLConfig": "vllm.transformers_utils.configs.hunyuan_vl",
|
||||
"HunYuanVLTextConfig": "vllm.transformers_utils.configs.hunyuan_vl",
|
||||
"HunYuanVLVisionConfig": "vllm.transformers_utils.configs.hunyuan_vl",
|
||||
"IsaacConfig": "vllm.transformers_utils.configs.isaac",
|
||||
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
|
||||
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||
@ -41,6 +42,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
|
||||
"NemotronHConfig": "vllm.transformers_utils.configs.nemotron_h",
|
||||
"Olmo3Config": "vllm.transformers_utils.configs.olmo3",
|
||||
"OvisConfig": "vllm.transformers_utils.configs.ovis",
|
||||
"PixelShuffleSiglip2VisionConfig": "vllm.transformers_utils.configs.isaac",
|
||||
"RadioConfig": "vllm.transformers_utils.configs.radio",
|
||||
"SpeculatorsConfig": "vllm.transformers_utils.configs.speculators.base",
|
||||
"UltravoxConfig": "vllm.transformers_utils.configs.ultravox",
|
||||
@ -65,6 +67,7 @@ __all__ = [
|
||||
"HunYuanVLConfig",
|
||||
"HunYuanVLTextConfig",
|
||||
"HunYuanVLVisionConfig",
|
||||
"IsaacConfig",
|
||||
"RWConfig",
|
||||
"JAISConfig",
|
||||
"Lfm2MoeConfig",
|
||||
@ -78,6 +81,7 @@ __all__ = [
|
||||
"NemotronHConfig",
|
||||
"Olmo3Config",
|
||||
"OvisConfig",
|
||||
"PixelShuffleSiglip2VisionConfig",
|
||||
"RadioConfig",
|
||||
"SpeculatorsConfig",
|
||||
"UltravoxConfig",
|
||||
|
||||
86
vllm/transformers_utils/configs/isaac.py
Normal file
86
vllm/transformers_utils/configs/isaac.py
Normal file
@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
from transformers import Qwen3Config
|
||||
from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig
|
||||
|
||||
|
||||
class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig):
|
||||
"""Vision configuration for Isaac with Pixel Shuffle support.
|
||||
|
||||
Extends Siglip2VisionConfig with additional fields for pixel shuffle.
|
||||
"""
|
||||
|
||||
model_type = "pixel_shuffle_siglip2"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pixel_shuffle_scale_factor: int = 1,
|
||||
num_patches: int = 256,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Add our custom fields
|
||||
self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor
|
||||
self.num_patches = num_patches
|
||||
|
||||
|
||||
class IsaacConfig(Qwen3Config):
|
||||
"""Configuration class for Isaac multimodal model."""
|
||||
|
||||
model_type = "isaac"
|
||||
sub_configs = {"vision_config": PixelShuffleSiglip2VisionConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
vision_patch_size: int = 16,
|
||||
vision_max_num_patches: int = 256,
|
||||
vision_min_num_patches: int | None = None,
|
||||
pixel_shuffle_scale: int = 1,
|
||||
max_sequence_length: int = 16384,
|
||||
vision_token: str = "<image>",
|
||||
vision_attn_implementation: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# EventStreamProcessor parameters (for backward compatibility)
|
||||
self.video_patch_size = vision_patch_size
|
||||
self.vision_max_num_patches = vision_max_num_patches
|
||||
self.vision_min_num_patches = vision_min_num_patches
|
||||
self.pixel_shuffle_scale = pixel_shuffle_scale
|
||||
|
||||
# Processing parameters
|
||||
self.max_sequence_length = max_sequence_length
|
||||
self.vision_token = vision_token
|
||||
|
||||
# Handle vision config - PixelShuffleSiglip2VisionConfig instance
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = PixelShuffleSiglip2VisionConfig(**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = PixelShuffleSiglip2VisionConfig()
|
||||
else:
|
||||
self.vision_config = vision_config
|
||||
|
||||
# Ensure compatibility with pretrained checkpoints
|
||||
self.vision_config.pixel_shuffle_scale_factor = getattr(
|
||||
self.vision_config,
|
||||
"pixel_shuffle_scale_factor",
|
||||
pixel_shuffle_scale,
|
||||
)
|
||||
self.vision_config.num_patches = getattr(
|
||||
self.vision_config,
|
||||
"num_patches",
|
||||
vision_max_num_patches,
|
||||
)
|
||||
self.vision_attn_implementation = vision_attn_implementation
|
||||
|
||||
|
||||
__all__ = [
|
||||
"IsaacConfig",
|
||||
"PixelShuffleSiglip2VisionConfig",
|
||||
]
|
||||
Loading…
x
Reference in New Issue
Block a user