mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 21:35:00 +08:00
Migrate Qwen inputs to TensorSchema (#23473)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
parent
5da4f5d857
commit
a69693e38f
@ -11,7 +11,7 @@ import math
|
|||||||
import unicodedata
|
import unicodedata
|
||||||
from collections.abc import Collection, Mapping, Sequence, Set
|
from collections.abc import Collection, Mapping, Sequence, Set
|
||||||
from functools import lru_cache, partial
|
from functools import lru_cache, partial
|
||||||
from typing import Callable, Literal, Optional, TypedDict, Union
|
from typing import Annotated, Callable, Literal, Optional, Union
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
PromptUpdate, PromptUpdateDetails)
|
PromptUpdate, PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP)
|
||||||
@ -47,26 +48,34 @@ from .qwen import QWenBaseModel, QWenModel
|
|||||||
from .utils import flatten_bn, merge_multimodal_embeddings
|
from .utils import flatten_bn, merge_multimodal_embeddings
|
||||||
|
|
||||||
|
|
||||||
class QwenImagePixelInputs(TypedDict):
|
class QwenImagePixelInputs(TensorSchema):
|
||||||
type: Literal["pixel_values"]
|
|
||||||
data: torch.Tensor
|
|
||||||
"""
|
"""
|
||||||
Shape: `(batch_size * num_images, 3, image_size, image_size)`
|
Dimensions:
|
||||||
|
- bn: Batch size * number of images
|
||||||
|
- c: Number of channels (3)
|
||||||
|
- h: Height
|
||||||
|
- w: Width
|
||||||
|
|
||||||
Note that image_size is the value in the vision config to which we resize
|
Note that image_size is the value in the vision config to which we resize
|
||||||
the image to in the normalization transform. Currently multi-image support
|
the image to in the normalization transform. Currently multi-image support
|
||||||
can only be leveraged by passing image embeddings directly.
|
can only be leveraged by passing image embeddings directly.
|
||||||
"""
|
"""
|
||||||
|
type: Literal["pixel_values"] = "pixel_values"
|
||||||
|
data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
|
||||||
|
|
||||||
|
|
||||||
class QwenImageEmbeddingInputs(TypedDict):
|
class QwenImageEmbeddingInputs(TensorSchema):
|
||||||
type: Literal["image_embeds"]
|
"""
|
||||||
data: torch.Tensor
|
Dimensions:
|
||||||
"""Shape: `(batch_size * num_images, 256, hidden_size)`
|
- bn: Batch size * number of images
|
||||||
|
- ifs: Image feature size (256)
|
||||||
|
- hs: Hidden size
|
||||||
|
|
||||||
`hidden_size` must match the hidden size of the language model backbone
|
`hidden_size` must match the hidden size of the language model backbone
|
||||||
and is stored in the visual config of the model if we have one.
|
and is stored in the visual config of the model if we have one.
|
||||||
"""
|
"""
|
||||||
|
type: Literal["image_embeds"] = "image_embeds"
|
||||||
|
data: Annotated[torch.Tensor, TensorShape("bn", 256, "hs")]
|
||||||
|
|
||||||
|
|
||||||
QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs]
|
QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs]
|
||||||
@ -697,19 +706,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
|||||||
|
|
||||||
self.transformer: QwenVLModel
|
self.transformer: QwenVLModel
|
||||||
|
|
||||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
|
||||||
h = w = self.config.visual["image_size"]
|
|
||||||
expected_dims = (3, h, w)
|
|
||||||
actual_dims = tuple(data.shape[1:])
|
|
||||||
|
|
||||||
if actual_dims != expected_dims:
|
|
||||||
expected_expr = ("batch_size", *map(str, expected_dims))
|
|
||||||
raise ValueError(
|
|
||||||
f"The expected shape of pixel values is {expected_expr}. "
|
|
||||||
f"You supplied {tuple(data.shape)}.")
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[QwenImageInputs]:
|
self, **kwargs: object) -> Optional[QwenImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
@ -720,10 +716,13 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
|||||||
raise ValueError("Incorrect type of pixel values. "
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
f"Got type: {type(pixel_values)}")
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
expected_h = expected_w = self.config.visual["image_size"]
|
||||||
|
resolve_bindings = {"h": expected_h, "w": expected_w}
|
||||||
|
|
||||||
return QwenImagePixelInputs(
|
return QwenImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_pixel_values(
|
data=flatten_bn(pixel_values, concat=True),
|
||||||
flatten_bn(pixel_values, concat=True)),
|
resolve_bindings=resolve_bindings,
|
||||||
)
|
)
|
||||||
|
|
||||||
if image_embeds is not None:
|
if image_embeds is not None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user