mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 15:56:07 +08:00
Migrate Mistral3ImagePixelInputs to TensorSchema (#21945)
Signed-off-by: Benji Beck <benjibeck@meta.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
dfd2382039
commit
c4477f55e5
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
|
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
|
||||||
Union)
|
Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -32,6 +32,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
PromptUpdateDetails)
|
PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP)
|
||||||
@ -42,15 +43,23 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
|||||||
from .vision import get_vision_encoder_info
|
from .vision import get_vision_encoder_info
|
||||||
|
|
||||||
|
|
||||||
class Mistral3ImagePixelInputs(TypedDict):
|
class Mistral3ImagePixelInputs(TensorSchema):
|
||||||
type: Literal["pixel_values_pixtral"]
|
"""
|
||||||
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
|
Dimensions:
|
||||||
|
- bn: Batch size * number of images
|
||||||
|
- c: Number of channels (3)
|
||||||
|
- h: Height of each image
|
||||||
|
- w: Width of each image
|
||||||
"""
|
"""
|
||||||
Shape: `(batch_size * num_images, num_channels, height, width)`
|
|
||||||
|
|
||||||
Note that `height` or `width` may be different per batch and image,
|
type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
|
||||||
in which case the data is passed as a list instead of a batched tensor.
|
|
||||||
"""
|
# Note that `height` or `width` may be different per batch and image,
|
||||||
|
# in which case the data is passed as a list instead of a batched tensor.
|
||||||
|
pixel_values: Annotated[
|
||||||
|
Union[torch.Tensor, list[torch.Tensor]],
|
||||||
|
TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Mistral3PatchMerger(nn.Module):
|
class Mistral3PatchMerger(nn.Module):
|
||||||
@ -456,19 +465,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.language_model.make_empty_intermediate_tensors)
|
self.language_model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
|
||||||
h = w = self.config.vision_config.image_size
|
|
||||||
expected_dims = (3, h, w)
|
|
||||||
actual_dims = tuple(data.shape[1:])
|
|
||||||
|
|
||||||
if actual_dims != expected_dims:
|
|
||||||
expected_expr = ("batch_size", *map(str, expected_dims))
|
|
||||||
raise ValueError(
|
|
||||||
f"The expected shape of pixel values is {expected_expr}. "
|
|
||||||
f"You supplied {tuple(data.shape)}.")
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[Mistral3ImagePixelInputs]:
|
self, **kwargs: object) -> Optional[Mistral3ImagePixelInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user