mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-06 20:31:21 +08:00
[Model] Define merge_by_field_config MM interface (R-T) (#26260)
Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
185d8ed44f
commit
de342585ff
@ -4,7 +4,7 @@ import math
|
|||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from math import ceil, sqrt
|
from math import ceil, sqrt
|
||||||
from typing import Any, Literal, Optional, TypedDict, Union
|
from typing import Annotated, Any, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -44,28 +44,48 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
|
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
WeightsMapper,
|
WeightsMapper,
|
||||||
flatten_bn,
|
|
||||||
init_vllm_registered_model,
|
init_vllm_registered_model,
|
||||||
maybe_prefix,
|
maybe_prefix,
|
||||||
)
|
)
|
||||||
from .vision import run_dp_sharded_vision_model
|
from .vision import run_dp_sharded_vision_model
|
||||||
|
|
||||||
|
|
||||||
class Step3VLImagePixelInputs(TypedDict):
|
class Step3VLImagePixelInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- bn: Batch size * number of images
|
||||||
|
- c: Number of channels (3)
|
||||||
|
- h: Height
|
||||||
|
- w: Width
|
||||||
|
- bnp: Batch size * number of images * number of patches
|
||||||
|
- hp: Height of patch
|
||||||
|
- wp: Width of patch
|
||||||
|
"""
|
||||||
|
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
pixel_values: torch.Tensor
|
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
|
||||||
patch_pixel_values: Optional[torch.Tensor]
|
patch_pixel_values: Annotated[
|
||||||
num_patches: list[int]
|
Optional[torch.Tensor], TensorShape("bnp", 3, "hp", "wp")
|
||||||
|
]
|
||||||
|
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||||
|
|
||||||
|
|
||||||
class Step3VLImageEmbeddingInputs(TypedDict):
|
class Step3VLImageEmbeddingInputs(TensorSchema):
|
||||||
type: Literal["image_embeds"]
|
"""
|
||||||
image_embeds: torch.Tensor
|
Dimensions:
|
||||||
|
- bn: Batch size * number of images
|
||||||
|
- f: Image feature size
|
||||||
|
- h: Hidden size (must match the hidden size of language model backbone)
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["image_embeds"] = "image_embeds"
|
||||||
|
data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
|
||||||
|
|
||||||
|
|
||||||
Step3VLImageInputs = Union[Step3VLImagePixelInputs, Step3VLImageEmbeddingInputs]
|
Step3VLImageInputs = Union[Step3VLImagePixelInputs, Step3VLImageEmbeddingInputs]
|
||||||
@ -895,6 +915,8 @@ class Step3VisionTransformer(nn.Module):
|
|||||||
dummy_inputs=Step3VLDummyInputsBuilder,
|
dummy_inputs=Step3VLDummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={
|
orig_to_new_prefix={
|
||||||
"model.": "language_model.model.",
|
"model.": "language_model.model.",
|
||||||
@ -982,41 +1004,22 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
|
||||||
if pixel_values.dim() >= 3:
|
|
||||||
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
|
|
||||||
if patch_pixel_values is not None:
|
|
||||||
patch_pixel_values = flatten_bn(patch_pixel_values, concat=True)
|
|
||||||
patch_pixel_values = patch_pixel_values.view(
|
|
||||||
-1, *patch_pixel_values.shape[-3:]
|
|
||||||
)
|
|
||||||
# Handle empty patch_pixel_values by setting to None
|
|
||||||
if patch_pixel_values.shape[0] == 0:
|
|
||||||
patch_pixel_values = None
|
|
||||||
num_patches = flatten_bn(num_patches, concat=True).tolist()
|
|
||||||
|
|
||||||
return Step3VLImagePixelInputs(
|
return Step3VLImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values=pixel_values.to(self.dtype).to(self.device),
|
pixel_values=pixel_values.to(self.dtype),
|
||||||
patch_pixel_values=patch_pixel_values.to(self.dtype).to(self.device)
|
patch_pixel_values=patch_pixel_values.to(self.dtype)
|
||||||
if patch_pixel_values is not None
|
if patch_pixel_values is not None
|
||||||
else None,
|
else None,
|
||||||
num_patches=num_patches,
|
num_patches=num_patches,
|
||||||
)
|
)
|
||||||
|
|
||||||
if image_embeds is not None:
|
if image_embeds is not None:
|
||||||
if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
|
|
||||||
image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unexpected shape for image_embeds: {image_embeds.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return Step3VLImageEmbeddingInputs(
|
return Step3VLImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
image_embeds=image_embeds.to(self.dtype).to(self.device),
|
image_embeds=image_embeds.to(self.dtype),
|
||||||
)
|
)
|
||||||
return None
|
|
||||||
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
|
def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||||
B, P = image_features.shape[:2]
|
B, P = image_features.shape[:2]
|
||||||
|
|||||||
@ -47,11 +47,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
from .clip import CLIPVisionModel
|
from .clip import CLIPVisionModel
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (
|
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||||
AutoWeightsLoader,
|
|
||||||
init_vllm_registered_model,
|
|
||||||
maybe_prefix,
|
|
||||||
)
|
|
||||||
from .vision import (
|
from .vision import (
|
||||||
VisionEncoderInfo,
|
VisionEncoderInfo,
|
||||||
get_num_selected_vision_tokens,
|
get_num_selected_vision_tokens,
|
||||||
|
|||||||
@ -87,12 +87,10 @@ def _terratorch_field_factory(
|
|||||||
if input.type == InputTypeEnum.tensor:
|
if input.type == InputTypeEnum.tensor:
|
||||||
fields[input_name] = "image"
|
fields[input_name] = "image"
|
||||||
|
|
||||||
mm_fields_config = {}
|
return {
|
||||||
for field_name, field_modality in fields.items():
|
field_name: MultiModalFieldConfig.batched(modality=field_modality)
|
||||||
mm_fields_config[field_name] = MultiModalFieldConfig.shared(
|
for field_name, field_modality in fields.items()
|
||||||
batch_size=1, modality=field_modality
|
}
|
||||||
)
|
|
||||||
return mm_fields_config
|
|
||||||
|
|
||||||
return _terratorch_field_config
|
return _terratorch_field_config
|
||||||
|
|
||||||
@ -192,9 +190,12 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
|||||||
) -> MultiModalInputs:
|
) -> MultiModalInputs:
|
||||||
if "image" in mm_data:
|
if "image" in mm_data:
|
||||||
image_data = mm_data["image"]
|
image_data = mm_data["image"]
|
||||||
|
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
|
||||||
else:
|
else:
|
||||||
image_data = mm_data
|
image_data = mm_data
|
||||||
mm_data = {"image": mm_data}
|
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
|
||||||
|
|
||||||
|
mm_data = {"image": image_data}
|
||||||
|
|
||||||
mm_items = self._to_mm_items(mm_data)
|
mm_items = self._to_mm_items(mm_data)
|
||||||
tokenization_kwargs = tokenization_kwargs or {}
|
tokenization_kwargs = tokenization_kwargs or {}
|
||||||
@ -226,6 +227,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
|||||||
dummy_inputs=TerratorchInputBuilder,
|
dummy_inputs=TerratorchInputBuilder,
|
||||||
)
|
)
|
||||||
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||||
|
merge_by_field_config = True
|
||||||
supports_multimodal_raw_input_only = True
|
supports_multimodal_raw_input_only = True
|
||||||
is_pooling_model = True
|
is_pooling_model = True
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user