mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:45:01 +08:00
[Frontend] decrease import time of vllm.multimodal (#18031)
Co-authored-by: Aaron Pham <Aaronpham0103@gmail.com>
This commit is contained in:
parent
856865008e
commit
749f792553
@ -10,40 +10,43 @@ from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
|
||||
Union, cast, final)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.types
|
||||
from PIL.Image import Image
|
||||
from transformers import BatchFeature
|
||||
from typing_extensions import NotRequired, TypeAlias
|
||||
|
||||
from vllm.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.utils import full_groupby, is_list_of
|
||||
from vllm.utils import LazyLoader, full_groupby, is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
import torch.types
|
||||
from PIL.Image import Image
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
from .hasher import MultiModalHashDict
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
|
||||
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
|
||||
"""
|
||||
A {class}`transformers.image_utils.ImageInput` representing a single image
|
||||
item, which can be passed to a HuggingFace `ImageProcessor`.
|
||||
"""
|
||||
|
||||
HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor,
|
||||
list[np.ndarray], list[torch.Tensor]]
|
||||
HfVideoItem: TypeAlias = Union[list["Image"], np.ndarray, "torch.Tensor",
|
||||
list[np.ndarray], list["torch.Tensor"]]
|
||||
"""
|
||||
A {class}`transformers.image_utils.VideoInput` representing a single video
|
||||
item, which can be passed to a HuggingFace `VideoProcessor`.
|
||||
"""
|
||||
|
||||
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor]
|
||||
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
|
||||
"""
|
||||
Represents a single audio
|
||||
item, which can be passed to a HuggingFace `AudioProcessor`.
|
||||
"""
|
||||
|
||||
ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor]
|
||||
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
|
||||
"""
|
||||
A {class}`transformers.image_utils.ImageInput` representing a single image
|
||||
item, which can be passed to a HuggingFace `ImageProcessor`.
|
||||
@ -53,7 +56,7 @@ which are treated as image embeddings;
|
||||
these are directly passed to the model without HF processing.
|
||||
"""
|
||||
|
||||
VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor]
|
||||
VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor"]
|
||||
"""
|
||||
A {class}`transformers.image_utils.VideoInput` representing a single video
|
||||
item, which can be passed to a HuggingFace `VideoProcessor`.
|
||||
@ -64,7 +67,7 @@ these are directly passed to the model without HF processing.
|
||||
"""
|
||||
|
||||
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
|
||||
torch.Tensor]
|
||||
"torch.Tensor"]
|
||||
"""
|
||||
Represents a single audio
|
||||
item, which can be passed to a HuggingFace `AudioProcessor`.
|
||||
@ -132,7 +135,7 @@ class PlaceholderRange:
|
||||
length: int
|
||||
"""The length of the placeholder."""
|
||||
|
||||
is_embed: Optional[torch.Tensor] = None
|
||||
is_embed: Optional["torch.Tensor"] = None
|
||||
"""
|
||||
A boolean mask of shape `(length,)` indicating which positions
|
||||
between `offset` and `offset + length` to assign embeddings to.
|
||||
@ -158,8 +161,8 @@ class PlaceholderRange:
|
||||
return nested_tensors_equal(self.is_embed, other.is_embed)
|
||||
|
||||
|
||||
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor,
|
||||
tuple[torch.Tensor, ...]]
|
||||
NestedTensors: TypeAlias = Union[list["NestedTensors"], list["torch.Tensor"],
|
||||
"torch.Tensor", tuple["torch.Tensor", ...]]
|
||||
"""
|
||||
Uses a list instead of a tensor if the dimensions of each element do not match.
|
||||
"""
|
||||
@ -465,7 +468,7 @@ class MultiModalFieldConfig:
|
||||
|
||||
@staticmethod
|
||||
def flat_from_sizes(modality: str,
|
||||
size_per_item: torch.Tensor,
|
||||
size_per_item: "torch.Tensor",
|
||||
dim: int = 0):
|
||||
"""
|
||||
Defines a field where an element in the batch is obtained by
|
||||
@ -602,7 +605,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
|
||||
@staticmethod
|
||||
def from_hf_inputs(
|
||||
hf_inputs: BatchFeature,
|
||||
hf_inputs: "BatchFeature",
|
||||
config_by_key: Mapping[str, MultiModalFieldConfig],
|
||||
):
|
||||
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
|
||||
@ -792,7 +795,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
return self._items_by_modality[modality]
|
||||
|
||||
|
||||
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
|
||||
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
|
||||
"""
|
||||
A dictionary containing placeholder ranges for each modality.
|
||||
"""
|
||||
@ -823,7 +826,7 @@ class MultiModalInputs(TypedDict):
|
||||
mm_hashes: Optional["MultiModalHashDict"]
|
||||
"""The hashes of the multi-modal data."""
|
||||
|
||||
mm_placeholders: MultiModalPlaceholderDict
|
||||
mm_placeholders: "MultiModalPlaceholderDict"
|
||||
"""
|
||||
For each modality, information about the placeholder tokens in
|
||||
`prompt_token_ids`.
|
||||
|
||||
@ -8,11 +8,9 @@ from typing import (TYPE_CHECKING, Any, Generic, Literal, NamedTuple, Optional,
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import BatchFeature
|
||||
from typing_extensions import TypeAlias, TypeGuard, assert_never
|
||||
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils import LazyLoader, is_list_of
|
||||
|
||||
from .audio import AudioResampler
|
||||
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
|
||||
@ -22,6 +20,11 @@ from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
|
||||
_T = TypeVar("_T")
|
||||
_I = TypeVar("_I")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL.Image as PILImage
|
||||
else:
|
||||
PILImage = LazyLoader("PILImage", globals(), "PIL.Image")
|
||||
|
||||
|
||||
class ModalityDataItems(ABC, Generic[_T, _I]):
|
||||
"""
|
||||
@ -131,6 +134,8 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
|
||||
Mapping[str, MultiModalFieldConfig],
|
||||
],
|
||||
) -> None:
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
super().__init__(data, modality)
|
||||
|
||||
missing_required_data_keys = required_fields - data.keys()
|
||||
@ -200,7 +205,7 @@ class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
|
||||
def get_image_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.get(item_idx)
|
||||
|
||||
if isinstance(image, Image):
|
||||
if isinstance(image, PILImage.Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
@ -226,7 +231,7 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
|
||||
def get_frame_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.get(item_idx)[0] # Assume that the video isn't empty
|
||||
|
||||
if isinstance(image, Image):
|
||||
if isinstance(image, PILImage.Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
@ -399,7 +404,7 @@ class MultiModalDataParser:
|
||||
if self._is_embeddings(data):
|
||||
return ImageEmbeddingItems(data)
|
||||
|
||||
if (isinstance(data, Image)
|
||||
if (isinstance(data, PILImage.Image)
|
||||
or isinstance(data,
|
||||
(np.ndarray, torch.Tensor)) and data.ndim == 3):
|
||||
data_items = [data]
|
||||
@ -420,7 +425,7 @@ class MultiModalDataParser:
|
||||
if self._is_embeddings(data):
|
||||
return VideoEmbeddingItems(data)
|
||||
|
||||
if (is_list_of(data, Image)
|
||||
if (is_list_of(data, PILImage.Image)
|
||||
or isinstance(data,
|
||||
(np.ndarray, torch.Tensor)) and data.ndim == 4):
|
||||
data_items = [data]
|
||||
|
||||
@ -13,7 +13,6 @@ from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
|
||||
TypeVar, Union, cast)
|
||||
|
||||
import torch
|
||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.inputs import InputProcessingContext
|
||||
@ -31,6 +30,10 @@ from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
from .profiling import BaseDummyInputsBuilder
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -1047,10 +1050,10 @@ class BaseProcessingInfo:
|
||||
def get_tokenizer(self) -> AnyTokenizer:
|
||||
return self.ctx.tokenizer
|
||||
|
||||
def get_hf_config(self) -> PretrainedConfig:
|
||||
def get_hf_config(self) -> "PretrainedConfig":
|
||||
return self.ctx.get_hf_config()
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
|
||||
def get_hf_processor(self, **kwargs: object) -> "ProcessorMixin":
|
||||
"""
|
||||
Subclasses can override this method to handle
|
||||
specific kwargs from model config or user inputs.
|
||||
@ -1165,7 +1168,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
@abstractmethod
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_inputs: "BatchFeature",
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
"""Given the HF-processed data, output the metadata of each field."""
|
||||
@ -1222,7 +1225,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
# This refers to the data to be passed to HF processor.
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
) -> "BatchFeature":
|
||||
"""
|
||||
Call the HF processor on the prompt text and
|
||||
associated multi-modal data.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user