diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 5ca2156c08b5..c705a70b93f5 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -677,7 +677,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | | `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | -| `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + IE+ + VE+ | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | ✅︎ | diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index f8ddb5a22b31..1d6d819ff58a 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -576,7 +576,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData: # Intern-S1 def run_interns1(questions: list[str], modality: str) -> ModelRequestData: - model_name = "internlm/Intern-S1" + model_name = "internlm/Intern-S1-mini" engine_args = EngineArgs( model=model_name, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 51b41f34b2ff..e0d95758a822 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -309,7 +309,7 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData: def load_interns1(question: str, image_urls: list[str]) -> ModelRequestData: - model_name = "internlm/Intern-S1" + model_name = "internlm/Intern-S1-mini" engine_args = EngineArgs( model=model_name, diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index 0292845f819c..e5caf0eae37d 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -25,7 +25,7 @@ from vllm.model_executor.models.interns1_vit import InternS1VisionModel from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -39,7 +39,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, +from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix) @@ -304,7 +304,7 @@ class InternS1MultiModalProcessor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: mm_data = dict(mm_data) videos = mm_data.pop("videos", []) images = mm_data.pop("images", []) @@ -342,7 +342,7 @@ class InternS1MultiModalProcessor( image_placeholder, 1) num_patches = [len(item) for item in image_pixel_values] - image_outputs: dict[str, NestedTensors] = { + image_outputs = { "pixel_values": torch.concat(image_pixel_values), "image_num_patches": torch.tensor(num_patches), "image_token_id": torch.tensor(hf_processor.image_token_id), @@ -370,7 +370,7 @@ class InternS1MultiModalProcessor( video_placeholder, 1) num_frames = [len(item) for item in video_pixel_values] - video_outputs: dict[str, NestedTensors] = { + video_outputs = { "pixel_values_videos": torch.concat(video_pixel_values), "video_num_patches": torch.tensor(num_frames), "video_token_id": torch.tensor(video_token_id), @@ -382,16 +382,11 @@ class InternS1MultiModalProcessor( prompt) text_outputs = tokenizer(prompt, **tok_kwargs, return_tensors="pt") - combined_outputs = dict( - **text_outputs, - **image_outputs, - **video_outputs, - ) - return BatchFeature(combined_outputs) + return BatchFeature({**text_outputs, **image_outputs, **video_outputs}) def _get_mm_fields_config( self, - hf_inputs: Mapping[str, NestedTensors], + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: @@ -487,6 +482,7 @@ class InternS1MultiModalProcessor( dummy_inputs=InternS1DummyInputsBuilder) class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + merge_by_field_config = True # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( @@ -561,7 +557,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, prefix=prefix, ) - def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: return InternS1MultiModalProjector(config) def pixel_shuffle(self, x, scale_factor=0.5): @@ -599,13 +595,9 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return InternS1ImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) image_token_id = kwargs["image_token_id"] @@ -613,17 +605,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, self.img_context_token_id = image_token_id.flatten().unique().item() if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - if not isinstance(image_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(image_num_patches)}") - - pixel_values = flatten_bn(pixel_values, concat=True) - image_num_patches = flatten_bn(image_num_patches, concat=True) - h, w = self.config.vision_config.image_size return InternS1ImagePixelInputs( type="pixel_values", @@ -638,7 +619,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, raise AssertionError("This line should be unreachable.") def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[InternS1VideoPixelInputs]: + self, **kwargs: object) -> Optional[InternS1VideoInputs]: pixel_values_flat_video = kwargs.pop("pixel_values_videos", None) video_num_patches = kwargs.pop("video_num_patches", None) video_embeds = kwargs.pop("video_embeds", None) @@ -647,13 +628,9 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, return None if video_embeds is not None: - if not isinstance(video_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") - - return InternS1ImageEmbeddingInputs( + return InternS1VideoEmbeddingInputs( type="video_embeds", - data=flatten_bn(video_embeds), + data=video_embeds, ) video_token_id = kwargs["video_token_id"] @@ -661,18 +638,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, self.video_context_token_id = video_token_id.flatten().unique().item() if pixel_values_flat_video is not None: - if not isinstance(pixel_values_flat_video, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat_video)}") - - if not isinstance(video_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(video_num_patches)}") - - pixel_values_flat_video = flatten_bn(pixel_values_flat_video, - concat=True) - video_num_patches = flatten_bn(video_num_patches, concat=True) - h, w = self.config.vision_config.image_size return InternS1VideoPixelInputs( type="pixel_values_videos", @@ -686,11 +651,12 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, raise AssertionError("This line should be unreachable.") - def _process_image_input( + def _process_vision_input( self, - image_input: Union[InternS1ImageInputs, InternS1VideoPixelInputs], + image_input: Union[InternS1ImageInputs, InternS1VideoInputs], ) -> tuple[torch.Tensor, ...]: - if image_input["type"] == "image_embeds": + if (image_input["type"] == "image_embeds" + or image_input["type"] == "video_embeds"): return image_input["data"] assert self.vision_tower is not None @@ -753,11 +719,11 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, for modality in modalities: if modality == "images": image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) + vision_embeddings = self._process_vision_input(image_input) multimodal_embeddings += vision_embeddings if modality == "videos": video_input = modalities["videos"] - video_embeddings = self._process_image_input(video_input) + video_embeddings = self._process_vision_input(video_input) multimodal_embeddings += video_embeddings return multimodal_embeddings diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 0c95c49f90b1..1f3224f9ac58 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn import torchvision.transforms as T from PIL import Image -from transformers import BatchEncoding, PretrainedConfig, TensorType +from transformers import BatchFeature, PretrainedConfig, TensorType from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig @@ -28,7 +28,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -42,8 +42,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix IMG_START = '' IMG_END = '' @@ -471,7 +470,7 @@ class BaseInternVLProcessor(ABC): max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, ) - image_inputs: dict[str, NestedTensors] = { + image_inputs = { "pixel_values_flat": torch.cat(pixel_values_lst), "image_num_patches": @@ -502,7 +501,7 @@ class BaseInternVLProcessor(ABC): max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: text, images = [self._make_batch_input(x) for x in (text, images)] text, image_inputs = self._preprocess_image( @@ -515,10 +514,9 @@ class BaseInternVLProcessor(ABC): text_inputs = self.tokenizer(text) - return { - **BatchEncoding(text_inputs, tensor_type=return_tensors), - **image_inputs, - } + combined_outputs = {**text_inputs, **image_inputs} + + return BatchFeature(combined_outputs, tensor_type=return_tensors) class InternVLProcessor(BaseInternVLProcessor): @@ -598,7 +596,7 @@ class InternVLProcessor(BaseInternVLProcessor): videos, dynamic_image_size=dynamic_image_size, ) - video_inputs: dict[str, NestedTensors] = { + video_inputs = { "pixel_values_flat_video": torch.cat(pixel_values_lst_video), "video_num_patches": @@ -622,7 +620,7 @@ class InternVLProcessor(BaseInternVLProcessor): max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: text, images, videos = [ self._make_batch_input(x) for x in (text, images, videos) ] @@ -643,11 +641,9 @@ class InternVLProcessor(BaseInternVLProcessor): text_inputs = self.tokenizer(text) - return { - **BatchEncoding(text_inputs, tensor_type=return_tensors), - **image_inputs, - **video_inputs, - } + combined_outputs = {**text_inputs, **image_inputs, **video_inputs} + + return BatchFeature(combined_outputs, tensor_type=return_tensors) def get_image_repl( self, @@ -773,7 +769,7 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, @@ -793,7 +789,7 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): def _get_mm_fields_config( self, - hf_inputs: Mapping[str, NestedTensors], + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) @@ -948,7 +944,7 @@ class InternVLMultiModalProcessor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: processed_outputs = super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs) @@ -960,7 +956,7 @@ class InternVLMultiModalProcessor( def _get_mm_fields_config( self, - hf_inputs: Mapping[str, NestedTensors], + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_fields = super()._get_mm_fields_config(hf_inputs, @@ -1033,6 +1029,7 @@ class InternVLMultiModalProcessor( dummy_inputs=InternVLDummyInputsBuilder) class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + merge_by_field_config = True supports_encoder_tp_data = True @@ -1126,7 +1123,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, else: return InternVisionPatchModel(config.vision_config) - def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: vit_hidden_size = config.vision_config.hidden_size llm_hidden_size = config.text_config.hidden_size @@ -1175,13 +1172,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return InternVLImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) image_token_id = kwargs["image_token_id"] @@ -1189,16 +1182,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, self.img_context_token_id = image_token_id.flatten().unique().item() if pixel_values_flat is not None: - if not isinstance(pixel_values_flat, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat)}") - - if not isinstance(image_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(image_num_patches)}") - - pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) - image_num_patches = flatten_bn(image_num_patches, concat=True) expected_h = expected_w = self.config.vision_config.image_size resolve_bindings = {"h": expected_h, "w": expected_w} @@ -1223,7 +1206,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, if video_embeds is not None: return InternVLVideoEmbeddingInputs( type="video_embeds", - data=flatten_bn(video_embeds), + data=video_embeds, ) video_token_id = kwargs["video_token_id"] @@ -1231,17 +1214,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, self.video_context_token_id = video_token_id.flatten().unique().item() if pixel_values_flat_video is not None: - if not isinstance(pixel_values_flat_video, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat_video)}") - - if not isinstance(video_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(video_num_patches)}") - - pixel_values_flat_video = flatten_bn(pixel_values_flat_video, - concat=True) - video_num_patches = flatten_bn(video_num_patches, concat=True) expected_h = expected_w = self.config.vision_config.image_size resolve_bindings = {"h": expected_h, "w": expected_w} @@ -1254,11 +1226,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, raise AssertionError("This line should be unreachable.") - def _process_image_input( + def _process_vision_input( self, - image_input: Union[InternVLImageInputs, InternVLVideoPixelInputs], + image_input: Union[InternVLImageInputs, InternVLVideoInputs], ) -> tuple[torch.Tensor, ...]: - if image_input["type"] == "image_embeds": + if (image_input["type"] == "image_embeds" + or image_input["type"] == "video_embeds"): return image_input["data"] assert self.vision_model is not None @@ -1326,11 +1299,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, for modality in modalities: if modality == "images": image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) + vision_embeddings = self._process_vision_input(image_input) multimodal_embeddings += vision_embeddings if modality == "videos": video_input = modalities["videos"] - video_embeddings = self._process_image_input(video_input) + video_embeddings = self._process_vision_input(video_input) multimodal_embeddings += video_embeddings return multimodal_embeddings diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 2d0ebdc90277..b1d59f77f59d 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -18,8 +18,7 @@ import torch import torch.nn as nn import torchvision.transforms as T from PIL import Image -from transformers import (BatchEncoding, BatchFeature, PretrainedConfig, - TensorType) +from transformers import BatchFeature, PretrainedConfig, TensorType from vllm.config import VllmConfig from vllm.model_executor.layers.activation import ReLUSquaredActivation @@ -38,8 +37,7 @@ from vllm.model_executor.models.utils import (flatten_bn, maybe_prefix) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, MultiModalKwargsItems, - NestedTensors) + MultiModalKwargs, MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -298,7 +296,7 @@ class BaseNanoNemotronVLProcessor(ABC): else: pixel_values_lst = self._images_to_pixel_values_lst( images, max_num_tiles) - image_inputs: dict[str, NestedTensors] = { + image_inputs = { "pixel_values_flat": torch.cat(pixel_values_lst), "image_num_patches": @@ -326,7 +324,7 @@ class BaseNanoNemotronVLProcessor(ABC): images: Optional[Union[Image.Image, list[Image.Image]]] = None, return_tensors: Optional[Union[str, TensorType]] = None, max_num_tiles: Optional[int] = None, - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: # Use default if not provided if max_num_tiles is None: max_num_tiles = 12 @@ -341,10 +339,9 @@ class BaseNanoNemotronVLProcessor(ABC): text_inputs = self.tokenizer(text, add_special_tokens=False) - return { - **BatchEncoding(text_inputs, tensor_type=return_tensors), - **image_inputs, - } + combined_outputs = {**text_inputs, **image_inputs} + + return BatchFeature(combined_outputs, tensor_type=return_tensors) class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): @@ -420,7 +417,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): dynamic_image_size=dynamic_image_size, ) - video_inputs: dict[str, NestedTensors] = { + video_inputs = { "pixel_values_flat_video": torch.cat(pixel_values_lst_video), "video_num_patches": @@ -443,7 +440,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): return_tensors: Optional[Union[str, TensorType]] = None, max_num_tiles: Optional[int] = None, dynamic_image_size: Optional[bool] = None, - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: # Use default if not provided if max_num_tiles is None: max_num_tiles = 12 @@ -467,11 +464,9 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): text_inputs = self.tokenizer(text, add_special_tokens=False) - return BatchFeature({ - **BatchEncoding(text_inputs, tensor_type=return_tensors), - **image_inputs, - **video_inputs, - }) + combined_outputs = {**text_inputs, **image_inputs, **video_inputs} + + return BatchFeature(combined_outputs, tensor_type=return_tensors) def get_image_repl( self, @@ -625,7 +620,7 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, @@ -645,7 +640,7 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]): def _get_mm_fields_config( self, - hf_inputs: Mapping[str, NestedTensors], + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) @@ -724,7 +719,7 @@ class NanoNemotronVLMultiModalProcessor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: processed_outputs = super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs) @@ -736,7 +731,7 @@ class NanoNemotronVLMultiModalProcessor( def _get_mm_fields_config( self, - hf_inputs: Mapping[str, NestedTensors], + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_fields = super()._get_mm_fields_config(hf_inputs, diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index 0e7ec8e458cf..e6c4c5b022dc 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -28,7 +28,6 @@ from vllm.model_executor.models.internvl import ( from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode -from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.processing import PromptUpdateDetails from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processor import ( @@ -37,8 +36,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix IMG_START = '' IMG_END = '' @@ -289,7 +287,7 @@ class NemotronVLProcessor(InternVLProcessor): max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, ) - image_inputs: dict[str, NestedTensors] = { + image_inputs = { "pixel_values_flat": torch.cat(pixel_values_lst), "image_num_patches": @@ -344,6 +342,7 @@ class NemotronVLProcessingInfo(BaseInternVLProcessingInfo): dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo]) class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + merge_by_field_config = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -414,7 +413,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, return AutoModel.from_config(config.vision_config, trust_remote_code=True) - def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: vit_hidden_size = config.vit_hidden_size vision_projection_hidden_size = config.projector_hidden_size llm_hidden_size = config.text_config.hidden_size @@ -467,13 +466,9 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return InternVLImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) image_token_id = kwargs["image_token_id"] @@ -481,17 +476,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, self.img_context_token_id = image_token_id.flatten().unique().item() if pixel_values_flat is not None: - if not isinstance(pixel_values_flat, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat)}") - - if not isinstance(image_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(image_num_patches)}") - - pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) - image_num_patches = flatten_bn(image_num_patches, concat=True) - return InternVLImagePixelInputs( type="pixel_values", pixel_values_flat=pixel_values_flat, diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index 3bbf4c67604c..0f993b0dc62f 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -159,7 +159,7 @@ class NVLMMultiModalProcessor( dummy_inputs=NVLMDummyInputsBuilder) class NVLM_D_Model(InternVLChatModel): - def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: vit_hidden_size = config.vision_config.hidden_size llm_intermediate_size = config.text_config.intermediate_size llm_hidden_size = config.text_config.hidden_size diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index f03022aa719c..8556c3847041 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -14,7 +14,7 @@ import torch import torch.nn as nn import torchvision.transforms as T from PIL import Image -from transformers import BatchEncoding, PretrainedConfig, TensorType +from transformers import BatchFeature, PretrainedConfig, TensorType from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ReplicatedLinear @@ -25,7 +25,7 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel, from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -37,8 +37,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix IMG_START = '' IMG_END = '' @@ -399,7 +398,7 @@ class SkyworkR1VProcessor: max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: if text is None: text = [] if not isinstance(text, list): @@ -418,7 +417,7 @@ class SkyworkR1VProcessor: max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, ) - image_inputs: dict[str, NestedTensors] = { + image_inputs = { "pixel_values_flat": torch.cat(pixel_values_lst), "image_num_patches": @@ -435,10 +434,9 @@ class SkyworkR1VProcessor: text_inputs = self.tokenizer(text) - return { - **BatchEncoding(text_inputs, tensor_type=return_tensors), - **image_inputs, - } + combined_outputs = {**text_inputs, **image_inputs} + + return BatchFeature(combined_outputs, tensor_type=return_tensors) class SkyworkR1VProcessingInfo(BaseProcessingInfo): @@ -529,7 +527,7 @@ class SkyworkR1VMultiModalProcessor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, @@ -549,7 +547,7 @@ class SkyworkR1VMultiModalProcessor( def _get_mm_fields_config( self, - hf_inputs: Mapping[str, NestedTensors], + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) @@ -617,6 +615,7 @@ class SkyworkR1VMultiModalProcessor( info=SkyworkR1VProcessingInfo, dummy_inputs=SkyworkR1VDummyInputsBuilder) class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -703,7 +702,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): else: return InternVisionPatchModel(config.vision_config) - def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: vit_hidden_size = config.vision_config.hidden_size llm_hidden_size = config.text_config.hidden_size @@ -756,13 +755,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return SkyworkR1VImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) image_token_id = kwargs["image_token_id"] @@ -770,17 +765,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): self.img_context_token_id = image_token_id.flatten().unique().item() if pixel_values_flat is not None: - if not isinstance(pixel_values_flat, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat)}") - - if not isinstance(image_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(image_num_patches)}") - - pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) - image_num_patches = flatten_bn(image_num_patches, concat=True) - return SkyworkR1VImagePixelInputs( type="pixel_values", pixel_values_flat=pixel_values_flat,