mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 20:45:01 +08:00
Migrate MolmoImageInputs to TensorSchema (#22022)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
parent
e0b056e443
commit
a482e4e769
@ -5,7 +5,7 @@ import math
|
|||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property, partial
|
from functools import cached_property, partial
|
||||||
from typing import Optional, TypedDict, Union
|
from typing import Annotated, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -51,6 +51,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
PromptUpdateDetails)
|
PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP, SupportsQuant)
|
SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||||
@ -70,23 +71,25 @@ IM_END_TOKEN = "<im_end>"
|
|||||||
POOLING_SIZE = 2
|
POOLING_SIZE = 2
|
||||||
|
|
||||||
|
|
||||||
class MolmoImageInputs(TypedDict):
|
class MolmoImageInputs(TensorSchema):
|
||||||
images: Union[torch.Tensor, list[torch.Tensor]]
|
|
||||||
"""Shape: `(batch_size * num_images, num_crops, num_patch, patch_dim)`"""
|
|
||||||
|
|
||||||
image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]]
|
|
||||||
"""Shape: `(batch_size * num_images, num_crops, num_patch)`"""
|
|
||||||
|
|
||||||
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
|
||||||
"""
|
"""
|
||||||
A boolean mask indicating which image features correspond
|
Dimensions:
|
||||||
to patch tokens.
|
- bn: Batch size * number of images
|
||||||
|
- nc: Number of crops
|
||||||
Shape: `(batch_size * num_images, num_crops, num_patch)`
|
- np: Number of patches
|
||||||
|
- pd: Patch dimension
|
||||||
"""
|
"""
|
||||||
|
images: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||||
|
TensorShape("bn", "nc", "np", "pd")]
|
||||||
|
|
||||||
num_crops: torch.Tensor
|
image_masks: Annotated[Optional[Union[torch.Tensor, list[torch.Tensor]]],
|
||||||
"""Shape: `(batch_size * num_images)`"""
|
TensorShape("bn", "nc", "np")]
|
||||||
|
|
||||||
|
feat_is_patch: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||||
|
TensorShape("bn", "nc", "np")]
|
||||||
|
# A boolean mask indicating which image features correspond to patch tokens.
|
||||||
|
|
||||||
|
num_crops: Annotated[torch.Tensor, TensorShape("bn")]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -1410,28 +1413,17 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
) -> Optional[MolmoImageInputs]:
|
) -> Optional[MolmoImageInputs]:
|
||||||
images = kwargs.pop("images", None)
|
images = kwargs.pop("images", None)
|
||||||
|
image_masks = kwargs.pop("image_masks", None)
|
||||||
|
feat_is_patch = kwargs.pop("feat_is_patch", None)
|
||||||
|
num_crops = kwargs.pop("num_crops", None)
|
||||||
|
|
||||||
if images is None:
|
if images is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not isinstance(images, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of images. "
|
|
||||||
f"Got type: {type(images)}")
|
|
||||||
|
|
||||||
image_masks = kwargs.pop("image_masks", None)
|
|
||||||
if not (image_masks is None or isinstance(image_masks,
|
|
||||||
(torch.Tensor, list))):
|
|
||||||
raise ValueError("Incorrect type of image_masks. "
|
|
||||||
f"Got type: {type(image_masks)}")
|
|
||||||
|
|
||||||
feat_is_patch = kwargs.pop("feat_is_patch", None)
|
|
||||||
if not isinstance(feat_is_patch, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of feat_is_patch. "
|
|
||||||
f"Got type: {type(feat_is_patch)}")
|
|
||||||
|
|
||||||
num_crops = kwargs.pop("num_crops", None)
|
|
||||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of num_crops. "
|
raise ValueError("Incorrect type of num_crops. "
|
||||||
f"Got type: {type(num_crops)}")
|
f"Got type: {type(num_crops)}")
|
||||||
|
num_crops = flatten_bn(num_crops, concat=True)
|
||||||
|
|
||||||
img_patch_id = kwargs.pop("img_patch_id", None)
|
img_patch_id = kwargs.pop("img_patch_id", None)
|
||||||
if not isinstance(img_patch_id, torch.Tensor):
|
if not isinstance(img_patch_id, torch.Tensor):
|
||||||
@ -1439,8 +1431,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
f"Got type: {type(img_patch_id)}")
|
f"Got type: {type(img_patch_id)}")
|
||||||
self.img_patch_id = img_patch_id.flatten().unique().item()
|
self.img_patch_id = img_patch_id.flatten().unique().item()
|
||||||
|
|
||||||
num_crops = flatten_bn(num_crops, concat=True)
|
|
||||||
|
|
||||||
return MolmoImageInputs(
|
return MolmoImageInputs(
|
||||||
images=images,
|
images=images,
|
||||||
image_masks=image_masks,
|
image_masks=image_masks,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user