mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
Migrate Qwen2 inputs to TensorSchema (#23475)
Signed-off-by: Benji Beck <benjibeck@meta.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
558f0907dc
commit
37a6fa95fd
@ -25,7 +25,7 @@
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -41,15 +41,13 @@ from transformers.models.whisper import WhisperFeatureExtractor
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.qwen2_5_vl import (
|
||||
Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs,
|
||||
Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs,
|
||||
Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs,
|
||||
Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs)
|
||||
from vllm.model_executor.models.qwen2_audio import (
|
||||
Qwen2AudioFeatureInputs, Qwen2AudioProcessingInfo,
|
||||
_get_feat_extract_output_lengths)
|
||||
Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths)
|
||||
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -66,9 +64,9 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
@ -81,6 +79,26 @@ except (ImportError, ModuleNotFoundError):
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- na: Number of audios
|
||||
- nmb: Number of mel bins
|
||||
- msl: Maximum sequence length
|
||||
- tsl: Total sequence length
|
||||
"""
|
||||
type: Literal["audio_features"]
|
||||
input_features: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("nmb", "tsl"),
|
||||
]
|
||||
|
||||
feature_attention_mask: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("na", "msl"),
|
||||
]
|
||||
|
||||
|
||||
def create_qwen2_5_omni_thinker_field_factory(
|
||||
spatial_merge_size: int
|
||||
) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str,
|
||||
@ -536,7 +554,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
return torch.concat(mm_input, dim=dim)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object) -> Optional[Qwen2AudioFeatureInputs]:
|
||||
self, **kwargs: object) -> Optional[Qwen2_5OmniAudioFeatureInputs]:
|
||||
input_audio_features = kwargs.pop('input_audio_features', None)
|
||||
audio_feature_lengths = kwargs.pop('audio_feature_lengths', None)
|
||||
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
|
||||
@ -550,7 +568,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
if not isinstance(input_audio_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio input features. "
|
||||
f"Got type: {type(input_audio_features)}")
|
||||
return Qwen2AudioFeatureInputs(
|
||||
return Qwen2_5OmniAudioFeatureInputs(
|
||||
type="audio_features",
|
||||
input_features=input_audio_features,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
feature_attention_mask=feature_attention_mask)
|
||||
@ -633,7 +652,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
|
||||
def _process_audio_input(
|
||||
self,
|
||||
audio_input: Qwen2AudioFeatureInputs,
|
||||
audio_input: Qwen2_5OmniAudioFeatureInputs,
|
||||
audio_hashes: list[str] = None,
|
||||
cached_audio_features: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
@ -660,8 +679,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
feature_lens=audio_feature_lengths,
|
||||
aftercnn_lens=audio_feat_lengths,
|
||||
)
|
||||
audio_features = audio_outputs.last_hidden_state
|
||||
return audio_features.split(audio_output_lengths.tolist())
|
||||
return audio_outputs.last_hidden_state.split(
|
||||
audio_output_lengths.tolist())
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
@ -707,7 +726,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
|
||||
)
|
||||
class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
nn.Module, SupportsMultiModal, SupportsPP,
|
||||
Qwen2_5OmniConditionalGenerationMixin):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
@ -800,15 +819,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""Get module prefix for multimodal models to filter LoRA modules."""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector=[], # No explicit connector in this model
|
||||
tower_model=["visual",
|
||||
"audio_tower"], # Exclude vision and audio towers
|
||||
)
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@
|
||||
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping
|
||||
from functools import lru_cache, partial
|
||||
from typing import Callable, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -64,6 +64,7 @@ from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
@ -80,84 +81,125 @@ logger = init_logger(__name__)
|
||||
# === Vision Inputs === #
|
||||
|
||||
|
||||
class Qwen2_5_VLImagePixelInputs(TypedDict):
|
||||
class Qwen2_5_VLImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- np: Number of patches
|
||||
- ni: Number of images
|
||||
- cps: Number of channels * patch_size * patch_size
|
||||
|
||||
Historical context:
|
||||
- pixel_values shape: (num_patches, num_channels * patch_size *
|
||||
patch_size)
|
||||
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
|
||||
formatnum_channels * patch_size * patch_size
|
||||
"""
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: torch.Tensor
|
||||
"""Shape:
|
||||
`(num_patches, num_channels * patch_size * patch_size)`
|
||||
|
||||
pixel_values: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("np", "cps"),
|
||||
]
|
||||
|
||||
image_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("ni", 3),
|
||||
]
|
||||
|
||||
|
||||
class Qwen2_5_VLImageEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
|
||||
image_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_images, 3)`
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
Dimensions:
|
||||
- nf: Number of image features
|
||||
- hs: Hidden size
|
||||
- ni: Number of images
|
||||
|
||||
Historical context:
|
||||
- image_embeds shape: (num_image_features, hidden_size)
|
||||
- num_image_features varies based on the number and resolution of the
|
||||
images.
|
||||
- hidden_size must match the hidden size of language model backbone.
|
||||
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
"""
|
||||
|
||||
|
||||
class Qwen2_5_VLImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
image_embeds: torch.Tensor
|
||||
"""Supported types:
|
||||
- list[`torch.Tensor`]: A list of tensors holding all images' features.
|
||||
Each tensor holds an image's features.
|
||||
- `torch.Tensor`: A tensor holding all images' features
|
||||
(concatenation of all images' feature tensors).
|
||||
|
||||
Tensor shape: `(num_image_features, hidden_size)`
|
||||
- `num_image_features` varies based on
|
||||
the number and resolution of the images.
|
||||
- `hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
image_embeds: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nf", "hs"),
|
||||
]
|
||||
|
||||
image_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_images, 3)`
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
"""
|
||||
image_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("ni", 3),
|
||||
]
|
||||
|
||||
|
||||
Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs,
|
||||
Qwen2_5_VLImageEmbeddingInputs]
|
||||
|
||||
|
||||
class Qwen2_5_VLVideoPixelInputs(TypedDict):
|
||||
class Qwen2_5_VLVideoPixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- np: Number of patches
|
||||
- nv: Number of videos
|
||||
- ctps: Number of channels * temporal_patch_size * patch_size *
|
||||
patch_size
|
||||
|
||||
Historical context:
|
||||
- pixel_values_videos shape: (num_patches, num_channels *
|
||||
temporal_patch_size * patch_size * patch_size)
|
||||
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
- second_per_grid_ts: The video time interval (in seconds) for each
|
||||
grid along the temporal dimension in the 3D position IDs. Returned
|
||||
when `videos` is not `None`.
|
||||
"""
|
||||
type: Literal["pixel_values_videos"]
|
||||
pixel_values_videos: torch.Tensor
|
||||
"""Shape:
|
||||
`(num_patches,
|
||||
num_channels * temporal_patch_size * patch_size * patch_size)`
|
||||
|
||||
pixel_values_videos: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("np", "ctps"),
|
||||
]
|
||||
|
||||
video_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nv", 3),
|
||||
]
|
||||
|
||||
second_per_grid_ts: Annotated[
|
||||
Optional[torch.Tensor],
|
||||
TensorShape("nv"),
|
||||
]
|
||||
|
||||
|
||||
class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
|
||||
video_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_videos, 3)`
|
||||
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
Dimensions:
|
||||
- nf: Number of video features
|
||||
- hs: Hidden size
|
||||
- nv: Number of videos
|
||||
|
||||
Historical context:
|
||||
- video_embeds shape: (num_video_features, hidden_size)
|
||||
- num_video_features varies based on the number and resolution of the
|
||||
videos.
|
||||
- hidden_size must match the hidden size of language model backbone.
|
||||
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
"""
|
||||
|
||||
second_per_grid_ts: torch.Tensor
|
||||
"""
|
||||
The video time interval (in seconds) for each grid along the temporal
|
||||
dimension in the 3D position IDs. Returned when `videos` is not `None`.
|
||||
"""
|
||||
|
||||
|
||||
class Qwen2_5_VLVideoEmbeddingInputs(TypedDict):
|
||||
type: Literal["video_embeds"]
|
||||
video_embeds: torch.Tensor
|
||||
"""Supported types:
|
||||
- list[`torch.Tensor`]: A list of tensors holding all videos' features.
|
||||
Each tensor holds an video's features.
|
||||
- `torch.Tensor`: A tensor holding all videos' features
|
||||
(concatenation of all videos' feature tensors).
|
||||
|
||||
Tensor shape: `(num_image_features, hidden_size)`
|
||||
- `num_image_features` varies based on
|
||||
the number and resolution of the videos.
|
||||
- `hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
video_embeds: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nf", "hs"),
|
||||
]
|
||||
|
||||
video_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_videos, 3)`
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
"""
|
||||
video_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nv", 3),
|
||||
]
|
||||
|
||||
|
||||
Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs,
|
||||
@ -936,10 +978,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return Qwen2_5_VLImagePixelInputs(type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw)
|
||||
@ -950,9 +988,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
return Qwen2_5_VLImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
image_embeds=image_embeds,
|
||||
@ -973,7 +1008,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
pixel_values_videos, "video pixel values")
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw")
|
||||
|
||||
if second_per_grid_ts is not None and second_per_grid_ts.ndim == 2:
|
||||
second_per_grid_ts = second_per_grid_ts.squeeze(-1)
|
||||
return Qwen2_5_VLVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
@ -987,9 +1023,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw")
|
||||
|
||||
if not isinstance(video_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of video embeddings. "
|
||||
f"Got type: {type(video_embeds)}")
|
||||
return Qwen2_5_VLVideoEmbeddingInputs(
|
||||
type="video_embeds",
|
||||
video_embeds=video_embeds,
|
||||
|
||||
@ -23,7 +23,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -47,6 +47,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
@ -54,21 +55,38 @@ from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
|
||||
|
||||
# # === Audio Inputs === #
|
||||
class Qwen2AudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
input_features: torch.Tensor
|
||||
"""Shape: `(num_audios, num_mel_bins, 3000)`"""
|
||||
|
||||
feature_attention_mask: torch.Tensor
|
||||
"""Shape: `(num_audios, 3000)`"""
|
||||
|
||||
|
||||
class Qwen2AudioEmbeddingInputs(TypedDict):
|
||||
type: Literal["audio_embeds"]
|
||||
audio_embeds: list[torch.Tensor]
|
||||
"""Shape: `(num_audio_features, hidden_size)`
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
class Qwen2AudioFeatureInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- na: Number of audios
|
||||
- nmb: Number of mel bins
|
||||
"""
|
||||
type: Literal["audio_features"]
|
||||
input_features: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("na", "nmb", 3000),
|
||||
]
|
||||
|
||||
feature_attention_mask: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("na", 3000),
|
||||
]
|
||||
|
||||
|
||||
class Qwen2AudioEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size
|
||||
- naf: Number of audio features
|
||||
- hs: Hidden size (must match the hidden size of language model
|
||||
backbone)
|
||||
"""
|
||||
type: Literal["audio_embeds"] = "audio_embeds"
|
||||
|
||||
audio_embeds: Annotated[
|
||||
list[torch.Tensor],
|
||||
TensorShape("bn", "naf", "hs"),
|
||||
]
|
||||
|
||||
|
||||
Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs]
|
||||
|
||||
@ -26,7 +26,7 @@
|
||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -70,6 +70,7 @@ from vllm.platforms import _Backend, current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
@ -86,78 +87,119 @@ _MAX_FRAMES_PER_VIDEO = 16
|
||||
# === Vision Inputs === #
|
||||
|
||||
|
||||
class Qwen2VLImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: torch.Tensor
|
||||
"""Shape:
|
||||
`(num_patches, num_channels * patch_size * patch_size)`
|
||||
class Qwen2VLImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
|
||||
image_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_images, 3)`
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
"""
|
||||
|
||||
|
||||
class Qwen2VLImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
image_embeds: torch.Tensor
|
||||
"""Supported types:
|
||||
- list[`torch.Tensor`]: A list of tensors holding all images' features.
|
||||
Each tensor holds an image's features.
|
||||
- `torch.Tensor`: A tensor holding all images' features
|
||||
(concatenation of all images' feature tensors).
|
||||
Dimensions:
|
||||
- np: The total number of patches over each image over each prompt in
|
||||
the batch
|
||||
- ni: Number of images
|
||||
- cps: Number of channels * patch_size * patch_size
|
||||
|
||||
Tensor shape: `(num_image_features, hidden_size)`
|
||||
- `num_image_features` varies based on
|
||||
the number and resolution of the images.
|
||||
- `hidden_size` must match the hidden size of language model backbone.
|
||||
Historical context:
|
||||
- pixel_values shape: (num_patches, num_channels * patch_size *
|
||||
patch_size)
|
||||
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
"""
|
||||
type: Literal["pixel_values"]
|
||||
|
||||
image_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_images, 3)`
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
pixel_values: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("np", "cps"),
|
||||
]
|
||||
|
||||
image_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("ni", 3),
|
||||
]
|
||||
|
||||
|
||||
class Qwen2VLImageEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- nf: Number of image features
|
||||
- hs: Hidden size
|
||||
- ni: Number of images
|
||||
|
||||
Historical context:
|
||||
- image_embeds shape: (num_image_features, hidden_size)
|
||||
- num_image_features varies based on the number and resolution of the
|
||||
images.
|
||||
- hidden_size must match the hidden size of language model backbone.
|
||||
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
"""
|
||||
type: Literal["image_embeds"]
|
||||
|
||||
image_embeds: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nf", "hs"),
|
||||
]
|
||||
|
||||
image_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("ni", 3),
|
||||
]
|
||||
|
||||
|
||||
Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
|
||||
Qwen2VLImageEmbeddingInputs]
|
||||
|
||||
|
||||
class Qwen2VLVideoPixelInputs(TypedDict):
|
||||
type: Literal["pixel_values_videos"]
|
||||
pixel_values_videos: torch.Tensor
|
||||
"""Shape:
|
||||
`(num_patches,
|
||||
num_channels * temporal_patch_size * patch_size * patch_size)`
|
||||
class Qwen2VLVideoPixelInputs(TensorSchema):
|
||||
"""
|
||||
|
||||
video_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_videos, 3)`
|
||||
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
"""
|
||||
|
||||
|
||||
class Qwen2VLVideoEmbeddingInputs(TypedDict):
|
||||
type: Literal["video_embeds"]
|
||||
video_embeds: torch.Tensor
|
||||
"""Supported types:
|
||||
- list[`torch.Tensor`]: A list of tensors holding all videos' features.
|
||||
Each tensor holds an video's features.
|
||||
- `torch.Tensor`: A tensor holding all videos' features
|
||||
(concatenation of all videos' feature tensors).
|
||||
Dimensions:
|
||||
- np: The total number of patches over each video over each prompt in
|
||||
the batch
|
||||
- ctps: Number of channels * temporal_patch_size * patch_size *
|
||||
patch_size
|
||||
- nv: Number of videos
|
||||
|
||||
Tensor shape: `(num_image_features, hidden_size)`
|
||||
- `num_image_features` varies based on
|
||||
the number and resolution of the videos.
|
||||
- `hidden_size` must match the hidden size of language model backbone.
|
||||
Historical context:
|
||||
- pixel_values_videos shape: (num_patches, num_channels *
|
||||
temporal_patch_size * patch_size * patch_size)
|
||||
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
"""
|
||||
type: Literal["pixel_values_videos"]
|
||||
|
||||
video_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_videos, 3)`
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
pixel_values_videos: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("np", "ctps"),
|
||||
]
|
||||
|
||||
video_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nv", 3),
|
||||
]
|
||||
|
||||
|
||||
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- nf: Number of video features
|
||||
- hs: Hidden size
|
||||
- nv: Number of videos
|
||||
|
||||
Historical context:
|
||||
- video_embeds shape: (num_video_features, hidden_size)
|
||||
- num_video_features varies based on the number and resolution of the
|
||||
videos.
|
||||
- hidden_size must match the hidden size of language model backbone.
|
||||
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
"""
|
||||
type: Literal["video_embeds"]
|
||||
|
||||
video_embeds: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nf", "hs"),
|
||||
]
|
||||
|
||||
video_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nv", 3),
|
||||
]
|
||||
|
||||
|
||||
Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs,
|
||||
@ -1126,10 +1168,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return Qwen2VLImagePixelInputs(type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw)
|
||||
@ -1140,9 +1178,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
return Qwen2VLImageEmbeddingInputs(type="image_embeds",
|
||||
image_embeds=image_embeds,
|
||||
image_grid_thw=image_grid_thw)
|
||||
@ -1174,9 +1209,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw")
|
||||
|
||||
if not isinstance(video_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of video embeddings. "
|
||||
f"Got type: {type(video_embeds)}")
|
||||
return Qwen2VLVideoEmbeddingInputs(type="video_embeds",
|
||||
video_embeds=video_embeds,
|
||||
video_grid_thw=video_grid_thw)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user