mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 15:44:30 +08:00
[Model] Define merge_by_field_config MM interface (R-T) (#26260)
Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
185d8ed44f
commit
de342585ff
@ -4,7 +4,7 @@ import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from itertools import product
|
||||
from math import ceil, sqrt
|
||||
from typing import Any, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -44,28 +44,48 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
|
||||
|
||||
class Step3VLImagePixelInputs(TypedDict):
|
||||
class Step3VLImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- c: Number of channels (3)
|
||||
- h: Height
|
||||
- w: Width
|
||||
- bnp: Batch size * number of images * number of patches
|
||||
- hp: Height of patch
|
||||
- wp: Width of patch
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: torch.Tensor
|
||||
patch_pixel_values: Optional[torch.Tensor]
|
||||
num_patches: list[int]
|
||||
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
|
||||
patch_pixel_values: Annotated[
|
||||
Optional[torch.Tensor], TensorShape("bnp", 3, "hp", "wp")
|
||||
]
|
||||
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||
|
||||
|
||||
class Step3VLImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
image_embeds: torch.Tensor
|
||||
class Step3VLImageEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- f: Image feature size
|
||||
- h: Hidden size (must match the hidden size of language model backbone)
|
||||
"""
|
||||
|
||||
type: Literal["image_embeds"] = "image_embeds"
|
||||
data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
|
||||
|
||||
|
||||
Step3VLImageInputs = Union[Step3VLImagePixelInputs, Step3VLImageEmbeddingInputs]
|
||||
@ -895,6 +915,8 @@ class Step3VisionTransformer(nn.Module):
|
||||
dummy_inputs=Step3VLDummyInputsBuilder,
|
||||
)
|
||||
class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.": "language_model.model.",
|
||||
@ -982,41 +1004,22 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
if pixel_values.dim() >= 3:
|
||||
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
|
||||
if patch_pixel_values is not None:
|
||||
patch_pixel_values = flatten_bn(patch_pixel_values, concat=True)
|
||||
patch_pixel_values = patch_pixel_values.view(
|
||||
-1, *patch_pixel_values.shape[-3:]
|
||||
)
|
||||
# Handle empty patch_pixel_values by setting to None
|
||||
if patch_pixel_values.shape[0] == 0:
|
||||
patch_pixel_values = None
|
||||
num_patches = flatten_bn(num_patches, concat=True).tolist()
|
||||
|
||||
return Step3VLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values.to(self.dtype).to(self.device),
|
||||
patch_pixel_values=patch_pixel_values.to(self.dtype).to(self.device)
|
||||
pixel_values=pixel_values.to(self.dtype),
|
||||
patch_pixel_values=patch_pixel_values.to(self.dtype)
|
||||
if patch_pixel_values is not None
|
||||
else None,
|
||||
num_patches=num_patches,
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
|
||||
image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected shape for image_embeds: {image_embeds.shape}"
|
||||
)
|
||||
|
||||
return Step3VLImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
image_embeds=image_embeds.to(self.dtype).to(self.device),
|
||||
image_embeds=image_embeds.to(self.dtype),
|
||||
)
|
||||
return None
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||
B, P = image_features.shape[:2]
|
||||
|
||||
@ -47,11 +47,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
from .vision import (
|
||||
VisionEncoderInfo,
|
||||
get_num_selected_vision_tokens,
|
||||
|
||||
@ -87,12 +87,10 @@ def _terratorch_field_factory(
|
||||
if input.type == InputTypeEnum.tensor:
|
||||
fields[input_name] = "image"
|
||||
|
||||
mm_fields_config = {}
|
||||
for field_name, field_modality in fields.items():
|
||||
mm_fields_config[field_name] = MultiModalFieldConfig.shared(
|
||||
batch_size=1, modality=field_modality
|
||||
)
|
||||
return mm_fields_config
|
||||
return {
|
||||
field_name: MultiModalFieldConfig.batched(modality=field_modality)
|
||||
for field_name, field_modality in fields.items()
|
||||
}
|
||||
|
||||
return _terratorch_field_config
|
||||
|
||||
@ -192,9 +190,12 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
||||
) -> MultiModalInputs:
|
||||
if "image" in mm_data:
|
||||
image_data = mm_data["image"]
|
||||
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
|
||||
else:
|
||||
image_data = mm_data
|
||||
mm_data = {"image": mm_data}
|
||||
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
|
||||
|
||||
mm_data = {"image": image_data}
|
||||
|
||||
mm_items = self._to_mm_items(mm_data)
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
@ -226,6 +227,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
||||
dummy_inputs=TerratorchInputBuilder,
|
||||
)
|
||||
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
merge_by_field_config = True
|
||||
supports_multimodal_raw_input_only = True
|
||||
is_pooling_model = True
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user