From 053278a5dc7a81d751f8e63c1ed793062b32cbce Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sat, 23 Aug 2025 21:55:53 -0700 Subject: [PATCH] Migrate Pixtral inputs to TensorSchema (#23472) Signed-off-by: Benji Beck --- vllm/model_executor/models/pixtral.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index c01074e2122bb..461b9c85d1c22 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -5,7 +5,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields from functools import cached_property -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -48,6 +48,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import (MistralTokenizer, cached_tokenizer_from_config) +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, @@ -68,15 +69,20 @@ except ImportError: PATCH_MERGE = "patch_merge" -class PixtralImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - - images: Union[torch.Tensor, list[torch.Tensor]] +class PixtralImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, num_channels, image_width, image_height)` - + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image + The result of stacking `ImageEncoding.tokens` from each prompt. """ + type: Literal["pixel_values"] = "pixel_values" + + images: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"})] class PixtralProcessorAdapter: @@ -381,10 +387,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, if images is None: return None - if not isinstance(images, (torch.Tensor, list)): - raise ValueError("Incorrect type of images. " - f"Got type: {type(images)}") - return PixtralImagePixelInputs( type="pixel_values", images=flatten_bn(images),