[Frontend] decrease import time of vllm.multimodal (#18031)

Co-authored-by: Aaron Pham <Aaronpham0103@gmail.com>
This commit is contained in:
David Xia 2025-05-14 18:43:32 -04:00 committed by GitHub
parent 856865008e
commit 749f792553
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 45 additions and 34 deletions

View File

@ -10,40 +10,43 @@ from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
Union, cast, final) Union, cast, final)
import numpy as np import numpy as np
import torch
import torch.types
from PIL.Image import Image
from transformers import BatchFeature
from typing_extensions import NotRequired, TypeAlias from typing_extensions import NotRequired, TypeAlias
from vllm.jsontree import JSONTree, json_map_leaves from vllm.jsontree import JSONTree, json_map_leaves
from vllm.utils import full_groupby, is_list_of from vllm.utils import LazyLoader, full_groupby, is_list_of
if TYPE_CHECKING: if TYPE_CHECKING:
import torch
import torch.types
from PIL.Image import Image
from transformers.feature_extraction_utils import BatchFeature
from .hasher import MultiModalHashDict from .hasher import MultiModalHashDict
else:
torch = LazyLoader("torch", globals(), "torch")
_T = TypeVar("_T") _T = TypeVar("_T")
HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor] HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
""" """
A {class}`transformers.image_utils.ImageInput` representing a single image A {class}`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace `ImageProcessor`. item, which can be passed to a HuggingFace `ImageProcessor`.
""" """
HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor, HfVideoItem: TypeAlias = Union[list["Image"], np.ndarray, "torch.Tensor",
list[np.ndarray], list[torch.Tensor]] list[np.ndarray], list["torch.Tensor"]]
""" """
A {class}`transformers.image_utils.VideoInput` representing a single video A {class}`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace `VideoProcessor`. item, which can be passed to a HuggingFace `VideoProcessor`.
""" """
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor] HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
""" """
Represents a single audio Represents a single audio
item, which can be passed to a HuggingFace `AudioProcessor`. item, which can be passed to a HuggingFace `AudioProcessor`.
""" """
ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor] ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
""" """
A {class}`transformers.image_utils.ImageInput` representing a single image A {class}`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace `ImageProcessor`. item, which can be passed to a HuggingFace `ImageProcessor`.
@ -53,7 +56,7 @@ which are treated as image embeddings;
these are directly passed to the model without HF processing. these are directly passed to the model without HF processing.
""" """
VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor] VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor"]
""" """
A {class}`transformers.image_utils.VideoInput` representing a single video A {class}`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace `VideoProcessor`. item, which can be passed to a HuggingFace `VideoProcessor`.
@ -64,7 +67,7 @@ these are directly passed to the model without HF processing.
""" """
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
torch.Tensor] "torch.Tensor"]
""" """
Represents a single audio Represents a single audio
item, which can be passed to a HuggingFace `AudioProcessor`. item, which can be passed to a HuggingFace `AudioProcessor`.
@ -132,7 +135,7 @@ class PlaceholderRange:
length: int length: int
"""The length of the placeholder.""" """The length of the placeholder."""
is_embed: Optional[torch.Tensor] = None is_embed: Optional["torch.Tensor"] = None
""" """
A boolean mask of shape `(length,)` indicating which positions A boolean mask of shape `(length,)` indicating which positions
between `offset` and `offset + length` to assign embeddings to. between `offset` and `offset + length` to assign embeddings to.
@ -158,8 +161,8 @@ class PlaceholderRange:
return nested_tensors_equal(self.is_embed, other.is_embed) return nested_tensors_equal(self.is_embed, other.is_embed)
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor, NestedTensors: TypeAlias = Union[list["NestedTensors"], list["torch.Tensor"],
tuple[torch.Tensor, ...]] "torch.Tensor", tuple["torch.Tensor", ...]]
""" """
Uses a list instead of a tensor if the dimensions of each element do not match. Uses a list instead of a tensor if the dimensions of each element do not match.
""" """
@ -261,7 +264,7 @@ class BaseMultiModalField(ABC):
""" """
Construct {class}`MultiModalFieldElem` instances to represent Construct {class}`MultiModalFieldElem` instances to represent
the provided data. the provided data.
This is the inverse of {meth}`reduce_data`. This is the inverse of {meth}`reduce_data`.
""" """
raise NotImplementedError raise NotImplementedError
@ -422,7 +425,7 @@ class MultiModalFieldConfig:
modality: The modality of the multi-modal item that uses this modality: The modality of the multi-modal item that uses this
keyword argument. keyword argument.
slices: For each multi-modal item, a slice (dim=0) or a tuple of slices: For each multi-modal item, a slice (dim=0) or a tuple of
slices (dim>0) that is used to extract the data corresponding slices (dim>0) that is used to extract the data corresponding
to it. to it.
dim: The dimension to extract data, default to 0. dim: The dimension to extract data, default to 0.
@ -465,7 +468,7 @@ class MultiModalFieldConfig:
@staticmethod @staticmethod
def flat_from_sizes(modality: str, def flat_from_sizes(modality: str,
size_per_item: torch.Tensor, size_per_item: "torch.Tensor",
dim: int = 0): dim: int = 0):
""" """
Defines a field where an element in the batch is obtained by Defines a field where an element in the batch is obtained by
@ -602,7 +605,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
@staticmethod @staticmethod
def from_hf_inputs( def from_hf_inputs(
hf_inputs: BatchFeature, hf_inputs: "BatchFeature",
config_by_key: Mapping[str, MultiModalFieldConfig], config_by_key: Mapping[str, MultiModalFieldConfig],
): ):
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key` # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
@ -792,7 +795,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
return self._items_by_modality[modality] return self._items_by_modality[modality]
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]] MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
""" """
A dictionary containing placeholder ranges for each modality. A dictionary containing placeholder ranges for each modality.
""" """
@ -823,7 +826,7 @@ class MultiModalInputs(TypedDict):
mm_hashes: Optional["MultiModalHashDict"] mm_hashes: Optional["MultiModalHashDict"]
"""The hashes of the multi-modal data.""" """The hashes of the multi-modal data."""
mm_placeholders: MultiModalPlaceholderDict mm_placeholders: "MultiModalPlaceholderDict"
""" """
For each modality, information about the placeholder tokens in For each modality, information about the placeholder tokens in
`prompt_token_ids`. `prompt_token_ids`.

View File

@ -8,11 +8,9 @@ from typing import (TYPE_CHECKING, Any, Generic, Literal, NamedTuple, Optional,
import numpy as np import numpy as np
import torch import torch
from PIL.Image import Image
from transformers import BatchFeature
from typing_extensions import TypeAlias, TypeGuard, assert_never from typing_extensions import TypeAlias, TypeGuard, assert_never
from vllm.utils import is_list_of from vllm.utils import LazyLoader, is_list_of
from .audio import AudioResampler from .audio import AudioResampler
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
@ -22,6 +20,11 @@ from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
_T = TypeVar("_T") _T = TypeVar("_T")
_I = TypeVar("_I") _I = TypeVar("_I")
if TYPE_CHECKING:
import PIL.Image as PILImage
else:
PILImage = LazyLoader("PILImage", globals(), "PIL.Image")
class ModalityDataItems(ABC, Generic[_T, _I]): class ModalityDataItems(ABC, Generic[_T, _I]):
""" """
@ -131,6 +134,8 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
Mapping[str, MultiModalFieldConfig], Mapping[str, MultiModalFieldConfig],
], ],
) -> None: ) -> None:
from transformers.feature_extraction_utils import BatchFeature
super().__init__(data, modality) super().__init__(data, modality)
missing_required_data_keys = required_fields - data.keys() missing_required_data_keys = required_fields - data.keys()
@ -200,7 +205,7 @@ class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
def get_image_size(self, item_idx: int) -> ImageSize: def get_image_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx) image = self.get(item_idx)
if isinstance(image, Image): if isinstance(image, PILImage.Image):
return ImageSize(*image.size) return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)): if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape _, h, w = image.shape
@ -226,7 +231,7 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
def get_frame_size(self, item_idx: int) -> ImageSize: def get_frame_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx)[0] # Assume that the video isn't empty image = self.get(item_idx)[0] # Assume that the video isn't empty
if isinstance(image, Image): if isinstance(image, PILImage.Image):
return ImageSize(*image.size) return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)): if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape _, h, w = image.shape
@ -253,7 +258,7 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
def get_count(self, modality: str, *, strict: bool = True) -> int: def get_count(self, modality: str, *, strict: bool = True) -> int:
""" """
Get the number of data items belonging to a modality. Get the number of data items belonging to a modality.
If `strict=False`, return `0` instead of raising {exc}`KeyError` If `strict=False`, return `0` instead of raising {exc}`KeyError`
even if the modality is not found. even if the modality is not found.
""" """
@ -399,7 +404,7 @@ class MultiModalDataParser:
if self._is_embeddings(data): if self._is_embeddings(data):
return ImageEmbeddingItems(data) return ImageEmbeddingItems(data)
if (isinstance(data, Image) if (isinstance(data, PILImage.Image)
or isinstance(data, or isinstance(data,
(np.ndarray, torch.Tensor)) and data.ndim == 3): (np.ndarray, torch.Tensor)) and data.ndim == 3):
data_items = [data] data_items = [data]
@ -420,7 +425,7 @@ class MultiModalDataParser:
if self._is_embeddings(data): if self._is_embeddings(data):
return VideoEmbeddingItems(data) return VideoEmbeddingItems(data)
if (is_list_of(data, Image) if (is_list_of(data, PILImage.Image)
or isinstance(data, or isinstance(data,
(np.ndarray, torch.Tensor)) and data.ndim == 4): (np.ndarray, torch.Tensor)) and data.ndim == 4):
data_items = [data] data_items = [data]

View File

@ -13,7 +13,6 @@ from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union, cast) TypeVar, Union, cast)
import torch import torch
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
@ -31,6 +30,10 @@ from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.configuration_utils import PretrainedConfig
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from .profiling import BaseDummyInputsBuilder from .profiling import BaseDummyInputsBuilder
logger = init_logger(__name__) logger = init_logger(__name__)
@ -1047,10 +1050,10 @@ class BaseProcessingInfo:
def get_tokenizer(self) -> AnyTokenizer: def get_tokenizer(self) -> AnyTokenizer:
return self.ctx.tokenizer return self.ctx.tokenizer
def get_hf_config(self) -> PretrainedConfig: def get_hf_config(self) -> "PretrainedConfig":
return self.ctx.get_hf_config() return self.ctx.get_hf_config()
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin: def get_hf_processor(self, **kwargs: object) -> "ProcessorMixin":
""" """
Subclasses can override this method to handle Subclasses can override this method to handle
specific kwargs from model config or user inputs. specific kwargs from model config or user inputs.
@ -1165,7 +1168,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
@abstractmethod @abstractmethod
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: "BatchFeature",
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
"""Given the HF-processed data, output the metadata of each field.""" """Given the HF-processed data, output the metadata of each field."""
@ -1222,7 +1225,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# This refers to the data to be passed to HF processor. # This refers to the data to be passed to HF processor.
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> "BatchFeature":
""" """
Call the HF processor on the prompt text and Call the HF processor on the prompt text and
associated multi-modal data. associated multi-modal data.