[Misc] Add placeholder module (#11501)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-26 21:12:51 +08:00 committed by GitHub
parent f57ee5650d
commit eec906d811
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 143 additions and 100 deletions

View File

@ -9,7 +9,6 @@ import openai
import pytest
import torch
from huggingface_hub import snapshot_download
from tensorizer import EncryptionParams
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
@ -23,12 +22,18 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
serialize_vllm_model,
tensorize_vllm_model)
# yapf: enable
from vllm.utils import import_from_path
from vllm.utils import PlaceholderModule, import_from_path
from ..conftest import VllmRunner
from ..utils import VLLM_PATH, RemoteOpenAIServer
from .conftest import retry_until_skip
try:
from tensorizer import EncryptionParams
except ImportError:
tensorizer = PlaceholderModule("tensorizer") # type: ignore[assignment]
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
EXAMPLES_PATH = VLLM_PATH / "examples"
prompts = [

View File

@ -1,11 +1,17 @@
from dataclasses import dataclass
from typing import Literal, Tuple
from typing import Literal
from urllib.parse import urljoin
import librosa
import numpy as np
import numpy.typing as npt
from vllm.assets.base import get_vllm_public_assets, vLLM_S3_BUCKET_URL
from vllm.utils import PlaceholderModule
from .base import VLLM_S3_BUCKET_URL, get_vllm_public_assets
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
ASSET_DIR = "multimodal_asset"
@ -15,8 +21,7 @@ class AudioAsset:
name: Literal["winning_call", "mary_had_lamb"]
@property
def audio_and_sample_rate(self) -> Tuple[np.ndarray, int]:
def audio_and_sample_rate(self) -> tuple[npt.NDArray, int]:
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR)
y, sr = librosa.load(audio_path, sr=None)
@ -25,4 +30,4 @@ class AudioAsset:
@property
def url(self) -> str:
return urljoin(vLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")
return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")

View File

@ -4,9 +4,8 @@ from typing import Optional
import vllm.envs as envs
from vllm.connections import global_http_connection
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
vLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
VLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
def get_cache_dir() -> Path:
@ -32,8 +31,8 @@ def get_vllm_public_assets(filename: str,
if s3_prefix is not None:
filename = s3_prefix + "/" + filename
global_http_connection.download_file(
f"{vLLM_S3_BUCKET_URL}/{filename}",
f"{VLLM_S3_BUCKET_URL}/{filename}",
asset_path,
timeout=VLLM_IMAGE_FETCH_TIMEOUT)
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT)
return asset_path

View File

@ -4,7 +4,7 @@ from typing import Literal
import torch
from PIL import Image
from vllm.assets.base import get_vllm_public_assets
from .base import get_vllm_public_assets
VLM_IMAGES_DIR = "vision_model_images"
@ -15,7 +15,6 @@ class ImageAsset:
@property
def pil_image(self) -> Image.Image:
image_path = get_vllm_public_assets(filename=f"{self.name}.jpg",
s3_prefix=VLM_IMAGES_DIR)
return Image.open(image_path)

View File

@ -2,13 +2,13 @@ from dataclasses import dataclass
from functools import lru_cache
from typing import List, Literal
import cv2
import numpy as np
import numpy.typing as npt
from huggingface_hub import hf_hub_download
from PIL import Image
from vllm.multimodal.video import (sample_frames_from_video,
try_import_video_packages)
from vllm.multimodal.video import sample_frames_from_video
from .base import get_cache_dir
@ -19,7 +19,7 @@ def download_video_asset(filename: str) -> str:
Download and open an image from huggingface
repo: raushan-testing-hf/videos-test
"""
video_directory = get_cache_dir() / "video-eample-data"
video_directory = get_cache_dir() / "video-example-data"
video_directory.mkdir(parents=True, exist_ok=True)
video_path = video_directory / filename
@ -35,8 +35,6 @@ def download_video_asset(filename: str) -> str:
def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
cv2, _ = try_import_video_packages()
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise ValueError(f"Could not open video file {path}")
@ -59,7 +57,6 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
def video_to_pil_images_list(path: str,
num_frames: int = -1) -> List[Image.Image]:
cv2, _ = try_import_video_packages()
frames = video_to_ndarrays(path, num_frames)
return [
Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

View File

@ -29,6 +29,7 @@ from vllm.transformers_utils.config import (
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, print_warning_once, random_uuid,
@ -372,15 +373,6 @@ class ModelConfig:
"""
if is_s3(model) or is_s3(tokenizer):
try:
from vllm.transformers_utils.s3_utils import S3Model
except ImportError as err:
raise ImportError(
"Please install Run:ai optional dependency "
"to use the S3 capabilities. "
"You can install it with: pip install vllm[runai]"
) from err
if is_s3(model):
self.s3_model = S3Model()
self.s3_model.pull_files(model, allow_pattern=["*config.json"])

View File

@ -48,6 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import (
runai_safetensors_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.transformers_utils.s3_utils import glob as s3_glob
from vllm.transformers_utils.utils import is_s3
from vllm.utils import is_pin_memory_available
@ -1269,16 +1270,6 @@ class RunaiModelStreamerLoader(BaseModelLoader):
If the model is not local, it will be downloaded."""
is_s3_path = is_s3(model_name_or_path)
if is_s3_path:
try:
from vllm.transformers_utils.s3_utils import glob as s3_glob
except ImportError as err:
raise ImportError(
"Please install Run:ai optional dependency "
"to use the S3 capabilities. "
"You can install it with: pip install vllm[runai]"
) from err
is_local = os.path.isdir(model_name_or_path)
safetensors_pattern = "*.safetensors"
index_file = SAFE_WEIGHTS_INDEX_NAME

View File

@ -19,9 +19,7 @@ from vllm.engine.llm_engine import LLMEngine
from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.utils import FlexibleArgumentParser
tensorizer_error_msg = None
from vllm.utils import FlexibleArgumentParser, PlaceholderModule
try:
from tensorizer import (DecryptionParams, EncryptionParams,
@ -34,8 +32,19 @@ try:
open_stream,
mode=mode,
) for mode in ("rb", "wb+"))
except ImportError as e:
tensorizer_error_msg = str(e)
except ImportError:
tensorizer = PlaceholderModule("tensorizer")
DecryptionParams = tensorizer.placeholder_attr("DecryptionParams")
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
TensorDeserializer = tensorizer.placeholder_attr("TensorDeserializer")
TensorSerializer = tensorizer.placeholder_attr("TensorSerializer")
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
_read_stream = tensorizer.placeholder_attr("_read_stream")
_write_stream = tensorizer.placeholder_attr("_write_stream")
__all__ = [
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
@ -267,12 +276,6 @@ class TensorizerAgent:
"""
def __init__(self, tensorizer_config: TensorizerConfig, vllm_config):
if tensorizer_error_msg is not None:
raise ImportError(
"Tensorizer is not installed. Please install tensorizer "
"to use this feature with `pip install vllm[tensorizer]`. "
"Error message: {}".format(tensorizer_error_msg))
self.tensorizer_config = tensorizer_config
self.tensorizer_args = (
self.tensorizer_config._construct_tensorizer_args())

View File

@ -25,7 +25,15 @@ from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
from vllm.utils import PlaceholderModule, print_warning_once
try:
from runai_model_streamer import SafetensorsStreamer
except ImportError:
runai_model_streamer = PlaceholderModule(
"runai_model_streamer") # type: ignore[assignment]
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
"SafetensorsStreamer")
logger = init_logger(__name__)
@ -414,13 +422,6 @@ def runai_safetensors_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
try:
from runai_model_streamer import SafetensorsStreamer
except ImportError as err:
raise ImportError(
"Please install Run:ai optional dependency."
"You can install it with: pip install vllm[runai]") from err
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
with SafetensorsStreamer() as streamer:

View File

@ -1,13 +1,17 @@
from typing import Any
import numpy as np
import numpy.typing as npt
from vllm.inputs.registry import InputContext
from vllm.utils import PlaceholderModule
from .base import MultiModalPlugin
from .inputs import AudioItem, MultiModalData, MultiModalKwargs
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
class AudioPlugin(MultiModalPlugin):
"""Plugin for audio data."""
@ -28,26 +32,10 @@ class AudioPlugin(MultiModalPlugin):
"There is no default maximum multimodal tokens")
def try_import_audio_packages() -> tuple[Any, Any]:
try:
import librosa
import soundfile
except ImportError as exc:
raise ImportError(
"Please install vllm[audio] for audio support.") from exc
return librosa, soundfile
def resample_audio(
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
target_sr: float,
) -> npt.NDArray[np.floating]:
try:
import librosa
except ImportError as exc:
msg = "Please install vllm[audio] for audio support."
raise ImportError(msg) from exc
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)

