[Model] Use merge_by_field_config for MM models (Ovis family) (#26308)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-10-07 20:54:22 +08:00 committed by GitHub
parent 63773a6200
commit 08d26a1b7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 86 additions and 81 deletions

View File

@ -1140,14 +1140,10 @@ def run_ovis2_5(questions: list[str], modality: str) -> ModelRequestData:
elif modality == "video":
placeholder = "<video>"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
messages = [
[{"role": "user", "content": f"{placeholder}\n{question}"}]
prompts = [
f"<|im_start|>user\n\n{placeholder}\n{question}<|im_end|>\n<|im_start|>assistant\n"
for question in questions
]
prompts = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,

View File

@ -713,11 +713,9 @@ def load_ovis2_5(question: str, image_urls: list[str]) -> ModelRequestData:
placeholders = "\n".join(
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
)
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
prompt = (
f"<|im_start|>user\n\n{placeholders}\n{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return ModelRequestData(

View File

@ -217,17 +217,17 @@ class VisualTokenizer(torch.nn.Module):
class OvisImagePatchInputs(TensorSchema):
"""
Dimensions:
- batch_patches: Batch size * number of patches
- patch_size: patch_size_x * patch_size_y * num_channels
- bnp: Batch size * number of images * number of patches
- h: Height of each patch
- w: Width of each patch
- patch_indicators: Batch size * (number of patches + 1)
- patches_per_image: List of number of total patches for each image
in the batch.
- bn: Batch size * number of images
"""
type: Literal["image_patches"]
flat_data: Annotated[torch.Tensor, TensorShape("batch_patches", "patch_size")]
flat_data: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
patches_per_image: Annotated[list[int], TensorShape("num_patches_per_image")]
patches_per_image: Annotated[list[int], TensorShape("bn")]
# This is used to restore the first two dimensions of `flat_data`.
@ -366,7 +366,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
self.image_indicators_to_visual_tokens(indicator)
for indicator in image_indicators
]
processed_outputs["indicator_tokens"] = indicator_tokens
processed_outputs["indicator_tokens"] = torch.tensor(indicator_tokens)
return processed_outputs
def _apply_hf_processor_tokens_only(
@ -414,6 +414,8 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
dummy_inputs=OvisDummyInputsBuilder,
)
class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
@ -470,14 +472,11 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(pixel_values)}"
)
flat_data = flatten_bn(pixel_values, concat=True)
if flat_data.ndim >= 3:
flat_data = flat_data.flatten(start_dim=1)
return OvisImagePatchInputs(
type="image_patches",
flat_data=flat_data,
patches_per_image=[x.shape[0] for x in flatten_bn(pixel_values)],
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True),
flat_data=flatten_bn(pixel_values, concat=True),
patches_per_image=[x.shape[0] for x in pixel_values],
indicator_tokens=flatten_bn(indicator_tokens, concat=True),
)
raise AssertionError("This line should be unreachable.")

View File

