mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 22:35:30 +08:00
Migrate ChameleonImagePixelInputs to TensorSchema (#21657)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
parent
3339cba3ff
commit
20950b29fb
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Literal, Optional, TypedDict, Union
|
from typing import Annotated, Any, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -38,6 +38,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
PromptUpdate, PromptUpdateDetails)
|
PromptUpdate, PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
||||||
SupportsQuant)
|
SupportsQuant)
|
||||||
@ -48,10 +49,16 @@ from .utils import (flatten_bn, is_pp_missing_parameter,
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ChameleonImagePixelInputs(TypedDict):
|
class ChameleonImagePixelInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- bn: Batch size * number of images
|
||||||
|
- c: Number of channels (3)
|
||||||
|
- h: Height of each image
|
||||||
|
- w: Width of each image
|
||||||
|
"""
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
data: torch.Tensor
|
data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
|
||||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
|
||||||
|
|
||||||
|
|
||||||
class ChameleonProcessingInfo(BaseProcessingInfo):
|
class ChameleonProcessingInfo(BaseProcessingInfo):
|
||||||
@ -962,19 +969,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.model.make_empty_intermediate_tensors)
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
|
||||||
vq_config: ChameleonVQVAEConfig = self.config.vq_config
|
|
||||||
expected_dims = (3, vq_config.resolution, vq_config.resolution)
|
|
||||||
actual_dims = tuple(data.shape[1:])
|
|
||||||
|
|
||||||
if actual_dims != expected_dims:
|
|
||||||
expected_expr = ("batch_size", *map(str, expected_dims))
|
|
||||||
raise ValueError(
|
|
||||||
f"The expected shape of pixel values is {expected_expr}. "
|
|
||||||
f"You supplied {tuple(data.shape)}.")
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]:
|
self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
@ -982,16 +976,16 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
vq_config: ChameleonVQVAEConfig = self.config.vq_config
|
||||||
raise ValueError("Incorrect type of pixel values. "
|
expected_h = expected_w = vq_config.resolution
|
||||||
f"Got type: {type(pixel_values)}")
|
|
||||||
|
|
||||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
return ChameleonImagePixelInputs(type="pixel_values",
|
||||||
|
data=flatten_bn(pixel_values,
|
||||||
return ChameleonImagePixelInputs(
|
concat=True),
|
||||||
type="pixel_values",
|
resolve_bindings={
|
||||||
data=self._validate_pixel_values(pixel_values),
|
"h": expected_h,
|
||||||
)
|
"w": expected_w
|
||||||
|
})
|
||||||
|
|
||||||
def get_language_model(self) -> torch.nn.Module:
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
return self.model
|
return self.model
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user