[Model] Use merge_by_field_config for MM models (Ovis family) (#26308)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-10-07 20:54:22 +08:00 committed by GitHub
parent 63773a6200
commit 08d26a1b7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 86 additions and 81 deletions

View File

@ -1140,14 +1140,10 @@ def run_ovis2_5(questions: list[str], modality: str) -> ModelRequestData:
elif modality == "video": elif modality == "video":
placeholder = "<video>" placeholder = "<video>"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) prompts = [
messages = [ f"<|im_start|>user\n\n{placeholder}\n{question}<|im_end|>\n<|im_start|>assistant\n"
[{"role": "user", "content": f"{placeholder}\n{question}"}]
for question in questions for question in questions
] ]
prompts = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,

View File

@ -713,11 +713,9 @@ def load_ovis2_5(question: str, image_urls: list[str]) -> ModelRequestData:
placeholders = "\n".join( placeholders = "\n".join(
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1) f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
) )
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}] prompt = (
f"<|im_start|>user\n\n{placeholders}\n{question}<|im_end|>\n"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) "<|im_start|>assistant\n"
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
) )
return ModelRequestData( return ModelRequestData(

View File

@ -217,17 +217,17 @@ class VisualTokenizer(torch.nn.Module):
class OvisImagePatchInputs(TensorSchema): class OvisImagePatchInputs(TensorSchema):
""" """
Dimensions: Dimensions:
- batch_patches: Batch size * number of patches - bnp: Batch size * number of images * number of patches
- patch_size: patch_size_x * patch_size_y * num_channels - h: Height of each patch
- w: Width of each patch
- patch_indicators: Batch size * (number of patches + 1) - patch_indicators: Batch size * (number of patches + 1)
- patches_per_image: List of number of total patches for each image - bn: Batch size * number of images
in the batch.
""" """
type: Literal["image_patches"] type: Literal["image_patches"]
flat_data: Annotated[torch.Tensor, TensorShape("batch_patches", "patch_size")] flat_data: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")] indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
patches_per_image: Annotated[list[int], TensorShape("num_patches_per_image")] patches_per_image: Annotated[list[int], TensorShape("bn")]
# This is used to restore the first two dimensions of `flat_data`. # This is used to restore the first two dimensions of `flat_data`.
@ -366,7 +366,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
self.image_indicators_to_visual_tokens(indicator) self.image_indicators_to_visual_tokens(indicator)
for indicator in image_indicators for indicator in image_indicators
] ]
processed_outputs["indicator_tokens"] = indicator_tokens processed_outputs["indicator_tokens"] = torch.tensor(indicator_tokens)
return processed_outputs return processed_outputs
def _apply_hf_processor_tokens_only( def _apply_hf_processor_tokens_only(
@ -414,6 +414,8 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
dummy_inputs=OvisDummyInputsBuilder, dummy_inputs=OvisDummyInputsBuilder,
) )
class Ovis(nn.Module, SupportsMultiModal, SupportsPP): class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"): if modality.startswith("image"):
@ -470,14 +472,11 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(pixel_values)}" f"Got type: {type(pixel_values)}"
) )
flat_data = flatten_bn(pixel_values, concat=True)
if flat_data.ndim >= 3:
flat_data = flat_data.flatten(start_dim=1)
return OvisImagePatchInputs( return OvisImagePatchInputs(
type="image_patches", type="image_patches",
flat_data=flat_data, flat_data=flatten_bn(pixel_values, concat=True),
patches_per_image=[x.shape[0] for x in flatten_bn(pixel_values)], patches_per_image=[x.shape[0] for x in pixel_values],
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True), indicator_tokens=flatten_bn(indicator_tokens, concat=True),
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")

View File

@ -4,7 +4,7 @@
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from functools import partial from functools import partial
from typing import Literal, Optional, TypedDict, Union from typing import Annotated, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -14,7 +14,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.ovis import OvisImagePatchInputs, VisualEmbedding from vllm.model_executor.models.ovis import VisualEmbedding
from vllm.model_executor.models.siglip2navit import Siglip2NavitModel from vllm.model_executor.models.siglip2navit import Siglip2NavitModel
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
AutoWeightsLoader, AutoWeightsLoader,
@ -37,6 +37,7 @@ from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -58,36 +59,38 @@ IMAGE_PAD_TOKEN_ID_MAP = {
} }
class OvisVideoPatchInputs(TypedDict): class Ovis2_5ImagePatchInputs(TensorSchema):
"""
Dimensions:
- bnp: Batch size * number of images * number of patches
- patch_size: patch_size_x * patch_size_y * num_channels
- patch_indicators: Batch size * (number of patches + 1)
- bn: Batch size * number of images
"""
type: Literal["image_patches"]
flat_data: Annotated[torch.Tensor, TensorShape("bnp", "patch_size")]
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
patches_per_item: Annotated[list[int], TensorShape("bn")]
grids: Annotated[torch.Tensor, TensorShape("bn", 3)]
# This is used to restore the first two dimensions of `flat_data`.
class Ovis2_5VideoPatchInputs(TensorSchema):
"""
Dimensions:
- bnp: Batch size * number of videos * number of patches
- patch_size: patch_size_x * patch_size_y * num_channels
- patch_indicators: Batch size * (number of patches + 1)
- bn: Batch size * number of videos
"""
type: Literal["video_patches"] type: Literal["video_patches"]
flat_data: torch.Tensor flat_data: Annotated[torch.Tensor, TensorShape("bnp", "patch_size")]
""" indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
Shape: patches_per_item: Annotated[list[int], TensorShape("bn")]
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` grids: Annotated[torch.Tensor, TensorShape("bn", 3)]
""" # This is used to restore the first two dimensions of `flat_data`.
indicator_tokens: torch.Tensor
"""
Shape:
`(batch_size * (num_patches + 1))`
"""
patches_per_image: list[int]
"""
List of number of total patches for each frame in the video.
This is used to restore the first two dimensions of `flat_data`.
"""
def _ovis2_5_field_config():
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
grids=MultiModalFieldConfig.batched("image"),
indicator_tokens=MultiModalFieldConfig.batched("image"),
video_pixel_values=MultiModalFieldConfig.batched("video"),
video_indicator_tokens=MultiModalFieldConfig.batched("video"),
video_grids=MultiModalFieldConfig.batched("video"),
)
class VisualTokenizer(torch.nn.Module): class VisualTokenizer(torch.nn.Module):
@ -380,7 +383,7 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo])
self.visual_indicators_to_visual_tokens(indicator) self.visual_indicators_to_visual_tokens(indicator)
for indicator in visual_indicators for indicator in visual_indicators
] ]
processed_outputs["video_indicator_tokens"] = indicator_tokens processed_outputs["video_indicator_tokens"] = torch.tensor(indicator_tokens)
if "images" in mm_data: if "images" in mm_data:
visual_indicators = [ visual_indicators = [
hf_processor.construct_visual_indicators((1, 1, 1), False) hf_processor.construct_visual_indicators((1, 1, 1), False)
@ -391,7 +394,7 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo])
for indicator in visual_indicators for indicator in visual_indicators
] ]
processed_outputs["indicator_tokens"] = indicator_tokens processed_outputs["indicator_tokens"] = torch.tensor(indicator_tokens)
return processed_outputs return processed_outputs
def _apply_hf_processor_tokens_only( def _apply_hf_processor_tokens_only(
@ -405,7 +408,14 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo])
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return _ovis2_5_field_config() return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
grids=MultiModalFieldConfig.batched("image"),
indicator_tokens=MultiModalFieldConfig.batched("image"),
video_pixel_values=MultiModalFieldConfig.batched("video"),
video_indicator_tokens=MultiModalFieldConfig.batched("video"),
video_grids=MultiModalFieldConfig.batched("video"),
)
def _get_prompt_updates( def _get_prompt_updates(
self, self,
@ -441,6 +451,8 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo])
dummy_inputs=Ovis2_5DummyInputsBuilder, dummy_inputs=Ovis2_5DummyInputsBuilder,
) )
class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@ -470,7 +482,7 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
) -> Optional[OvisImagePatchInputs]: ) -> Optional[Ovis2_5ImagePatchInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
indicator_tokens = kwargs.pop("indicator_tokens", None) indicator_tokens = kwargs.pop("indicator_tokens", None)
grids = kwargs.pop("grids", None) grids = kwargs.pop("grids", None)
@ -489,22 +501,22 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(indicator_tokens)}" f"Got type: {type(indicator_tokens)}"
) )
return OvisImagePatchInputs( return Ovis2_5ImagePatchInputs(
type="image_patches", type="image_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), flat_data=flatten_bn(pixel_values, concat=True),
patches_per_image=[ patches_per_item=[
x.shape[0] // (self.config.vit_config.hidden_stride**2) x.shape[0] // (self.config.vit_config.hidden_stride**2)
for x in flatten_bn(pixel_values) for x in pixel_values
], ],
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True), indicator_tokens=flatten_bn(indicator_tokens, concat=True),
grids=flatten_bn(flatten_bn(grids), concat=True), grids=flatten_bn(grids, concat=True),
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def _parse_and_validate_video_input( def _parse_and_validate_video_input(
self, **kwargs: object self, **kwargs: object
) -> Optional[OvisImagePatchInputs]: ) -> Optional[Ovis2_5VideoPatchInputs]:
pixel_values = kwargs.pop("video_pixel_values", None) pixel_values = kwargs.pop("video_pixel_values", None)
indicator_tokens = kwargs.pop("video_indicator_tokens", None) indicator_tokens = kwargs.pop("video_indicator_tokens", None)
grids = kwargs.pop("video_grids", None) grids = kwargs.pop("video_grids", None)
@ -523,26 +535,26 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(indicator_tokens)}" f"Got type: {type(indicator_tokens)}"
) )
return OvisVideoPatchInputs( return Ovis2_5VideoPatchInputs(
type="video_patches", type="video_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), flat_data=flatten_bn(pixel_values, concat=True),
patches_per_image=[ patches_per_item=[
x.shape[0] // (self.config.vit_config.hidden_stride**2) x.shape[0] // (self.config.vit_config.hidden_stride**2)
for x in flatten_bn(pixel_values) for x in pixel_values
], ],
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True), indicator_tokens=flatten_bn(indicator_tokens, concat=True),
grids=flatten_bn(flatten_bn(grids), concat=True), grids=flatten_bn(grids, concat=True),
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def _process_image_input( def _process_visual_input(
self, image_input: Union[OvisImagePatchInputs, OvisVideoPatchInputs] self, visual_input: Union[Ovis2_5ImagePatchInputs, Ovis2_5VideoPatchInputs]
) -> MultiModalEmbeddings: ) -> MultiModalEmbeddings:
image_patches_flat = image_input["flat_data"] image_patches_flat = visual_input["flat_data"]
patches_per_image = image_input["patches_per_image"] patches_per_image = visual_input["patches_per_item"]
indicator_tokens = image_input["indicator_tokens"] indicator_tokens = visual_input["indicator_tokens"]
grid_thws = image_input["grids"] grid_thws = visual_input["grids"]
indicator_per_image = list( indicator_per_image = list(
map(lambda x: 2 if x > 1 else x + 2, patches_per_image) map(lambda x: 2 if x > 1 else x + 2, patches_per_image)
@ -604,11 +616,11 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
for modality in modalities: for modality in modalities:
if modality == "images": if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_visual_input(image_input)
multimodal_embeddings += vision_embeddings multimodal_embeddings += vision_embeddings
if modality == "videos": if modality == "videos":
video_input = modalities["videos"] video_input = modalities["videos"]
video_embeddings = self._process_image_input(video_input) video_embeddings = self._process_visual_input(video_input)
multimodal_embeddings += video_embeddings multimodal_embeddings += video_embeddings
return multimodal_embeddings return multimodal_embeddings

View File

@ -408,7 +408,7 @@ class OvisProcessor(ProcessorMixin):
crops.insert(0, image) crops.insert(0, image)
pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0) pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0)
image_placeholders = self.construct_image_placeholders(grid) image_placeholders = self.construct_image_placeholders(grid)
return pixel_values, image_placeholders, grid return torch.tensor(pixel_values), image_placeholders, torch.tensor(grid)
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
""" """