diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 23b8ef89268d7..edf67c860e977 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -21,12 +21,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors # yapf: disable @@ -415,31 +416,31 @@ class AriaProcessingInfo(BaseProcessingInfo): class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token: str = processor.tokenizer.image_token # type: ignore + + return image_token * num_images + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: vision_config = self.info.get_vision_config() max_image_size = vision_config.image_size num_images = mm_counts.get("image", 0) - mm_data = { + return { "image": self._get_dummy_images(width=max_image_size, height=max_image_size, num_images=num_images) } - hf_processor = self.info.get_hf_processor() - image_token: str = hf_processor.tokenizer.image_token # type: ignore - - return ProcessorInputs( - prompt_text=image_token * num_images, - mm_data=mm_data, - ) - class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]): diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index cdec31602503d..8700f24d2bd25 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -20,7 +20,7 @@ from vllm.jsontree import json_map_leaves from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs +from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -28,7 +28,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, MultiModalFieldConfig, PromptReplacement, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -146,28 +146,29 @@ class AyaVisionProcessingInfo(BaseProcessingInfo): class AyaVisionDummyInputsBuilder( BaseDummyInputsBuilder[AyaVisionProcessingInfo]): - def get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + processor = self.info.get_hf_processor() image_token = processor.image_token + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) image_size = \ self.info.get_image_size_with_most_features() - mm_data = { + return { "image": self._get_dummy_images(width=image_size.width, height=image_size.height, num_images=num_images) } - return ProcessorInputs( - prompt_text=image_token * num_images, - mm_data=mm_data, - ) class AyaVisionMultiModalProcessor( diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index dde78ee52a3de..a6f00f9997730 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -15,12 +15,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptIndexTargets, PromptInsertion, PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .blip import BlipVisionModel @@ -413,29 +414,27 @@ class Blip2ProcessingInfo(BaseProcessingInfo): class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: hf_config = self.info.get_hf_config() vision_config = hf_config.vision_config max_image_size = vision_config.image_size num_images = mm_counts.get("image", 0) - mm_data = { + return { "image": self._get_dummy_images(width=max_image_size, height=max_image_size, num_images=num_images) } - return ProcessorInputs( - prompt_text="", - mm_data=mm_data, - ) - class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index fb2f4b677c5af..0ad5e89df2e25 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -30,12 +30,13 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, @@ -72,28 +73,31 @@ class ChameleonProcessingInfo(BaseProcessingInfo): class ChameleonDummyInputsBuilder( BaseDummyInputsBuilder[ChameleonProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: config = self.info.get_hf_config() width = height = config.vq_config.resolution num_images = mm_counts.get("image", 0) - mm_data = { + return { "image": self._get_dummy_images(width=width, height=height, num_images=num_images) } - return ProcessorInputs( - prompt_text="" * num_images, - mm_data=mm_data, - ) - class ChameleonMultiModalProcessor( BaseMultiModalProcessor[ChameleonProcessingInfo]): diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 951185bc9bd01..c3dbadb292769 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -19,14 +19,14 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, - NestedTensors) +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, MlpProjectorConfig, @@ -172,29 +172,30 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo): class DeepseekVL2DummyInputsBuilder( BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - hf_processor = self.info.get_hf_processor() - image_token: str = hf_processor.image_token max_image_size = self.info.get_image_size_with_most_features() - mm_data = { + return { "image": self._get_dummy_images(width=max_image_size.width, height=max_image_size.height, num_images=num_images) } - return ProcessorInputs( - prompt_text=image_token * num_images, - mm_data=mm_data, - ) - class DeepseekVL2MultiModalProcessor( BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]): diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 56572bd59a35c..359cc7f377310 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -21,13 +21,14 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, BartScaledWordEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs -from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, PromptIndexTargets, PromptInsertion, PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, @@ -772,27 +773,25 @@ class Florence2ProcessingInfo(BaseProcessingInfo): class Florence2DummyInputsBuilder( BaseDummyInputsBuilder[Florence2ProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) target_width = target_height = self.info.get_hf_config().projection_dim - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } - return ProcessorInputs( - prompt_text="", - mm_data=mm_data, - ) - class Florence2MultiModalProcessor( EncDecMultiModalProcessor[Florence2ProcessingInfo]): diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 5fc6bb846388f..27cd8d0986a55 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -31,13 +31,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -125,27 +126,25 @@ class FuyuProcessingInfo(BaseProcessingInfo): class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: target_width, target_height = \ self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } - return ProcessorInputs( - prompt_text="", - mm_data=mm_data, - ) - class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 34d856f4b2037..e5a3d6762fff2 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -15,8 +15,9 @@ from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import MultiModalFieldConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) # yapf: disable @@ -28,7 +29,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, find_mm_placeholders, replace_token_matches) # yapf: enable -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -224,31 +225,31 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): - def get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + processor = self.info.get_hf_processor() image_token = processor.boi_token + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) target_width, target_height = \ self.info.get_image_size_with_most_features() - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } - return ProcessorInputs( - prompt_text=image_token * num_images, - mm_data=mm_data, - ) - class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 02954eecc42cd..4e13716719ace 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -12,7 +12,7 @@ from torch import nn from torch.nn import LayerNorm from torchvision import transforms from torchvision.transforms import InterpolationMode -from transformers import PreTrainedTokenizer, TensorType +from transformers import BatchFeature, PreTrainedTokenizer, TensorType from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput @@ -28,13 +28,13 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, BatchFeature, - MultiModalFieldConfig, - PromptReplacement, PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig @@ -447,31 +447,31 @@ class GLM4VProcessingInfo(BaseProcessingInfo): class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>" + + return base_text * num_images + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: hf_config = self.info.get_hf_config() vision_config = hf_config.vision_config target_width = target_height = vision_config["image_size"] num_images = mm_counts.get("image", 0) - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } - base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>" - - return ProcessorInputs( - prompt_text=base_text * num_images, - mm_data=mm_data, - ) - class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 655db1c856346..c31870461b4c2 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -32,18 +32,18 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) from vllm.multimodal.parse import ImageProcessorItems, ImageSize # yapf conflicts with isort for this block # yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, - MultiModalDataItems, - MultiModalFieldConfig, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) + MultiModalDataItems, PromptReplacement, + PromptUpdate, PromptUpdateDetails) # yapf: enable -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors # yapf: disable @@ -284,29 +284,31 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] ): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token, _, _ = self.info._get_image_token(processor) + + return image_token * num_images + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) hf_processor = self.info.get_hf_processor() image_processor: Idefics3ImageProcessor = hf_processor.image_processor longest_edge = image_processor.max_image_size['longest_edge'] - image_token, _, _ = self.info._get_image_token(hf_processor) - mm_data = { + return { "image": self._get_dummy_images(width=longest_edge, height=longest_edge, num_images=num_images) } - return ProcessorInputs( - prompt_text=image_token * num_images, - mm_data=mm_data, - ) - class Idefics3MultiModalProcessor( BaseMultiModalProcessor[Idefics3ProcessingInfo]): diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 08741b3a3c11e..8f5f454cbf607 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -25,14 +25,14 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, - NestedTensors) +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -504,27 +504,27 @@ _I = TypeVar("_I", bound=BaseInternVLProcessingInfo) class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + return "" * num_images + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: target_width, target_height = \ self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } - return ProcessorInputs( - prompt_text="" * num_images, - mm_data=mm_data, - ) - class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 5804cb4419b6c..fbd212d170044 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -34,7 +34,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, ProcessingCache, PromptReplacement, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .clip import CLIPVisionModel @@ -186,30 +186,31 @@ _I = TypeVar("_I", bound=BaseLlavaProcessingInfo) class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]): - def get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) processor = self.info.get_hf_processor() image_token = processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + target_width, target_height = \ self.info.get_image_size_with_most_features() - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } - return ProcessorInputs( - prompt_text=image_token * num_images, - mm_data=mm_data, - ) - class LlavaProcessingInfo(BaseLlavaProcessingInfo): diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 281c9c0e8ebe3..0221c6b237cbb 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -16,13 +16,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -130,22 +131,27 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo): class LlavaNextVideoDummyInputsBuilder( BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]): - def get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_videos = mm_counts.get("video", 0) processor = self.info.get_hf_processor() video_token = processor.video_token + return video_token * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_videos = mm_counts.get("video", 0) + target_width, target_height = \ self.info.get_image_size_with_most_features() target_num_frames = \ self.info.get_num_frames_with_most_features(seq_len, mm_counts) - mm_data = { + return { "video": self._get_dummy_videos( width=target_width, @@ -155,11 +161,6 @@ class LlavaNextVideoDummyInputsBuilder( ) } - return ProcessorInputs( - prompt_text=video_token * num_videos, - mm_data=mm_data, - ) - class LlavaNextVideoMultiModalProcessor( BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]): diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index f6256771d9828..60d32c924694c 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -19,11 +19,11 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems) from vllm.multimodal.processing import PromptReplacement, PromptUpdate -from vllm.multimodal.profiling import ProcessorInputs from vllm.sequence import IntermediateTensors from .clip import CLIPVisionModel @@ -226,11 +226,7 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): class LlavaOnevisionDummyInputsBuilder( LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]): - def get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -238,13 +234,23 @@ class LlavaOnevisionDummyInputsBuilder( image_token = processor.image_token video_token = processor.video_token + return image_token * num_images + video_token * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + target_width, target_height = \ self.info.get_image_size_with_most_features() target_num_frames = \ self.info.get_num_frames_with_most_features(seq_len, mm_counts) - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, @@ -258,11 +264,6 @@ class LlavaOnevisionDummyInputsBuilder( ) } - return ProcessorInputs( - prompt_text=image_token * num_images + video_token * num_videos, - mm_data=mm_data, - ) - class LlavaOnevisionMultiModalProcessor( BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]): diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index 8bb41a108b5a9..29c3cc5e769b3 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -35,14 +35,14 @@ from transformers.models.whisper.modeling_whisper import ( from vllm.config import VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + NestedTensors) from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, DictEmbeddingItems, ModalityData, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import ProcessorInputs from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, MiniCPMVDummyInputsBuilder, @@ -206,29 +206,31 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): class MiniCPMODummyInputsBuilder( MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]): - def get_dummy_processor_inputs( - self, seq_len: int, mm_counts: Mapping[str, - int]) -> ProcessorInputs: + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + + audio_prompt_texts = self.info.audio_pattern * num_audios + + return super().get_dummy_text(mm_counts) + audio_prompt_texts + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) audio_len = self.info.get_max_audio_chunks_with_most_features() * \ self.info.get_default_audio_sampling_rate() - processor_inputs = super().get_dummy_processor_inputs( - seq_len, mm_counts) - - audio_prompt_texts = self.info.audio_pattern * num_audios audio_mm_data = { "audio": self._get_dummy_audios(length=audio_len, num_audios=num_audios) } - return ProcessorInputs( - prompt_text=processor_inputs.prompt_text + audio_prompt_texts, - mm_data={ - **processor_inputs.mm_data, - **audio_mm_data, - }, - ) + return { + **super().get_dummy_mm_data(seq_len, mm_counts), + **audio_mm_data, + } class MiniCPMOMultiModalProcessor( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 87c6902195831..c504737e1b335 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -48,7 +48,8 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + NestedTensors) from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageProcessorItems, ImageSize, ModalityData, ModalityDataItems, @@ -57,7 +58,7 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import flatten_2d_lists @@ -471,11 +472,20 @@ _I = TypeVar("_I", class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + image_prompt_texts = self.info.image_pattern * num_images + video_prompt_texts = self.info.video_pattern * num_videos + + return image_prompt_texts + video_prompt_texts + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -486,7 +496,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): num_video_frames = \ self.info.get_num_frames_with_most_features(seq_len, mm_counts) - mm_data = { + return { "image": self._get_dummy_images(width=image_width, height=image_height, @@ -498,13 +508,6 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): ] * num_videos, } - image_prompt_texts = self.info.image_pattern * num_images - video_prompt_texts = self.info.video_pattern * num_videos - - return ProcessorInputs(prompt_text=image_prompt_texts + - video_prompt_texts, - mm_data=mm_data) - class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index d2c600feb4b29..8b1a1d68fc3fa 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -22,14 +22,15 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, ProcessingCache, PromptReplacement, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -185,30 +186,31 @@ _I = TypeVar("_I", bound=BaseLlavaProcessingInfo) class Mistral3DummyInputsBuilder(BaseDummyInputsBuilder[_I]): - def get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) processor = self.info.get_hf_processor() image_token = processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + target_width, target_height = \ self.info.get_image_size_with_most_features() - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } - return ProcessorInputs( - prompt_text=image_token * num_images, - mm_data=mm_data, - ) - class Mistral3ProcessingInfo(BaseLlavaProcessingInfo): diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index b61e42f31d88b..251d95e41dc3d 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -54,14 +54,14 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalEncDecInputs, +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalFieldConfig, MultiModalKwargs) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataDict, MultiModalDataItems) + MultiModalDataItems) from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, PromptReplacement, PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from .clip import CLIPMLP from .interfaces import SupportsMultiModal, SupportsV0Only @@ -131,31 +131,31 @@ class MllamaProcessingInfo(BaseProcessingInfo): class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) target_width, target_height = \ self.info.get_image_size_with_most_features() - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } - hf_processor = self.info.get_hf_processor() - image_token: str = hf_processor.image_token - - return ProcessorInputs( - prompt_text=image_token * num_images, - mm_data=mm_data, - ) - class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] ): diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 4f709751ae629..0966f546ddf90 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -43,14 +43,14 @@ from vllm.model_executor.model_loader.loader import _initialize_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, - NestedTensors) +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -619,29 +619,31 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.fake_image_token + + return image_token * num_images + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) (target_width, target_height) = self.info.get_image_size_with_most_features() - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } - image_token = self.info.get_hf_processor().fake_image_token - return ProcessorInputs( - prompt_text=image_token * num_images, - mm_data=mm_data, - ) - @MULTIMODAL_REGISTRY.register_processor( Mllama4MultiModalProcessor, diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index d896431b166b2..d75845b45e733 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -41,14 +41,15 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptIndexTargets, PromptInsertion, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -1216,27 +1217,25 @@ class MolmoProcessingInfo(BaseProcessingInfo): class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: target_width, target_height = \ self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } - return ProcessorInputs( - prompt_text="", - mm_data=mm_data, - ) - class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index 314f75c203012..62a7deab6a10c 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -15,12 +15,11 @@ from transformers import PretrainedConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs +from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import ProcessorInputs from .intern_vit import InternVisionModel from .internvl import (BaseInternVLProcessingInfo, BaseInternVLProcessor, @@ -87,29 +86,29 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo): class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + # The newline is necessary to separate ">" of the current item + # and "<" of the next item + return "\n" * num_images + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: target_width, target_height = \ self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } - return ProcessorInputs( - # The newline is necessary to separate ">" of the current item - # and "<" of the next item - prompt_text="\n" * num_images, - mm_data=mm_data, - ) - class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]): diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index ae8eee4515e04..6c1bd499f6398 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -19,7 +19,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptIndexTargets, PromptInsertion, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -90,29 +90,27 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo): class PaliGemmaDummyInputsBuilder( BaseDummyInputsBuilder[PaliGemmaProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: hf_config = self.info.get_hf_config() vision_config = hf_config.vision_config max_image_size = vision_config.image_size num_images = mm_counts.get("image", 0) - mm_data = { + return { "image": self._get_dummy_images(width=max_image_size, height=max_image_size, num_images=num_images) } - return ProcessorInputs( - prompt_text="", - mm_data=mm_data, - ) - class PaliGemmaMultiModalProcessor( BaseMultiModalProcessor[PaliGemmaProcessingInfo]): diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index cce700f02f597..7f41ad2359df6 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -32,7 +32,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) # yapf conflicts with isort for this block @@ -42,7 +43,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PlaceholderFeaturesInfo, PromptReplacement, PromptUpdate) # yapf: enable -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -343,31 +344,31 @@ class Phi3VProcessingInfo(BaseProcessingInfo): class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + hf_processor = self.info.get_hf_processor() + image_tokens: list[str] = hf_processor.img_tokens # type: ignore + + return "".join(image_tokens[:num_images]) + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) target_width, target_height = \ self.info.get_image_size_with_most_features() - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } - hf_processor = self.info.get_hf_processor() - image_tokens: list[str] = hf_processor.img_tokens # type: ignore - - return ProcessorInputs( - prompt_text="".join(image_tokens[:num_images]), - mm_data=mm_data, - ) - class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index fdd342ccf6b56..ee1e7713e90e2 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -32,13 +32,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + NestedTensors) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import (MistralTokenizer, cached_tokenizer_from_config) @@ -203,28 +204,26 @@ class PixtralProcessingInfo(BaseProcessingInfo): class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) target_width, target_height = \ self.info.get_image_size_with_most_features() - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } - return ProcessorInputs( - prompt_text="", - mm_data=mm_data, - ) - class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] ): diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index e3a93e95530c3..c10ef45440b11 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -35,7 +35,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import (IntermediateTensors, PoolerOutput, PoolingSequenceGroupOutput) @@ -49,20 +49,21 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): class PrithviGeoSpatialMAEInputBuilder( BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: - return ProcessorInputs( - prompt_text="", - # This model input is fixed and is in the form of a torch Tensor. - # The size of pixel_values might change in the cases where we resize - # the input but never exceeds the dimensions below. - mm_data={ - "pixel_values": torch.full((1, 6, 512, 512), 1.0), - "location_coords": torch.full((1, 2), 1.0) - }) + ) -> MultiModalDataDict: + # This model input is fixed and is in the form of a torch Tensor. + # The size of pixel_values might change in the cases where we resize + # the input but never exceeds the dimensions below. + return { + "pixel_values": torch.full((1, 6, 512, 512), 1.0), + "location_coords": torch.full((1, 2), 1.0), + } class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index ba4646f5583f9..280cda0f68f1a 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -37,13 +37,14 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -113,27 +114,30 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo): class Qwen2AudioDummyInputsBuilder( BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + + hf_processor = self.info.get_hf_processor() + audio_token = hf_processor.audio_token + + return audio_token * num_audios + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: feature_extractor = self.info.get_feature_extractor() sampling_rate = feature_extractor.sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate num_audios = mm_counts.get("audio", 0) - mm_data = { + return { "audio": self._get_dummy_audios(length=audio_len, num_audios=num_audios) } - return ProcessorInputs( - prompt_text="<|AUDIO|>" * num_audios, - mm_data=mm_data, - ) - class Qwen2AudioMultiModalProcessor( BaseMultiModalProcessor[Qwen2AudioProcessingInfo]): diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 23f27e7ef9fb0..11950f78f1d25 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -56,15 +56,15 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (ImageItem, ModalityData, - MultiModalFieldConfig, MultiModalKwargs, - VideoItem) + MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, VideoItem) from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -965,11 +965,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): - def get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -977,12 +973,22 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): image_token: str = hf_processor.image_token video_token: str = hf_processor.video_token + return image_token * num_images + video_token * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + target_width, target_height = \ self.info.get_image_size_with_most_features() target_num_frames = \ self.info.get_num_frames_with_most_features(seq_len, mm_counts) - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, @@ -996,11 +1002,6 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): ) } - return ProcessorInputs( - prompt_text=image_token * num_images + video_token * num_videos, - mm_data=mm_data, - ) - class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ): diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 403d47a39d175..9f370d7aab4e4 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -32,12 +32,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -542,34 +543,34 @@ class QwenVLProcessingInfo(BaseProcessingInfo): class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + hf_processor = self.info.get_hf_processor() + img_start = hf_processor.image_start_tag + img_end = hf_processor.image_end_tag + + return "".join(f"Picture {i}: {img_start}{img_end}\n" + for i in range(1, num_images + 1)) + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: hf_config = self.info.get_hf_config() vision_config = hf_config.visual - processor = self.info.get_hf_processor() - img_start = processor.image_start_tag - img_end = processor.image_end_tag - target_width = target_height = vision_config["image_size"] num_images = mm_counts.get("image", 0) - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } - return ProcessorInputs( - prompt_text="".join(f"Picture {i}: {img_start}{img_end}\n" - for i in range(1, num_images + 1)), - mm_data=mm_data, - ) - class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 09a212a9face0..19a23162aa840 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -26,14 +26,14 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, - NestedTensors) +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -505,27 +505,27 @@ _I = TypeVar("_I", bound=BaseSkyworkR1VProcessingInfo) class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + return "" * num_images + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: target_width, target_height = \ self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } - return ProcessorInputs( - prompt_text="" * num_images, - mm_data=mm_data, - ) - class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]): diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 3ff5a0516b65e..cb5ff4ed6365b 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -23,13 +23,13 @@ from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, - NestedTensors) +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig @@ -110,11 +110,16 @@ class UltravoxProcessingInfo(BaseProcessingInfo): class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] ): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + + return "<|audio|>" * num_audios + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: feature_extractor = self.info.get_feature_extractor() sampling_rate = feature_extractor.sampling_rate @@ -122,16 +127,11 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] _MAX_ENCODER_BATCH_SIZE) num_audios = mm_counts.get("audio", 0) - mm_data = { + return { "audio": self._get_dummy_audios(length=audio_len, num_audios=num_audios) } - return ProcessorInputs( - prompt_text="<|audio|>" * num_audios, - mm_data=mm_data, - ) - class UltravoxMultiModalProcessor( BaseMultiModalProcessor[UltravoxProcessingInfo]): diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 341e22a4a8bb1..63e71f2688057 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -26,13 +26,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs -from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems, - MultiModalDataParser) +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, PromptReplacement, PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription, SupportsV0Only) @@ -544,27 +544,27 @@ class WhisperProcessingInfo(BaseProcessingInfo): class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]): - def get_dummy_processor_inputs( + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + + return "<|startoftranscript|>" * num_audios + + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> ProcessorInputs: + ) -> MultiModalDataDict: feature_extractor = self.info.get_feature_extractor() sampling_rate = feature_extractor.sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate num_audios = mm_counts.get("audio", 0) - mm_data = { + return { "audio": self._get_dummy_audios(length=audio_len, num_audios=num_audios) } - return ProcessorInputs( - prompt_text="<|startoftranscript|>" * num_audios, - mm_data=mm_data, - ) - class WhisperMultiModalProcessor( EncDecMultiModalProcessor[WhisperProcessingInfo]): diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index fefeefd21375e..f531314abedc7 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1051,12 +1051,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): *, cache: Optional[ProcessingCache] = None, enable_sanity_checks: bool = True) -> None: - if get_repls := getattr(self, "_get_prompt_replacements", None): - logger.warning_once("`_get_prompt_replacements` has been renamed " - "to `_get_prompt_updates`. The old name will " - "be removed in an upcoming release.") - self._get_prompt_updates = get_repls # type: ignore[method-assign] - super().__init__() self.info = info @@ -1274,13 +1268,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): """ mm_counts = mm_items.get_all_counts() - dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs( - self.info.ctx.model_config.max_model_len, - mm_counts, - ) - _, mm_kwargs, _ = self._apply_hf_processor_text_mm( - prompt_text=dummy_inputs.prompt_text, + prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, ) diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 7efe86448fdd0..29de9b7cda03c 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from abc import ABC, abstractmethod +from abc import ABC from collections.abc import Mapping from dataclasses import dataclass, field from typing import Generic, NamedTuple, Optional, TypeVar, cast @@ -60,7 +60,35 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): self.info = info - @abstractmethod + # TODO: @abstractmethod after transition + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + """ + Build the text input corresponding to :code:`mm_counts`. + """ + if (type(self).get_dummy_processor_inputs == + BaseDummyInputsBuilder.get_dummy_processor_inputs): + raise NotImplementedError + + logger.warning_once("`get_dummy_processor_inputs` has been split up " + "into `get_dummy_text` and `get_dummy_mm_data`. " + "These two methods will be marked as abstract " + "in an upcoming release.") + + seq_len = self.info.ctx.model_config.max_model_len + return self.get_dummy_processor_inputs(seq_len, mm_counts).prompt_text + + # TODO: @abstractmethod after transition + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + """ + Build the multimodal input which, after processing, results in + the maximum possible number of placeholder tokens. + """ + raise NotImplementedError + def get_dummy_processor_inputs( self, seq_len: int, @@ -70,7 +98,10 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): Build the input which, after processing, results in the maximum possible number of placeholder tokens. """ - raise NotImplementedError + dummy_text = self.get_dummy_text(mm_counts) + dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) + + return ProcessorInputs(prompt_text=dummy_text, mm_data=dummy_mm_data) def _get_dummy_audios( self,