Migrate Pixtral inputs to TensorSchema (#23472)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-08-23 21:55:53 -07:00 committed by GitHub
parent c55c028998
commit 053278a5dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,7 +5,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields
from functools import cached_property
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
@ -48,6 +48,7 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
@ -68,15 +69,20 @@ except ImportError:
PATCH_MERGE = "patch_merge"
class PixtralImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
images: Union[torch.Tensor, list[torch.Tensor]]
class PixtralImagePixelInputs(TensorSchema):
"""
Shape: `(batch_size * num_images, num_channels, image_width, image_height)`
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
The result of stacking `ImageEncoding.tokens` from each prompt.
"""
type: Literal["pixel_values"] = "pixel_values"
images: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"})]
class PixtralProcessorAdapter:
@ -381,10 +387,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
if images is None:
return None
if not isinstance(images, (torch.Tensor, list)):
raise ValueError("Incorrect type of images. "
f"Got type: {type(images)}")
return PixtralImagePixelInputs(
type="pixel_values",
images=flatten_bn(images),