Migrate Qwen2 inputs to TensorSchema (#23475)

Signed-off-by: Benji Beck <benjibeck@meta.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Benji Beck 2025-09-06 20:07:31 -07:00 committed by GitHub
parent 558f0907dc
commit 37a6fa95fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 268 additions and 175 deletions

View File

@ -25,7 +25,7 @@
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from copy import copy from copy import copy
from functools import partial from functools import partial
from typing import Any, Callable, Optional, Union from typing import Annotated, Any, Callable, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -41,15 +41,13 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2_5_vl import ( from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs,
Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs, Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs,
Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs,
Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs)
from vllm.model_executor.models.qwen2_audio import ( from vllm.model_executor.models.qwen2_audio import (
Qwen2AudioFeatureInputs, Qwen2AudioProcessingInfo, Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths)
_get_feat_extract_output_lengths)
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@ -66,9 +64,9 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
@ -81,6 +79,26 @@ except (ImportError, ModuleNotFoundError):
logger = init_logger(__name__) logger = init_logger(__name__)
class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
"""
Dimensions:
- na: Number of audios
- nmb: Number of mel bins
- msl: Maximum sequence length
- tsl: Total sequence length
"""
type: Literal["audio_features"]
input_features: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("nmb", "tsl"),
]
feature_attention_mask: Annotated[
torch.Tensor,
TensorShape("na", "msl"),
]
def create_qwen2_5_omni_thinker_field_factory( def create_qwen2_5_omni_thinker_field_factory(
spatial_merge_size: int spatial_merge_size: int
) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, ) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str,
@ -536,7 +554,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
return torch.concat(mm_input, dim=dim) return torch.concat(mm_input, dim=dim)
def _parse_and_validate_audio_input( def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Qwen2AudioFeatureInputs]: self, **kwargs: object) -> Optional[Qwen2_5OmniAudioFeatureInputs]:
input_audio_features = kwargs.pop('input_audio_features', None) input_audio_features = kwargs.pop('input_audio_features', None)
audio_feature_lengths = kwargs.pop('audio_feature_lengths', None) audio_feature_lengths = kwargs.pop('audio_feature_lengths', None)
feature_attention_mask = kwargs.pop('feature_attention_mask', None) feature_attention_mask = kwargs.pop('feature_attention_mask', None)
@ -550,7 +568,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
if not isinstance(input_audio_features, (torch.Tensor, list)): if not isinstance(input_audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio input features. " raise ValueError("Incorrect type of audio input features. "
f"Got type: {type(input_audio_features)}") f"Got type: {type(input_audio_features)}")
return Qwen2AudioFeatureInputs( return Qwen2_5OmniAudioFeatureInputs(
type="audio_features",
input_features=input_audio_features, input_features=input_audio_features,
audio_feature_lengths=audio_feature_lengths, audio_feature_lengths=audio_feature_lengths,
feature_attention_mask=feature_attention_mask) feature_attention_mask=feature_attention_mask)
@ -633,7 +652,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
def _process_audio_input( def _process_audio_input(
self, self,
audio_input: Qwen2AudioFeatureInputs, audio_input: Qwen2_5OmniAudioFeatureInputs,
audio_hashes: list[str] = None, audio_hashes: list[str] = None,
cached_audio_features: torch.Tensor = None, cached_audio_features: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -660,8 +679,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
feature_lens=audio_feature_lengths, feature_lens=audio_feature_lengths,
aftercnn_lens=audio_feat_lengths, aftercnn_lens=audio_feat_lengths,
) )
audio_features = audio_outputs.last_hidden_state return audio_outputs.last_hidden_state.split(
return audio_features.split(audio_output_lengths.tolist()) audio_output_lengths.tolist())
def _process_image_input( def _process_image_input(
self, self,
@ -707,7 +726,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
) )
class Qwen2_5OmniThinkerForConditionalGeneration( class Qwen2_5OmniThinkerForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, nn.Module, SupportsMultiModal, SupportsPP,
Qwen2_5OmniConditionalGenerationMixin): Qwen2_5OmniConditionalGenerationMixin):
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
@ -800,15 +819,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_mm_mapping(self) -> MultiModelKeys:
"""Get module prefix for multimodal models to filter LoRA modules."""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector=[], # No explicit connector in this model
tower_model=["visual",
"audio_tower"], # Exclude vision and audio towers
)
def get_multimodal_embeddings(self, def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings: **kwargs: object) -> MultiModalEmbeddings:

View File

