mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 12:25:41 +08:00
Migrate GLMVImagePixelInputs to TensorSchema (#21679)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
parent
88e46c7c8d
commit
304dcdf575
@ -6,7 +6,7 @@
|
||||
"""Inference-only CogAgent model compatible with THUDM weights."""
|
||||
from argparse import Namespace
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -38,6 +38,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .chatglm import ChatGLMBaseModel, ChatGLMModel
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
@ -45,10 +46,16 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
from .utils import flatten_bn, merge_multimodal_embeddings
|
||||
|
||||
|
||||
class GLMVImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size, num_channels, height, width)`"""
|
||||
class GLMVImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- c: Number of channels (3)
|
||||
- h: Height of image
|
||||
- w: Width of image
|
||||
"""
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")]
|
||||
|
||||
|
||||
class EVA2CLIPPatchEmbedding(nn.Module):
|
||||
@ -562,19 +569,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
|
||||
self.transformer: GLM4VModel
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.vision_config["image_size"]
|
||||
expected_dims = (3, h, w)
|
||||
actual_dims = tuple(data.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("batch_size", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
f"The expected shape of pixel values is {expected_expr}. "
|
||||
f"You supplied {tuple(data.shape)}.")
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[GLMVImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@ -584,11 +578,14 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return GLMVImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values, concat=True)),
|
||||
)
|
||||
expected_h = expected_w = self.config.vision_config["image_size"]
|
||||
return GLMVImagePixelInputs(type="pixel_values",
|
||||
data=flatten_bn(pixel_values,
|
||||
concat=True),
|
||||
resolve_bindings={
|
||||
"h": expected_h,
|
||||
"w": expected_w
|
||||
})
|
||||
|
||||
return None
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user