mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-29 16:27:26 +08:00
Migrate Pixtral inputs to TensorSchema (#23472)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
parent
c55c028998
commit
053278a5dc
@ -5,7 +5,7 @@ import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import cached_property
|
||||
from typing import Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -48,6 +48,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||
cached_tokenizer_from_config)
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
||||
@ -68,15 +69,20 @@ except ImportError:
|
||||
PATCH_MERGE = "patch_merge"
|
||||
|
||||
|
||||
class PixtralImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
|
||||
images: Union[torch.Tensor, list[torch.Tensor]]
|
||||
class PixtralImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Shape: `(batch_size * num_images, num_channels, image_width, image_height)`
|
||||
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- c: Number of channels (3)
|
||||
- h: Height of each image
|
||||
- w: Width of each image
|
||||
|
||||
The result of stacking `ImageEncoding.tokens` from each prompt.
|
||||
"""
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
|
||||
images: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"})]
|
||||
|
||||
|
||||
class PixtralProcessorAdapter:
|
||||
@ -381,10 +387,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if images is None:
|
||||
return None
|
||||
|
||||
if not isinstance(images, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of images. "
|
||||
f"Got type: {type(images)}")
|
||||
|
||||
return PixtralImagePixelInputs(
|
||||
type="pixel_values",
|
||||
images=flatten_bn(images),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user