View File

@ -13,10 +13,24 @@ import vllm.envs as envs
from vllm.connections import global_http_connection
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.utils import PlaceholderModule
from .audio import try_import_audio_packages
from .inputs import MultiModalDataDict, PlaceholderRange
from .video import try_import_video_packages
try:
import decord
except ImportError:
decord = PlaceholderModule("decord") # type: ignore[assignment]
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import soundfile
except ImportError:
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
logger = init_logger(__name__)
@ -128,8 +142,6 @@ async def async_fetch_image(image_url: str,
def _load_video_from_bytes(b: bytes, num_frames: int = 32) -> npt.NDArray:
_, decord = try_import_video_packages()
video_path = BytesIO(b)
vr = decord.VideoReader(video_path, num_threads=1)
total_frame_num = len(vr)
@ -204,8 +216,6 @@ def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
"""
Load audio from a URL.
"""
librosa, _ = try_import_audio_packages()
if audio_url.startswith("http"):
audio_bytes = global_http_connection.get_bytes(
audio_url,
@ -226,8 +236,6 @@ async def async_fetch_audio(
"""
Asynchronously fetch audio from a URL.
"""
librosa, _ = try_import_audio_packages()
if audio_url.startswith("http"):
audio_bytes = await global_http_connection.async_get_bytes(
audio_url,
@ -286,8 +294,6 @@ def encode_audio_base64(
sampling_rate: int,
) -> str:
"""Encode audio as base64."""
_, soundfile = try_import_audio_packages()
buffered = BytesIO()
soundfile.write(buffered, audio, sampling_rate, format="WAV")

View File

@ -1,6 +1,7 @@
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Dict, Optional
import cv2
import numpy as np
import numpy.typing as npt
@ -78,19 +79,7 @@ class VideoPlugin(ImagePlugin):
return 4096
def try_import_video_packages() -> tuple[Any, Any]:
try:
import cv2
import decord
except ImportError as exc:
raise ImportError(
"Please install vllm[video] for video support.") from exc
return cv2, decord
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
cv2, _ = try_import_video_packages()
num_frames, _, _, channels = frames.shape
new_height, new_width = size
resized_frames = np.empty((num_frames, new_height, new_width, channels),

View File

@ -6,7 +6,12 @@ import tempfile
from pathlib import Path
from typing import Optional
import boto3
from vllm.utils import PlaceholderModule
try:
import boto3
except ImportError:
boto3 = PlaceholderModule("boto3") # type: ignore[assignment]
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:

View File

@ -6,10 +6,12 @@ import datetime
import enum
import gc
import getpass
import importlib.metadata
import importlib.util
import inspect
import ipaddress
import os
import re
import signal
import socket
import subprocess
@ -1550,6 +1552,67 @@ def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
return module
@lru_cache(maxsize=None)
def get_vllm_optional_dependencies():
metadata = importlib.metadata.metadata("vllm")
requirements = metadata.get_all("Requires-Dist", [])
extras = metadata.get_all("Provides-Extra", [])
return {
extra: [
re.split(r";|>=|<=|==", req)[0] for req in requirements
if req.endswith(f'extra == "{extra}"')
]
for extra in extras
}
@dataclass(frozen=True)
class PlaceholderModule:
"""
A placeholder object to use when a module does not exist.
This enables more informative errors when trying to access attributes
of a module that does not exists.
"""
name: str
def placeholder_attr(self, attr_path: str):
return _PlaceholderModuleAttr(self, attr_path)
def __getattr__(self, key: str):
name = self.name
try:
importlib.import_module(self.name)
except ImportError as exc:
for extra, names in get_vllm_optional_dependencies().items():
if name in names:
msg = f"Please install vllm[{extra}] for {extra} support"
raise ImportError(msg) from exc
raise exc
raise AssertionError("PlaceholderModule should not be used "
"when the original module can be imported")
@dataclass(frozen=True)
class _PlaceholderModuleAttr:
module: PlaceholderModule
attr_path: str
def placeholder_attr(self, attr_path: str):
return _PlaceholderModuleAttr(self.module,
f"{self.attr_path}.{attr_path}")
def __getattr__(self, key: str):
getattr(self.module, f"{self.attr_path}.{key}")
raise AssertionError("PlaceholderModule should not be used "
"when the original module can be imported")
# create a library to hold the custom op
vllm_lib = Library("vllm", "FRAGMENT") # noqa