@ -4,7 +4,7 @@
from collections.abc import Iterable, Mapping
from functools import partial
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
@ -14,7 +14,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.ovis import OvisImagePatchInputs, VisualEmbedding
from vllm.model_executor.models.ovis import VisualEmbedding
from vllm.model_executor.models.siglip2navit import Siglip2NavitModel
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
@ -37,6 +37,7 @@ from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -58,36 +59,38 @@ IMAGE_PAD_TOKEN_ID_MAP = {
}
class OvisVideoPatchInputs(TypedDict):
class Ovis2_5ImagePatchInputs(TensorSchema):
"""
Dimensions:
- bnp: Batch size * number of images * number of patches
- patch_size: patch_size_x * patch_size_y * num_channels
- patch_indicators: Batch size * (number of patches + 1)
- bn: Batch size * number of images
"""
type: Literal["image_patches"]
flat_data: Annotated[torch.Tensor, TensorShape("bnp", "patch_size")]
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
patches_per_item: Annotated[list[int], TensorShape("bn")]
grids: Annotated[torch.Tensor, TensorShape("bn", 3)]
# This is used to restore the first two dimensions of `flat_data`.
class Ovis2_5VideoPatchInputs(TensorSchema):
"""
Dimensions:
- bnp: Batch size * number of videos * number of patches
- patch_size: patch_size_x * patch_size_y * num_channels
- patch_indicators: Batch size * (number of patches + 1)
- bn: Batch size * number of videos
"""
type: Literal["video_patches"]
flat_data: torch.Tensor
"""
Shape:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
"""
indicator_tokens: torch.Tensor
"""
Shape:
`(batch_size * (num_patches + 1))`
"""
patches_per_image: list[int]
"""
List of number of total patches for each frame in the video.
This is used to restore the first two dimensions of `flat_data`.
"""
def _ovis2_5_field_config():
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
grids=MultiModalFieldConfig.batched("image"),
indicator_tokens=MultiModalFieldConfig.batched("image"),
video_pixel_values=MultiModalFieldConfig.batched("video"),
video_indicator_tokens=MultiModalFieldConfig.batched("video"),
video_grids=MultiModalFieldConfig.batched("video"),
)
flat_data: Annotated[torch.Tensor, TensorShape("bnp", "patch_size")]
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
patches_per_item: Annotated[list[int], TensorShape("bn")]
grids: Annotated[torch.Tensor, TensorShape("bn", 3)]
# This is used to restore the first two dimensions of `flat_data`.
class VisualTokenizer(torch.nn.Module):
@ -380,7 +383,7 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo])
self.visual_indicators_to_visual_tokens(indicator)
for indicator in visual_indicators
]
processed_outputs["video_indicator_tokens"] = indicator_tokens
processed_outputs["video_indicator_tokens"] = torch.tensor(indicator_tokens)
if "images" in mm_data:
visual_indicators = [
hf_processor.construct_visual_indicators((1, 1, 1), False)
@ -391,7 +394,7 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo])
for indicator in visual_indicators
]
processed_outputs["indicator_tokens"] = indicator_tokens
processed_outputs["indicator_tokens"] = torch.tensor(indicator_tokens)
return processed_outputs
def _apply_hf_processor_tokens_only(
@ -405,7 +408,14 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo])
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _ovis2_5_field_config()
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
grids=MultiModalFieldConfig.batched("image"),
indicator_tokens=MultiModalFieldConfig.batched("image"),
video_pixel_values=MultiModalFieldConfig.batched("video"),
video_indicator_tokens=MultiModalFieldConfig.batched("video"),
video_grids=MultiModalFieldConfig.batched("video"),
)
def _get_prompt_updates(
self,
@ -441,6 +451,8 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo])
dummy_inputs=Ovis2_5DummyInputsBuilder,
)
class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@ -470,7 +482,7 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Optional[OvisImagePatchInputs]:
) -> Optional[Ovis2_5ImagePatchInputs]:
pixel_values = kwargs.pop("pixel_values", None)
indicator_tokens = kwargs.pop("indicator_tokens", None)
grids = kwargs.pop("grids", None)
@ -489,22 +501,22 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(indicator_tokens)}"
)
return OvisImagePatchInputs(
return Ovis2_5ImagePatchInputs(
type="image_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
patches_per_image=[
flat_data=flatten_bn(pixel_values, concat=True),
patches_per_item=[
x.shape[0] // (self.config.vit_config.hidden_stride**2)
for x in flatten_bn(pixel_values)
for x in pixel_values
],
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True),
grids=flatten_bn(flatten_bn(grids), concat=True),
indicator_tokens=flatten_bn(indicator_tokens, concat=True),
grids=flatten_bn(grids, concat=True),
)
raise AssertionError("This line should be unreachable.")
def _parse_and_validate_video_input(
self, **kwargs: object
) -> Optional[OvisImagePatchInputs]:
) -> Optional[Ovis2_5VideoPatchInputs]:
pixel_values = kwargs.pop("video_pixel_values", None)
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
grids = kwargs.pop("video_grids", None)
@ -523,26 +535,26 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(indicator_tokens)}"
)
return OvisVideoPatchInputs(
return Ovis2_5VideoPatchInputs(
type="video_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
patches_per_image=[
flat_data=flatten_bn(pixel_values, concat=True),
patches_per_item=[
x.shape[0] // (self.config.vit_config.hidden_stride**2)
for x in flatten_bn(pixel_values)
for x in pixel_values
],
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True),
grids=flatten_bn(flatten_bn(grids), concat=True),
indicator_tokens=flatten_bn(indicator_tokens, concat=True),
grids=flatten_bn(grids, concat=True),
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self, image_input: Union[OvisImagePatchInputs, OvisVideoPatchInputs]
def _process_visual_input(
self, visual_input: Union[Ovis2_5ImagePatchInputs, Ovis2_5VideoPatchInputs]
) -> MultiModalEmbeddings:
image_patches_flat = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"]
indicator_tokens = image_input["indicator_tokens"]
grid_thws = image_input["grids"]
image_patches_flat = visual_input["flat_data"]
patches_per_image = visual_input["patches_per_item"]
indicator_tokens = visual_input["indicator_tokens"]
grid_thws = visual_input["grids"]
indicator_per_image = list(
map(lambda x: 2 if x > 1 else x + 2, patches_per_image)
@ -604,11 +616,11 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
vision_embeddings = self._process_visual_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_image_input(video_input)
video_embeddings = self._process_visual_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings

View File

@ -408,7 +408,7 @@ class OvisProcessor(ProcessorMixin):
crops.insert(0, image)
pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0)
image_placeholders = self.construct_image_placeholders(grid)
return pixel_values, image_placeholders, grid
return torch.tensor(pixel_values), image_placeholders, torch.tensor(grid)
def batch_decode(self, *args, **kwargs):
"""