[Model] Define merge_by_field_config MM interface (R-T) (#26260)

Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Ayush Satyam 2025-10-07 13:40:55 +05:30 committed by GitHub
parent 185d8ed44f
commit de342585ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 45 deletions

View File

@ -4,7 +4,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence
from itertools import product
from math import ceil, sqrt
from typing import Any, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Literal, Optional, Union
import numpy as np
import torch
@ -44,28 +44,48 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (
AutoWeightsLoader,
WeightsMapper,
flatten_bn,
init_vllm_registered_model,
maybe_prefix,
)
from .vision import run_dp_sharded_vision_model
class Step3VLImagePixelInputs(TypedDict):
class Step3VLImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height
- w: Width
- bnp: Batch size * number of images * number of patches
- hp: Height of patch
- wp: Width of patch
"""
type: Literal["pixel_values"]
pixel_values: torch.Tensor
patch_pixel_values: Optional[torch.Tensor]
num_patches: list[int]
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
patch_pixel_values: Annotated[
Optional[torch.Tensor], TensorShape("bnp", 3, "hp", "wp")
]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
class Step3VLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
image_embeds: torch.Tensor
class Step3VLImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- f: Image feature size
- h: Hidden size (must match the hidden size of language model backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
Step3VLImageInputs = Union[Step3VLImagePixelInputs, Step3VLImageEmbeddingInputs]
@ -895,6 +915,8 @@ class Step3VisionTransformer(nn.Module):
dummy_inputs=Step3VLDummyInputsBuilder,
)
class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.": "language_model.model.",
@ -982,41 +1004,22 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
return None
if pixel_values is not None:
pixel_values = flatten_bn(pixel_values, concat=True)
if pixel_values.dim() >= 3:
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
if patch_pixel_values is not None:
patch_pixel_values = flatten_bn(patch_pixel_values, concat=True)
patch_pixel_values = patch_pixel_values.view(
-1, *patch_pixel_values.shape[-3:]
)
# Handle empty patch_pixel_values by setting to None
if patch_pixel_values.shape[0] == 0:
patch_pixel_values = None
num_patches = flatten_bn(num_patches, concat=True).tolist()
return Step3VLImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values.to(self.dtype).to(self.device),
patch_pixel_values=patch_pixel_values.to(self.dtype).to(self.device)
pixel_values=pixel_values.to(self.dtype),
patch_pixel_values=patch_pixel_values.to(self.dtype)
if patch_pixel_values is not None
else None,
num_patches=num_patches,
)
if image_embeds is not None:
if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
else:
raise ValueError(
f"Unexpected shape for image_embeds: {image_embeds.shape}"
)
return Step3VLImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds.to(self.dtype).to(self.device),
image_embeds=image_embeds.to(self.dtype),
)
return None
raise AssertionError("This line should be unreachable.")
def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
B, P = image_features.shape[:2]

View File

@ -47,11 +47,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (
AutoWeightsLoader,
init_vllm_registered_model,
maybe_prefix,
)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
from .vision import (
VisionEncoderInfo,
get_num_selected_vision_tokens,

View File

@ -87,12 +87,10 @@ def _terratorch_field_factory(
if input.type == InputTypeEnum.tensor:
fields[input_name] = "image"
mm_fields_config = {}
for field_name, field_modality in fields.items():
mm_fields_config[field_name] = MultiModalFieldConfig.shared(
batch_size=1, modality=field_modality
)
return mm_fields_config
return {
field_name: MultiModalFieldConfig.batched(modality=field_modality)
for field_name, field_modality in fields.items()
}
return _terratorch_field_config
@ -192,9 +190,12 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
) -> MultiModalInputs:
if "image" in mm_data:
image_data = mm_data["image"]
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
else:
image_data = mm_data
mm_data = {"image": mm_data}
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
mm_data = {"image": image_data}
mm_items = self._to_mm_items(mm_data)
tokenization_kwargs = tokenization_kwargs or {}
@ -226,6 +227,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
dummy_inputs=TerratorchInputBuilder,
)
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
merge_by_field_config = True
supports_multimodal_raw_input_only = True
is_pooling_model = True