[Model] Define merge_by_field_config MM interface (R-T) (#26260)

Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Ayush Satyam 2025-10-07 13:40:55 +05:30 committed by GitHub
parent 185d8ed44f
commit de342585ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 45 deletions

View File

@ -4,7 +4,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from itertools import product from itertools import product
from math import ceil, sqrt from math import ceil, sqrt
from typing import Any, Literal, Optional, TypedDict, Union from typing import Annotated, Any, Literal, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -44,28 +44,48 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Step3VisionEncoderConfig from vllm.transformers_utils.configs import Step3VisionEncoderConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper, WeightsMapper,
flatten_bn,
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
from .vision import run_dp_sharded_vision_model from .vision import run_dp_sharded_vision_model
class Step3VLImagePixelInputs(TypedDict): class Step3VLImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height
- w: Width
- bnp: Batch size * number of images * number of patches
- hp: Height of patch
- wp: Width of patch
"""
type: Literal["pixel_values"] type: Literal["pixel_values"]
pixel_values: torch.Tensor pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
patch_pixel_values: Optional[torch.Tensor] patch_pixel_values: Annotated[
num_patches: list[int] Optional[torch.Tensor], TensorShape("bnp", 3, "hp", "wp")
]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
class Step3VLImageEmbeddingInputs(TypedDict): class Step3VLImageEmbeddingInputs(TensorSchema):
type: Literal["image_embeds"] """
image_embeds: torch.Tensor Dimensions:
- bn: Batch size * number of images
- f: Image feature size
- h: Hidden size (must match the hidden size of language model backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
Step3VLImageInputs = Union[Step3VLImagePixelInputs, Step3VLImageEmbeddingInputs] Step3VLImageInputs = Union[Step3VLImagePixelInputs, Step3VLImageEmbeddingInputs]
@ -895,6 +915,8 @@ class Step3VisionTransformer(nn.Module):
dummy_inputs=Step3VLDummyInputsBuilder, dummy_inputs=Step3VLDummyInputsBuilder,
) )
class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
"model.": "language_model.model.", "model.": "language_model.model.",
@ -982,41 +1004,22 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
return None return None
if pixel_values is not None: if pixel_values is not None:
pixel_values = flatten_bn(pixel_values, concat=True)
if pixel_values.dim() >= 3:
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
if patch_pixel_values is not None:
patch_pixel_values = flatten_bn(patch_pixel_values, concat=True)
patch_pixel_values = patch_pixel_values.view(
-1, *patch_pixel_values.shape[-3:]
)
# Handle empty patch_pixel_values by setting to None
if patch_pixel_values.shape[0] == 0:
patch_pixel_values = None
num_patches = flatten_bn(num_patches, concat=True).tolist()
return Step3VLImagePixelInputs( return Step3VLImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=pixel_values.to(self.dtype).to(self.device), pixel_values=pixel_values.to(self.dtype),
patch_pixel_values=patch_pixel_values.to(self.dtype).to(self.device) patch_pixel_values=patch_pixel_values.to(self.dtype)
if patch_pixel_values is not None if patch_pixel_values is not None
else None, else None,
num_patches=num_patches, num_patches=num_patches,
) )
if image_embeds is not None: if image_embeds is not None:
if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
else:
raise ValueError(
f"Unexpected shape for image_embeds: {image_embeds.shape}"
)
return Step3VLImageEmbeddingInputs( return Step3VLImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
image_embeds=image_embeds.to(self.dtype).to(self.device), image_embeds=image_embeds.to(self.dtype),
) )
return None
raise AssertionError("This line should be unreachable.")
def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor: def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
B, P = image_features.shape[:2] B, P = image_features.shape[:2]

View File

@ -47,11 +47,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import ( from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
AutoWeightsLoader,
init_vllm_registered_model,
maybe_prefix,
)
from .vision import ( from .vision import (
VisionEncoderInfo, VisionEncoderInfo,
get_num_selected_vision_tokens, get_num_selected_vision_tokens,

View File

@ -87,12 +87,10 @@ def _terratorch_field_factory(
if input.type == InputTypeEnum.tensor: if input.type == InputTypeEnum.tensor:
fields[input_name] = "image" fields[input_name] = "image"
mm_fields_config = {} return {
for field_name, field_modality in fields.items(): field_name: MultiModalFieldConfig.batched(modality=field_modality)
mm_fields_config[field_name] = MultiModalFieldConfig.shared( for field_name, field_modality in fields.items()
batch_size=1, modality=field_modality }
)
return mm_fields_config
return _terratorch_field_config return _terratorch_field_config
@ -192,9 +190,12 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
) -> MultiModalInputs: ) -> MultiModalInputs:
if "image" in mm_data: if "image" in mm_data:
image_data = mm_data["image"] image_data = mm_data["image"]
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
else: else:
image_data = mm_data image_data = mm_data
mm_data = {"image": mm_data} image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
mm_data = {"image": image_data}
mm_items = self._to_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
tokenization_kwargs = tokenization_kwargs or {} tokenization_kwargs = tokenization_kwargs or {}
@ -226,6 +227,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
dummy_inputs=TerratorchInputBuilder, dummy_inputs=TerratorchInputBuilder,
) )
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
merge_by_field_config = True
supports_multimodal_raw_input_only = True supports_multimodal_raw_input_only = True
is_pooling_model = True is_pooling_model = True