diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index 38f052aba3187..28a4a1e8d2596 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -631,8 +631,11 @@ class InternS1ForConditionalGeneration( ) image_token_id = kwargs["image_token_id"] - assert isinstance(image_token_id, torch.Tensor) - self.img_context_token_id = image_token_id.flatten().unique().item() + if isinstance(image_token_id, torch.Tensor): + image_token_id = image_token_id.flatten().unique().item() + + assert isinstance(image_token_id, int) + self.img_context_token_id = image_token_id if pixel_values is not None: h, w = self.config.vision_config.image_size @@ -665,8 +668,11 @@ class InternS1ForConditionalGeneration( ) video_token_id = kwargs["video_token_id"] - assert isinstance(video_token_id, torch.Tensor) - self.video_context_token_id = video_token_id.flatten().unique().item() + if isinstance(video_token_id, torch.Tensor): + video_token_id = video_token_id.flatten().unique().item() + + assert isinstance(video_token_id, int) + self.video_context_token_id = video_token_id if pixel_values_flat_video is not None: h, w = self.config.vision_config.image_size diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 47429ef1b76e0..28a35595f43aa 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -1232,8 +1232,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) ) image_token_id = kwargs["image_token_id"] - assert isinstance(image_token_id, torch.Tensor) - self.img_context_token_id = image_token_id.flatten().unique().item() + if isinstance(image_token_id, torch.Tensor): + image_token_id = image_token_id.flatten().unique().item() + + assert isinstance(image_token_id, int) + self.img_context_token_id = image_token_id if pixel_values_flat is not None: expected_h = expected_w = self.config.vision_config.image_size @@ -1265,8 +1268,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) ) video_token_id = kwargs["video_token_id"] - assert isinstance(video_token_id, torch.Tensor) - self.video_context_token_id = video_token_id.flatten().unique().item() + if isinstance(video_token_id, torch.Tensor): + video_token_id = video_token_id.flatten().unique().item() + + assert isinstance(video_token_id, int) + self.video_context_token_id = video_token_id if pixel_values_flat_video is not None: expected_h = expected_w = self.config.vision_config.image_size diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py index 2a798672d13c6..322cce79d4cb2 100644 --- a/vllm/model_executor/models/midashenglm.py +++ b/vllm/model_executor/models/midashenglm.py @@ -26,7 +26,7 @@ import collections import collections.abc from collections.abc import Callable, Iterable, Mapping, Sequence -from typing import Any, TypeAlias, TypedDict, cast +from typing import Annotated, Any, TypeAlias, cast import numpy as np import torch @@ -62,6 +62,7 @@ from vllm.multimodal.processing import ( from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.midashenglm import DashengConfig +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix @@ -508,11 +509,16 @@ class AudioProjectorSubsample(nn.Module): # === Audio Inputs === # -class MiDashengLMAudioInputs(TypedDict): - input_values: torch.Tensor - """Shape: `(num_audios, num_sampling_points)`""" - audio_length: torch.Tensor - """Shape: `(num_audios, 1)`""" +class MiDashengLMAudioInputs(TensorSchema): + """ + + Dimensions: + - bn: Batch size * number of audios + - p: Number of sampling points + """ + + input_values: Annotated[torch.Tensor, TensorShape("n", "p")] + audio_length: Annotated[torch.Tensor, TensorShape("n")] class MiDashengLMProcessingInfo(BaseProcessingInfo): @@ -676,6 +682,8 @@ class MiDashengLMMultiModalProcessor( dummy_inputs=MiDashengLMDummyInputsBuilder, ) class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -728,26 +736,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): self.decoder.make_empty_intermediate_tensors ) - def _validate_and_reshape_mm_tensor( - self, mm_input: object, name: str - ) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - return mm_input.reshape(-1, *mm_input.shape[2:]) - - if name == "input_values": - max_length = max(tensor.shape[1] for tensor in mm_input) - padded_mm_input = [ - torch.nn.functional.pad(tensor, (0, max_length - tensor.shape[1])) - if tensor.shape[1] < max_length - else tensor - for tensor in mm_input - ] - return torch.concat(padded_mm_input) - - return torch.concat(mm_input) - def _parse_and_validate_audio_input( self, **kwargs: object ) -> MiDashengLMAudioInputs | None: @@ -756,16 +744,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): if input_values is None: return None - input_values = self._validate_and_reshape_mm_tensor( - input_values, "input_values" - ) - audio_length = self._validate_and_reshape_mm_tensor( - audio_length, "audio_length" - ) - if not isinstance(input_values, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of audio input features. " - f"Got type: {type(input_values)}" + + if isinstance(input_values, list): + input_values = torch.nn.utils.rnn.pad_sequence( + input_values, + batch_first=True, ) return MiDashengLMAudioInputs( @@ -773,7 +756,10 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): audio_length=audio_length, ) - def _process_audio_input(self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor: + def _process_audio_input( + self, + audio_input: MiDashengLMAudioInputs, + ) -> tuple[torch.Tensor, ...]: # Process audio through encoder and projector input_values = audio_input["input_values"] audio_length = audio_input["audio_length"] @@ -783,17 +769,13 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): audio_embeddings = audio_embeddings.to(audio_input["input_values"].dtype) batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape - audio_length_np = ( - audio_length.cpu().numpy() - if isinstance(audio_length, torch.Tensor) - else audio_length - ) audio_output_lengths = [ max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame - for length in audio_length_np + for length in audio_length.tolist() ] - audio_output_lengths = torch.tensor(audio_output_lengths).to( - audio_embeddings.device + audio_output_lengths = torch.tensor( + audio_output_lengths, + device=audio_embeddings.device, ) audio_feature_mask = torch.arange( @@ -826,14 +808,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None - elif inputs_embeds is None: - multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings( - input_ids, - multimodal_embeddings, - is_multimodal=input_ids == self.config.audio_token_id, - ) - input_ids = None return self.decoder.model( input_ids, diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index 827c019008ab8..371c9607c5c5b 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -71,7 +71,7 @@ from .minicpmv import ( MiniCPMVProcessingInfo, _minicpmv_field_config, ) -from .utils import AutoWeightsLoader, cast_overflow_tensors, flatten_bn, maybe_prefix +from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix CPU_DEVICE = torch.device("cpu") @@ -132,15 +132,11 @@ MiniCPMOAudioInputs: TypeAlias = ( def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): - audio_features = hf_inputs.get("audio_features", torch.empty(0)) - num_audios = len(audio_features) - return dict( **_minicpmv_field_config(hf_inputs), audio_features=MultiModalFieldConfig.batched("audio"), audio_feature_lens=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"), - audio_token_id=MultiModalFieldConfig.shared("audio", num_audios), ) @@ -332,10 +328,6 @@ class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessing ] audio_inputs["audio_features"] = unpadded_audio_features - tokenizer = self.info.get_tokenizer() - unk_token_id = tokenizer.get_vocab()[""] - audio_inputs["audio_token_id"] = torch.tensor(unk_token_id) - return audio_inputs def process_mm_inputs( @@ -436,12 +428,10 @@ class MiniCPMWhisperEncoderLayer(nn.Module): attention_mask: torch.Tensor, ) -> torch.Tensor: residual = hidden_states - past_key_values = None hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, past_key_values = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, - past_key_value=past_key_values, ) hidden_states = nn.functional.dropout( hidden_states, p=self.dropout, training=self.training @@ -567,8 +557,6 @@ class MiniCPMO(MiniCPMV2_6): vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm") ) - self.audio_token_id = None - def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""): # Do not use parameters temporarily audio_config = self.config.audio_config @@ -731,43 +719,18 @@ class MiniCPMO(MiniCPMV2_6): if audio_features is None and audio_embeds is None: return None - audio_token_id = kwargs.pop("audio_token_id") - if audio_token_id is not None: - assert isinstance(audio_token_id, torch.Tensor) - self.mm_token_ids.add(audio_token_id.flatten().unique().item()) - if audio_embeds is not None: - if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of audio_embeds. Got type: {type(audio_embeds)}" - ) - - audio_embeds_flat = flatten_bn(audio_embeds) - return MiniCPMOAudioEmbeddingInputs( type="audio_embeds", - audio_embeds=audio_embeds_flat, - ) - - if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of audio_features. Got type: {type(audio_features)}" + audio_embeds=audio_embeds, ) audio_feature_lens = kwargs.pop("audio_feature_lens") - if not isinstance(audio_feature_lens, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of audio_feature_lens. " - f"Got type: {type(audio_feature_lens)}" - ) - - audio_features_flat = flatten_bn(audio_features) - audio_feature_lens_flat = flatten_bn(audio_feature_lens) return MiniCPMOAudioFeatureInputs( type="audio_features", - audio_features=audio_features_flat, - audio_feature_lens=audio_feature_lens_flat, + audio_features=audio_features, + audio_feature_lens=audio_feature_lens, ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 53a25cf988481..173cab3bffc10 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -114,7 +114,7 @@ class MiniCPMVImagePixelInputs(TensorSchema): type: Literal["pixel_values"] = "pixel_values" - # Note that the image size may vary, so we pass it as a list instead of a + # Note that the patch size may vary, so we pass it as a list instead of a # batched tensor. pixel_values: Annotated[ list[torch.Tensor], @@ -453,12 +453,6 @@ def get_version_by_config(config: PretrainedConfig) -> tuple[int, ...]: def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): - pixel_values = hf_inputs.get("pixel_values", torch.empty(0)) - num_images = len(pixel_values) - - video_pixel_values = hf_inputs.get("video_pixel_values", torch.empty(0)) - num_videos = len(video_pixel_values) - return dict( pixel_values=MultiModalFieldConfig.batched("image"), image_sizes=MultiModalFieldConfig.batched("image"), @@ -468,8 +462,6 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): video_image_sizes=MultiModalFieldConfig.batched("video"), video_tgt_sizes=MultiModalFieldConfig.batched("video"), video_embeds=MultiModalFieldConfig.batched("video"), - image_token_id=MultiModalFieldConfig.shared("image", num_images), - video_token_id=MultiModalFieldConfig.shared("video", num_videos), ) @@ -792,10 +784,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, ) - tokenizer = self.info.get_tokenizer() - unk_token_id = tokenizer.get_vocab()[""] - image_inputs["image_token_id"] = torch.tensor(unk_token_id) - return image_inputs def process_videos( @@ -831,10 +819,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): video_inputs = {f"video_{k}": v for k, v in video_inputs.items()} - tokenizer = self.info.get_tokenizer() - unk_token_id = tokenizer.get_vocab()[""] - video_inputs["video_token_id"] = torch.tensor(unk_token_id) - return video_inputs def process_mm_inputs( @@ -1021,6 +1005,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): instantiated. """ + merge_by_field_config = True + supports_encoder_tp_data = True @classmethod @@ -1066,7 +1052,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): prefix=maybe_prefix(prefix, "resampler"), ) - self.mm_token_ids = set[int]() self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors def _parse_and_validate_vision_input( @@ -1080,43 +1065,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): if pixel_values is None and image_embeds is None: return None - image_token_id = kwargs.pop("image_token_id") - if image_token_id is not None: - assert isinstance(image_token_id, torch.Tensor) - self.mm_token_ids.add(image_token_id.flatten().unique().item()) - if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of image_embeds for {modality=}. " - f"Got type: {type(image_embeds)}" - ) - - image_embeds_flat = flatten_bn(image_embeds) - return MiniCPMVImageEmbeddingInputs( type="image_embeds", - image_embeds=image_embeds_flat, - ) - - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of pixel_values for {modality=}. " - f"Got type: {type(pixel_values)}" + image_embeds=image_embeds, ) tgt_sizes = kwargs.pop("tgt_sizes") - if not isinstance(tgt_sizes, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of tgt_sizes for {modality=}. " - f"Got type: {type(tgt_sizes)}" - ) - num_slices = [[len(p) for p in ps] for ps in pixel_values] - num_slices_flat = flatten_bn(torch.tensor(num_slices)) - - pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values)) - tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True) + num_slices_flat = torch.tensor([len(ps) for ps in pixel_values]) + pixel_values_flat = flatten_bn(pixel_values) + tgt_sizes_flat = flatten_bn(tgt_sizes, concat=True) return MiniCPMVImagePixelInputs( type="pixel_values", @@ -1142,15 +1101,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): input_key in ("video_pixel_values", "video_embeds") and "videos" not in modalities ): - - def _image_key(video_key: str): - if video_key == "video_token_id": - return "image_token_id" - - return video_key.removeprefix("video_") - modalities["videos"] = self._parse_and_validate_vision_input( - "videos", **{_image_key(k): v for k, v in kwargs.items()} + "videos", **{k.removeprefix("video_"): v for k, v in kwargs.items()} ) return modalities diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 8da45f681043a..81be1135dfd9b 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -71,7 +71,7 @@ from .interfaces import ( SupportsPP, ) from .llama4 import Llama4ForCausalLM -from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix +from .utils import AutoWeightsLoader, maybe_prefix from .vision import run_dp_sharded_vision_model @@ -86,7 +86,7 @@ class Llama4ImagePatchInputs(TensorSchema): type: Literal["pixel_values"] = "pixel_values" - flat_data: Annotated[ + pixel_values: Annotated[ torch.Tensor, TensorShape("total_num_chunks", "num_channels", "image_size", "image_size"), ] @@ -96,7 +96,7 @@ class Llama4ImagePatchInputs(TensorSchema): The number of total patches for each image in the batch. This is used to split the embeddings which has the first two dimensions - flattened just like `flat_data`. + flattened just like `pixel_values`. """ aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)] @@ -725,6 +725,8 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]): class Llama4ForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3 ): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -798,17 +800,12 @@ class Llama4ForConditionalGeneration( if pixel_values is None: return None - # num_images x num_chunks, channel, image_size, image_size - # TODO: confirm handling for variable lengths - flat_pixel_values = flatten_bn(pixel_values, concat=True) - patches_per_image = flatten_bn(kwargs.pop("patches_per_image")) + patches_per_image = kwargs.pop("patches_per_image") aspect_ratios = kwargs.pop("aspect_ratios") - if aspect_ratios.ndim == 3: - aspect_ratios = aspect_ratios.squeeze(1) return Llama4ImagePatchInputs( type="pixel_values", - flat_data=flat_pixel_values, + pixel_values=pixel_values, patches_per_image=patches_per_image, aspect_ratios=aspect_ratios, ) @@ -817,16 +814,16 @@ class Llama4ForConditionalGeneration( self, image_input: Llama4ImagePatchInputs ) -> MultiModalEmbeddings: assert self.vision_model and self.multi_modal_projector - flat_data = image_input["flat_data"] + pixel_values = image_input["pixel_values"] patches_per_image = image_input["patches_per_image"].tolist() # shard image input if self.use_data_parallel: vision_embeddings_flat = run_dp_sharded_vision_model( - flat_data, self.vision_model + pixel_values, self.vision_model ) else: - vision_embeddings_flat = self.vision_model(flat_data) + vision_embeddings_flat = self.vision_model(pixel_values) vision_embeddings_flat = self.multi_modal_projector(vision_embeddings_flat) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 83e0f282ddf8d..106aaf413e99b 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -75,7 +75,6 @@ from .interfaces import ( from .utils import ( AutoWeightsLoader, WeightsMapper, - flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -97,28 +96,19 @@ class MolmoImageInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - - nc: Number of crops (dynamic) + - bnc: Batch size * number of images * number of crops (dynamic) - np: Number of patches - tp: Token sequence positions - pd: Patch dimension """ - images: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"}), - ] - # Number of crops may vary per batch and image, so pass it as a list. + images: Annotated[torch.Tensor, TensorShape("bnc", "np", "pd")] - image_masks: Annotated[ - torch.Tensor | list[torch.Tensor] | None, - TensorShape("bn", "nc", "np", dynamic_dims={"nc"}), - ] + image_masks: Annotated[torch.Tensor | None, TensorShape("bnc", "np")] + + image_input_idx: Annotated[torch.Tensor, TensorShape("bnc", "tp")] + """An index tensor that maps image features to their corresponding patch tokens.""" - image_input_idx: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}), - ] - # An index tensor that maps image features to their corresponding patch tokens. num_crops: Annotated[torch.Tensor, TensorShape("bn")] @@ -1363,6 +1353,8 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): class MolmoForCausalLM( nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant ): + merge_by_field_config = True + hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ # vision backbone mapping @@ -1451,18 +1443,12 @@ class MolmoForCausalLM( if images is None: return None - if not isinstance(num_crops, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of num_crops. Got type: {type(num_crops)}" - ) - num_crops = flatten_bn(num_crops, concat=True) - img_patch_id = kwargs.pop("img_patch_id", None) - if not isinstance(img_patch_id, torch.Tensor): - raise ValueError( - f"Incorrect type of img_patch_id. Got type: {type(img_patch_id)}" - ) - self.img_patch_id = img_patch_id.flatten().unique().item() + if isinstance(img_patch_id, torch.Tensor): + img_patch_id = img_patch_id.item() + + assert isinstance(img_patch_id, int) + self.img_patch_id = img_patch_id return MolmoImageInputs( images=images, @@ -1481,17 +1467,9 @@ class MolmoForCausalLM( num_crops = image_input["num_crops"] # Call the vision backbone on the whole batch at once - images_flat = flatten_bn(images, concat=True) - image_masks_flat = ( - None if image_masks is None else flatten_bn(image_masks, concat=True) - ) - image_input_idx_flat = flatten_bn(image_input_idx, concat=True) - - image_features_flat = self.vision_backbone( - images=images_flat.unsqueeze(0), - image_masks=( - None if image_masks_flat is None else image_masks_flat.unsqueeze(0) - ), + image_features = self.vision_backbone( + images=images.unsqueeze(0), + image_masks=None if image_masks is None else image_masks.unsqueeze(0), ).squeeze(0) # Only the features corresponding to patch tokens are relevant @@ -1499,8 +1477,8 @@ class MolmoForCausalLM( results = [] num_crops_list = num_crops.tolist() for feats, img_idx in zip( - image_features_flat.split(num_crops_list), - image_input_idx_flat.split(num_crops_list), + image_features.split(num_crops_list), + image_input_idx.split(num_crops_list), ): is_valid = img_idx >= 0 valid_img_idx = img_idx[is_valid] diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 56d3c4bb7d107..dfb7cb7fe6bd4 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -11,7 +11,7 @@ import copy import warnings from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, TypeAlias, TypedDict, TypeVar +from typing import Annotated, Any, Literal, TypeAlias, TypeVar import numpy.typing as npt import torch @@ -40,7 +40,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM from vllm.model_executor.models.radio import RadioModel from vllm.model_executor.models.utils import ( - flatten_bn, init_vllm_registered_model, maybe_prefix, ) @@ -96,31 +95,35 @@ MAX_FRAMES = 16 DEFAULT_NUM_TILES = 12 -class NanoNemotronVLImagePixelInputs(TypedDict): +class NanoNemotronVLImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - bnp: Batch size * number of images * (1 + num_patches) + - c: Number of channels (3) + - h: Height of each image patch + - w: Width of each image patch + """ + type: Literal["pixel_values"] - pixel_values_flat: torch.Tensor + pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")] + num_patches: Annotated[torch.Tensor, TensorShape("bn")] + + +class NanoNemotronVLImageEmbeddingInputs(TensorSchema): """ - Shape: - `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + Dimensions: + - n: Number of images + - f: Total image feature size + - h: Hidden size (must match the hidden size of language model backbone) """ - num_patches: torch.Tensor - """Shape: `(batch_size * num_images)`""" - - -class NanoNemotronVLImageEmbeddinInputs(TypedDict): type: Literal["image_embeds"] - data: torch.Tensor | list[torch.Tensor] - """ - A tensor of shape `(num_images, total_image_feature_size, hidden_size)` - or a list of tensors of shape `(total_image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. - """ + data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")] NanoNemotronVLImageInputs: TypeAlias = ( - NanoNemotronVLImagePixelInputs | NanoNemotronVLImageEmbeddinInputs + NanoNemotronVLImagePixelInputs | NanoNemotronVLImageEmbeddingInputs ) @@ -710,37 +713,12 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]): """Basic image-only MultiModalProcessor for InternVL-style models.""" - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> BatchFeature: - processed_outputs = super()._call_hf_processor( - prompt=prompt, - mm_data=mm_data, - mm_kwargs=mm_kwargs, - tok_kwargs=tok_kwargs, - ) - - hf_processor = self.info.get_hf_processor(**mm_kwargs) - image_token_id = hf_processor.image_token_id - - # Since there may be extra tokens in the feature placeholders, - # we need to pass the image token ID to the model to select the - # tokens to merge from the vision encoder outputs - processed_outputs["image_token_id"] = torch.tensor(image_token_id) - - return processed_outputs - def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) - num_images = len(image_num_patches) return dict( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( @@ -748,7 +726,6 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]): ), image_num_patches=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), - image_token_id=MultiModalFieldConfig.shared("image", num_images), ) def _get_prompt_updates( @@ -814,25 +791,6 @@ class NanoNemotronVLMultiModalProcessor( ): """MultiModalProcessor extended for video support""" - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> BatchFeature: - processed_outputs = super()._call_hf_processor( - prompt, mm_data, mm_kwargs, tok_kwargs - ) - - hf_processor = self.info.get_hf_processor(**mm_kwargs) - if ( - self.info.supports_video - and (video_token_id := hf_processor.video_token_id) is not None - ): - processed_outputs["video_token_id"] = torch.tensor(video_token_id) - return processed_outputs - def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -841,13 +799,12 @@ class NanoNemotronVLMultiModalProcessor( image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs) if self.info.supports_video: video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0)) - num_videos = len(video_num_patches) + video_fields = dict( pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes( "video", video_num_patches ), video_num_patches=MultiModalFieldConfig.batched("video"), - video_token_id=MultiModalFieldConfig.shared("video", num_videos), ) else: video_fields = {} @@ -999,6 +956,8 @@ class NanoNemotronVLDummyInputsBuilder( class NemotronH_Nano_VL_V2( nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning ): + merge_by_field_config = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): @@ -1051,8 +1010,6 @@ class NemotronH_Nano_VL_V2( ) self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype) - self.img_context_token_id = None - self.video_context_token_id = None self.config = config self.model_config = vllm_config.model_config @@ -1106,37 +1063,12 @@ class NemotronH_Nano_VL_V2( return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}" - ) - - return NanoNemotronVLImageEmbeddinInputs( + return NanoNemotronVLImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) - image_token_id = kwargs["image_token_id"] - assert isinstance(image_token_id, torch.Tensor) - self.img_context_token_id = image_token_id.flatten().unique().item() - if pixel_values_flat is not None: - if not isinstance(pixel_values_flat, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat)}" - ) - - if not isinstance(image_num_patches, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of image_num_patches. " - f"Got type: {type(image_num_patches)}" - ) - - pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) - image_num_patches = flatten_bn(image_num_patches, concat=True) - return NanoNemotronVLImagePixelInputs( type="pixel_values", pixel_values_flat=pixel_values_flat, @@ -1285,28 +1217,10 @@ class NemotronH_Nano_VL_V2( if video_embeds is not None: return NanoNemotronVLVideoEmbeddingInputs( type="video_embeds", - data=flatten_bn(video_embeds), + data=video_embeds, ) - video_token_id = kwargs["video_token_id"] - assert isinstance(video_token_id, torch.Tensor) - self.video_context_token_id = video_token_id.flatten().unique().item() - if pixel_values_flat_video is not None: - if not isinstance(pixel_values_flat_video, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat_video)}" - ) - - if not isinstance(video_num_patches, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of image_num_patches. " - f"Got type: {type(video_num_patches)}" - ) - - pixel_values_flat_video = flatten_bn(pixel_values_flat_video, concat=True) - video_num_patches = flatten_bn(video_num_patches, concat=True) expected_h = expected_w = self.config.force_image_size resolve_bindings = {"h": expected_h, "w": expected_w} diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index 9e1323f41ee08..42f70ef105a5d 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -496,8 +496,11 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor ) image_token_id = kwargs["image_token_id"] - assert isinstance(image_token_id, torch.Tensor) - self.img_context_token_id = image_token_id.flatten().unique().item() + if isinstance(image_token_id, torch.Tensor): + image_token_id = image_token_id.flatten().unique().item() + + assert isinstance(image_token_id, int) + self.img_context_token_id = image_token_id if pixel_values_flat is not None: return InternVLImagePixelInputs( diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 709c17a8e8638..44550ae595d13 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -814,8 +814,11 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ) image_token_id = kwargs["image_token_id"] - assert isinstance(image_token_id, torch.Tensor) - self.img_context_token_id = image_token_id.flatten().unique().item() + if isinstance(image_token_id, torch.Tensor): + image_token_id = image_token_id.flatten().unique().item() + + assert isinstance(image_token_id, int) + self.img_context_token_id = image_token_id if pixel_values_flat is not None: return SkyworkR1VImagePixelInputs( diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index ecc1862c42f8c..5b228e6b3aeb3 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -432,7 +432,7 @@ def group_mm_kwargs_by_modality( if device is not None: mm_kwargs_group = json_map_leaves( - lambda x: x.to(device=device), + lambda x: x.to(device=device) if isinstance(x, torch.Tensor) else x, mm_kwargs_group, ) else: