[Misc] Rename MultiModalInputsV2 -> MultiModalInputs (#12244)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-21 15:31:19 +08:00 committed by GitHub
parent 2fc6944c5e
commit 96912550c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 31 additions and 31 deletions

View File

@ -43,7 +43,7 @@
```
```{eval-rst}
.. autoclass:: vllm.multimodal.inputs.MultiModalInputsV2
.. autoclass:: vllm.multimodal.inputs.MultiModalInputs
:members:
:show-inheritance:
```

View File

@ -9,7 +9,7 @@ from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never
if TYPE_CHECKING:
from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs,
MultiModalPlaceholderDict)
from vllm.multimodal.inputs import MultiModalInputsV2
from vllm.multimodal.inputs import MultiModalInputs
class TextPrompt(TypedDict):
@ -207,7 +207,7 @@ def token_inputs(
return inputs
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputsV2"]
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputs"]
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
@ -222,14 +222,14 @@ class EncoderDecoderInputs(TypedDict):
This specifies the required data for encoder-decoder models.
"""
encoder: Union[TokenInputs, "MultiModalInputsV2"]
encoder: Union[TokenInputs, "MultiModalInputs"]
"""The inputs for the encoder portion."""
decoder: Union[TokenInputs, "MultiModalInputsV2"]
decoder: Union[TokenInputs, "MultiModalInputs"]
"""The inputs for the decoder portion."""
SingletonInputs = Union[TokenInputs, "MultiModalInputsV2"]
SingletonInputs = Union[TokenInputs, "MultiModalInputs"]
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
@ -311,7 +311,7 @@ class SingletonInputsAdapter:
return inputs.get("multi_modal_hashes", [])
if inputs["type"] == "multimodal":
# only the case when we use MultiModalInputsV2
# only the case when we use MultiModalInputs
return inputs.get("mm_hashes", []) # type: ignore[return-value]
assert_never(inputs) # type: ignore[arg-type]

View File

@ -7,7 +7,7 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputsV2
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
@ -247,7 +247,7 @@ class InputPreprocessor:
mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]],
lora_request: Optional[LoRARequest],
) -> MultiModalInputsV2:
) -> MultiModalInputs:
"""
Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata.
@ -271,7 +271,7 @@ class InputPreprocessor:
mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]],
lora_request: Optional[LoRARequest],
) -> MultiModalInputsV2:
) -> MultiModalInputs:
"""Async version of :meth:`_process_multimodal`."""
tokenizer_group = self.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async(lora_request

View File

@ -15,7 +15,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
MultiModalInputs, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@ -490,7 +490,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
) -> MultiModalInputs:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <image> tokens should be considered as placeholders,

View File

@ -29,7 +29,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
MultiModalInputs, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@ -159,7 +159,7 @@ class ChameleonMultiModalProcessor(
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
) -> MultiModalInputs:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <image> tokens should be considered as placeholders,

View File

@ -31,7 +31,7 @@ from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
MultiModalInputs, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
@ -232,7 +232,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
) -> MultiModalInputs:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only |SPEAKER| (image) tokens should be considered as placeholders,

View File

@ -24,7 +24,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
MultiModalInputs, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
@ -746,7 +746,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
) -> MultiModalInputs:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
@ -805,7 +805,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
for modality, placeholders in mm_placeholders.items()
}
return MultiModalInputsV2(
return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=prompt_ids,

View File

@ -31,7 +31,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
MultiModalInputs, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
@ -484,7 +484,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
) -> MultiModalInputs:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <|image|> tokens should be considered as placeholders,

View File

@ -37,7 +37,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
MultiModalInputs, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
@ -245,7 +245,7 @@ class Qwen2AudioMultiModalProcessor(
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
) -> MultiModalInputs:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <|AUDIO|> tokens should be considered as placeholders,

View File

@ -491,7 +491,7 @@ A dictionary containing placeholder ranges for each modality.
"""
class MultiModalInputsV2(TypedDict):
class MultiModalInputs(TypedDict):
"""
Represents the outputs of
:class:`vllm.multimodal.processing.BaseMultiModalProcessor`,

View File

@ -18,8 +18,8 @@ from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange)
MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem,
PlaceholderRange)
from .parse import MultiModalDataItems, MultiModalDataParser
if TYPE_CHECKING:
@ -609,7 +609,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt: str,
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
) -> MultiModalInputs:
return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
def _get_data_parser(self) -> MultiModalDataParser:
@ -1067,7 +1067,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.
@ -1169,7 +1169,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
for modality, placeholders in mm_placeholders.items()
}
return MultiModalInputsV2(
return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=prompt_ids,

View File

@ -11,7 +11,7 @@ import vllm.envs as envs
from vllm.inputs import DummyData
from vllm.logger import init_logger
from .inputs import MultiModalDataDict, MultiModalInputsV2
from .inputs import MultiModalDataDict, MultiModalInputs
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
logger = init_logger(__name__)
@ -131,7 +131,7 @@ class MultiModalProfiler(Generic[_I]):
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalInputsV2:
) -> MultiModalInputs:
factory = self.dummy_inputs
processor_inputs = factory.get_dummy_processor_inputs(
seq_len, mm_counts)