@ -27,7 +27,7 @@
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Callable, Literal, Optional, TypedDict, Union from typing import Annotated, Callable, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -64,6 +64,7 @@ from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
from vllm.platforms import _Backend from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsQuant) SupportsMultiModal, SupportsPP, SupportsQuant)
@ -80,84 +81,125 @@ logger = init_logger(__name__)
# === Vision Inputs === # # === Vision Inputs === #
class Qwen2_5_VLImagePixelInputs(TypedDict): class Qwen2_5_VLImagePixelInputs(TensorSchema):
"""
Dimensions:
- np: Number of patches
- ni: Number of images
- cps: Number of channels * patch_size * patch_size
Historical context:
- pixel_values shape: (num_patches, num_channels * patch_size *
patch_size)
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
formatnum_channels * patch_size * patch_size
"""
type: Literal["pixel_values"] type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""Shape: pixel_values: Annotated[
`(num_patches, num_channels * patch_size * patch_size)` torch.Tensor,
TensorShape("np", "cps"),
]
image_grid_thw: Annotated[
torch.Tensor,
TensorShape("ni", 3),
]
class Qwen2_5_VLImageEmbeddingInputs(TensorSchema):
""" """
Dimensions:
image_grid_thw: torch.Tensor - nf: Number of image features
"""Shape: `(num_images, 3)` - hs: Hidden size
This should be in `(grid_t, grid_h, grid_w)` format. - ni: Number of images
Historical context:
- image_embeds shape: (num_image_features, hidden_size)
- num_image_features varies based on the number and resolution of the
images.
- hidden_size must match the hidden size of language model backbone.
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
format
""" """
class Qwen2_5_VLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
image_embeds: torch.Tensor
"""Supported types:
- list[`torch.Tensor`]: A list of tensors holding all images' features.
Each tensor holds an image's features.
- `torch.Tensor`: A tensor holding all images' features
(concatenation of all images' feature tensors).
Tensor shape: `(num_image_features, hidden_size)` image_embeds: Annotated[
- `num_image_features` varies based on torch.Tensor,
the number and resolution of the images. TensorShape("nf", "hs"),
- `hidden_size` must match the hidden size of language model backbone. ]
"""
image_grid_thw: torch.Tensor image_grid_thw: Annotated[
"""Shape: `(num_images, 3)` torch.Tensor,
This should be in `(grid_t, grid_h, grid_w)` format. TensorShape("ni", 3),
""" ]
Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs, Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs,
Qwen2_5_VLImageEmbeddingInputs] Qwen2_5_VLImageEmbeddingInputs]
class Qwen2_5_VLVideoPixelInputs(TypedDict): class Qwen2_5_VLVideoPixelInputs(TensorSchema):
"""
Dimensions:
- np: Number of patches
- nv: Number of videos
- ctps: Number of channels * temporal_patch_size * patch_size *
patch_size
Historical context:
- pixel_values_videos shape: (num_patches, num_channels *
temporal_patch_size * patch_size * patch_size)
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
format
- second_per_grid_ts: The video time interval (in seconds) for each
grid along the temporal dimension in the 3D position IDs. Returned
when `videos` is not `None`.
"""
type: Literal["pixel_values_videos"] type: Literal["pixel_values_videos"]
pixel_values_videos: torch.Tensor
"""Shape: pixel_values_videos: Annotated[
`(num_patches, torch.Tensor,
num_channels * temporal_patch_size * patch_size * patch_size)` TensorShape("np", "ctps"),
]
video_grid_thw: Annotated[
torch.Tensor,
TensorShape("nv", 3),
]
second_per_grid_ts: Annotated[
Optional[torch.Tensor],
TensorShape("nv"),
]
class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
""" """
Dimensions:
video_grid_thw: torch.Tensor - nf: Number of video features
"""Shape: `(num_videos, 3)` - hs: Hidden size
- nv: Number of videos
This should be in `(grid_t, grid_h, grid_w)` format.
Historical context:
- video_embeds shape: (num_video_features, hidden_size)
- num_video_features varies based on the number and resolution of the
videos.
- hidden_size must match the hidden size of language model backbone.
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
format
""" """
second_per_grid_ts: torch.Tensor
"""
The video time interval (in seconds) for each grid along the temporal
dimension in the 3D position IDs. Returned when `videos` is not `None`.
"""
class Qwen2_5_VLVideoEmbeddingInputs(TypedDict):
type: Literal["video_embeds"] type: Literal["video_embeds"]
video_embeds: torch.Tensor
"""Supported types:
- list[`torch.Tensor`]: A list of tensors holding all videos' features.
Each tensor holds an video's features.
- `torch.Tensor`: A tensor holding all videos' features
(concatenation of all videos' feature tensors).
Tensor shape: `(num_image_features, hidden_size)` video_embeds: Annotated[
- `num_image_features` varies based on torch.Tensor,
the number and resolution of the videos. TensorShape("nf", "hs"),
- `hidden_size` must match the hidden size of language model backbone. ]
"""
video_grid_thw: torch.Tensor video_grid_thw: Annotated[
"""Shape: `(num_videos, 3)` torch.Tensor,
This should be in `(grid_t, grid_h, grid_w)` format. TensorShape("nv", 3),
""" ]
Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs, Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs,
@ -936,10 +978,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw") image_grid_thw, "image grid_thw")
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")
return Qwen2_5_VLImagePixelInputs(type="pixel_values", return Qwen2_5_VLImagePixelInputs(type="pixel_values",
pixel_values=pixel_values, pixel_values=pixel_values,
image_grid_thw=image_grid_thw) image_grid_thw=image_grid_thw)
@ -950,9 +988,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw") image_grid_thw, "image grid_thw")
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Qwen2_5_VLImageEmbeddingInputs( return Qwen2_5_VLImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
image_embeds=image_embeds, image_embeds=image_embeds,
@ -973,7 +1008,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values_videos, "video pixel values") pixel_values_videos, "video pixel values")
video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw") video_grid_thw, "video grid_thw")
if second_per_grid_ts is not None and second_per_grid_ts.ndim == 2:
second_per_grid_ts = second_per_grid_ts.squeeze(-1)
return Qwen2_5_VLVideoPixelInputs( return Qwen2_5_VLVideoPixelInputs(
type="pixel_values_videos", type="pixel_values_videos",
pixel_values_videos=pixel_values_videos, pixel_values_videos=pixel_values_videos,
@ -987,9 +1023,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw") video_grid_thw, "video grid_thw")
if not isinstance(video_embeds, torch.Tensor):
raise ValueError("Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}")
return Qwen2_5_VLVideoEmbeddingInputs( return Qwen2_5_VLVideoEmbeddingInputs(
type="video_embeds", type="video_embeds",
video_embeds=video_embeds, video_embeds=video_embeds,

View File

@ -23,7 +23,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict, Union from typing import Annotated, Any, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -47,6 +47,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate, PromptUpdateDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
@ -54,21 +55,38 @@ from .utils import (AutoWeightsLoader, init_vllm_registered_model,
# # === Audio Inputs === # # # === Audio Inputs === #
class Qwen2AudioFeatureInputs(TypedDict): class Qwen2AudioFeatureInputs(TensorSchema):
type: Literal["audio_features"]
input_features: torch.Tensor
"""Shape: `(num_audios, num_mel_bins, 3000)`"""
feature_attention_mask: torch.Tensor
"""Shape: `(num_audios, 3000)`"""
class Qwen2AudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
audio_embeds: list[torch.Tensor]
"""Shape: `(num_audio_features, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
""" """
Dimensions:
- na: Number of audios
- nmb: Number of mel bins
"""
type: Literal["audio_features"]
input_features: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("na", "nmb", 3000),
]
feature_attention_mask: Annotated[
torch.Tensor,
TensorShape("na", 3000),
]
class Qwen2AudioEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size
- naf: Number of audio features
- hs: Hidden size (must match the hidden size of language model
backbone)
"""
type: Literal["audio_embeds"] = "audio_embeds"
audio_embeds: Annotated[
list[torch.Tensor],
TensorShape("bn", "naf", "hs"),
]
Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs] Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs]

View File

@ -26,7 +26,7 @@
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Any, Callable, Literal, Optional, TypedDict, Union from typing import Annotated, Any, Callable, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -70,6 +70,7 @@ from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
@ -86,78 +87,119 @@ _MAX_FRAMES_PER_VIDEO = 16
# === Vision Inputs === # # === Vision Inputs === #
class Qwen2VLImagePixelInputs(TypedDict): class Qwen2VLImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
""" """
Dimensions:
image_grid_thw: torch.Tensor - np: The total number of patches over each image over each prompt in
"""Shape: `(num_images, 3)` the batch
This should be in `(grid_t, grid_h, grid_w)` format. - ni: Number of images
""" - cps: Number of channels * patch_size * patch_size
class Qwen2VLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
image_embeds: torch.Tensor
"""Supported types:
- list[`torch.Tensor`]: A list of tensors holding all images' features.
Each tensor holds an image's features.
- `torch.Tensor`: A tensor holding all images' features
(concatenation of all images' feature tensors).
Tensor shape: `(num_image_features, hidden_size)` Historical context:
- `num_image_features` varies based on - pixel_values shape: (num_patches, num_channels * patch_size *
the number and resolution of the images. patch_size)
- `hidden_size` must match the hidden size of language model backbone. - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
format
""" """
type: Literal["pixel_values"]
image_grid_thw: torch.Tensor pixel_values: Annotated[
"""Shape: `(num_images, 3)` torch.Tensor,
This should be in `(grid_t, grid_h, grid_w)` format. TensorShape("np", "cps"),
]
image_grid_thw: Annotated[
torch.Tensor,
TensorShape("ni", 3),
]
class Qwen2VLImageEmbeddingInputs(TensorSchema):
""" """
Dimensions:
- nf: Number of image features
- hs: Hidden size
- ni: Number of images
Historical context:
- image_embeds shape: (num_image_features, hidden_size)
- num_image_features varies based on the number and resolution of the
images.
- hidden_size must match the hidden size of language model backbone.
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
format
"""
type: Literal["image_embeds"]
image_embeds: Annotated[
torch.Tensor,
TensorShape("nf", "hs"),
]
image_grid_thw: Annotated[
torch.Tensor,
TensorShape("ni", 3),
]
Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
Qwen2VLImageEmbeddingInputs] Qwen2VLImageEmbeddingInputs]
class Qwen2VLVideoPixelInputs(TypedDict): class Qwen2VLVideoPixelInputs(TensorSchema):
type: Literal["pixel_values_videos"]
pixel_values_videos: torch.Tensor
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
""" """
Dimensions:
video_grid_thw: torch.Tensor - np: The total number of patches over each video over each prompt in
"""Shape: `(num_videos, 3)` the batch
- ctps: Number of channels * temporal_patch_size * patch_size *
This should be in `(grid_t, grid_h, grid_w)` format. patch_size
""" - nv: Number of videos
class Qwen2VLVideoEmbeddingInputs(TypedDict):
type: Literal["video_embeds"]
video_embeds: torch.Tensor
"""Supported types:
- list[`torch.Tensor`]: A list of tensors holding all videos' features.
Each tensor holds an video's features.
- `torch.Tensor`: A tensor holding all videos' features
(concatenation of all videos' feature tensors).
Tensor shape: `(num_image_features, hidden_size)` Historical context:
- `num_image_features` varies based on - pixel_values_videos shape: (num_patches, num_channels *
the number and resolution of the videos. temporal_patch_size * patch_size * patch_size)
- `hidden_size` must match the hidden size of language model backbone. - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
format
""" """
type: Literal["pixel_values_videos"]
video_grid_thw: torch.Tensor pixel_values_videos: Annotated[
"""Shape: `(num_videos, 3)` torch.Tensor,
This should be in `(grid_t, grid_h, grid_w)` format. TensorShape("np", "ctps"),
]
video_grid_thw: Annotated[
torch.Tensor,
TensorShape("nv", 3),
]
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
""" """
Dimensions:
- nf: Number of video features
- hs: Hidden size
- nv: Number of videos
Historical context:
- video_embeds shape: (num_video_features, hidden_size)
- num_video_features varies based on the number and resolution of the
videos.
- hidden_size must match the hidden size of language model backbone.
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
format
"""
type: Literal["video_embeds"]
video_embeds: Annotated[
torch.Tensor,
TensorShape("nf", "hs"),
]
video_grid_thw: Annotated[
torch.Tensor,
TensorShape("nv", 3),
]
Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs, Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs,
@ -1126,10 +1168,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw") image_grid_thw, "image grid_thw")
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")
return Qwen2VLImagePixelInputs(type="pixel_values", return Qwen2VLImagePixelInputs(type="pixel_values",
pixel_values=pixel_values, pixel_values=pixel_values,
image_grid_thw=image_grid_thw) image_grid_thw=image_grid_thw)
@ -1140,9 +1178,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw") image_grid_thw, "image grid_thw")
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Qwen2VLImageEmbeddingInputs(type="image_embeds", return Qwen2VLImageEmbeddingInputs(type="image_embeds",
image_embeds=image_embeds, image_embeds=image_embeds,
image_grid_thw=image_grid_thw) image_grid_thw=image_grid_thw)
@ -1174,9 +1209,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw") video_grid_thw, "video grid_thw")
if not isinstance(video_embeds, torch.Tensor):
raise ValueError("Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}")
return Qwen2VLVideoEmbeddingInputs(type="video_embeds", return Qwen2VLVideoEmbeddingInputs(type="video_embeds",
video_embeds=video_embeds, video_embeds=video_embeds,
video_grid_thw=video_grid_thw) video_grid_thw=video_grid_thw)