mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-22 10:54:29 +08:00
[Misc] Add placeholder module (#11501)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
f57ee5650d
commit
eec906d811
@ -9,7 +9,6 @@ import openai
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from tensorizer import EncryptionParams
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
@ -23,12 +22,18 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
|
||||
serialize_vllm_model,
|
||||
tensorize_vllm_model)
|
||||
# yapf: enable
|
||||
from vllm.utils import import_from_path
|
||||
from vllm.utils import PlaceholderModule, import_from_path
|
||||
|
||||
from ..conftest import VllmRunner
|
||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
||||
from .conftest import retry_until_skip
|
||||
|
||||
try:
|
||||
from tensorizer import EncryptionParams
|
||||
except ImportError:
|
||||
tensorizer = PlaceholderModule("tensorizer") # type: ignore[assignment]
|
||||
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
|
||||
|
||||
EXAMPLES_PATH = VLLM_PATH / "examples"
|
||||
|
||||
prompts = [
|
||||
|
||||
@ -1,11 +1,17 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Tuple
|
||||
from typing import Literal
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from vllm.assets.base import get_vllm_public_assets, vLLM_S3_BUCKET_URL
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
from .base import VLLM_S3_BUCKET_URL, get_vllm_public_assets
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
ASSET_DIR = "multimodal_asset"
|
||||
|
||||
@ -15,8 +21,7 @@ class AudioAsset:
|
||||
name: Literal["winning_call", "mary_had_lamb"]
|
||||
|
||||
@property
|
||||
def audio_and_sample_rate(self) -> Tuple[np.ndarray, int]:
|
||||
|
||||
def audio_and_sample_rate(self) -> tuple[npt.NDArray, int]:
|
||||
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
|
||||
s3_prefix=ASSET_DIR)
|
||||
y, sr = librosa.load(audio_path, sr=None)
|
||||
@ -25,4 +30,4 @@ class AudioAsset:
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return urljoin(vLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")
|
||||
return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")
|
||||
|
||||
@ -4,9 +4,8 @@ from typing import Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
|
||||
|
||||
vLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
|
||||
VLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
|
||||
|
||||
|
||||
def get_cache_dir() -> Path:
|
||||
@ -32,8 +31,8 @@ def get_vllm_public_assets(filename: str,
|
||||
if s3_prefix is not None:
|
||||
filename = s3_prefix + "/" + filename
|
||||
global_http_connection.download_file(
|
||||
f"{vLLM_S3_BUCKET_URL}/{filename}",
|
||||
f"{VLLM_S3_BUCKET_URL}/{filename}",
|
||||
asset_path,
|
||||
timeout=VLLM_IMAGE_FETCH_TIMEOUT)
|
||||
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT)
|
||||
|
||||
return asset_path
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Literal
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from vllm.assets.base import get_vllm_public_assets
|
||||
from .base import get_vllm_public_assets
|
||||
|
||||
VLM_IMAGES_DIR = "vision_model_images"
|
||||
|
||||
@ -15,7 +15,6 @@ class ImageAsset:
|
||||
|
||||
@property
|
||||
def pil_image(self) -> Image.Image:
|
||||
|
||||
image_path = get_vllm_public_assets(filename=f"{self.name}.jpg",
|
||||
s3_prefix=VLM_IMAGES_DIR)
|
||||
return Image.open(image_path)
|
||||
|
||||
@ -2,13 +2,13 @@ from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import List, Literal
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from vllm.multimodal.video import (sample_frames_from_video,
|
||||
try_import_video_packages)
|
||||
from vllm.multimodal.video import sample_frames_from_video
|
||||
|
||||
from .base import get_cache_dir
|
||||
|
||||
@ -19,7 +19,7 @@ def download_video_asset(filename: str) -> str:
|
||||
Download and open an image from huggingface
|
||||
repo: raushan-testing-hf/videos-test
|
||||
"""
|
||||
video_directory = get_cache_dir() / "video-eample-data"
|
||||
video_directory = get_cache_dir() / "video-example-data"
|
||||
video_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_path = video_directory / filename
|
||||
@ -35,8 +35,6 @@ def download_video_asset(filename: str) -> str:
|
||||
|
||||
|
||||
def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
|
||||
cv2, _ = try_import_video_packages()
|
||||
|
||||
cap = cv2.VideoCapture(path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Could not open video file {path}")
|
||||
@ -59,7 +57,6 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
|
||||
|
||||
def video_to_pil_images_list(path: str,
|
||||
num_frames: int = -1) -> List[Image.Image]:
|
||||
cv2, _ = try_import_video_packages()
|
||||
frames = video_to_ndarrays(path, num_frames)
|
||||
return [
|
||||
Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
|
||||
@ -29,6 +29,7 @@ from vllm.transformers_utils.config import (
|
||||
get_hf_text_config, get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
|
||||
try_get_generation_config, uses_mrope)
|
||||
from vllm.transformers_utils.s3_utils import S3Model
|
||||
from vllm.transformers_utils.utils import is_s3
|
||||
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
||||
get_cpu_memory, print_warning_once, random_uuid,
|
||||
@ -372,15 +373,6 @@ class ModelConfig:
|
||||
|
||||
"""
|
||||
if is_s3(model) or is_s3(tokenizer):
|
||||
try:
|
||||
from vllm.transformers_utils.s3_utils import S3Model
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install Run:ai optional dependency "
|
||||
"to use the S3 capabilities. "
|
||||
"You can install it with: pip install vllm[runai]"
|
||||
) from err
|
||||
|
||||
if is_s3(model):
|
||||
self.s3_model = S3Model()
|
||||
self.s3_model.pull_files(model, allow_pattern=["*config.json"])
|
||||
|
||||
@ -48,6 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
runai_safetensors_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.s3_utils import glob as s3_glob
|
||||
from vllm.transformers_utils.utils import is_s3
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
@ -1269,16 +1270,6 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
is_s3_path = is_s3(model_name_or_path)
|
||||
if is_s3_path:
|
||||
try:
|
||||
from vllm.transformers_utils.s3_utils import glob as s3_glob
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install Run:ai optional dependency "
|
||||
"to use the S3 capabilities. "
|
||||
"You can install it with: pip install vllm[runai]"
|
||||
) from err
|
||||
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
safetensors_pattern = "*.safetensors"
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
@ -19,9 +19,7 @@ from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
tensorizer_error_msg = None
|
||||
from vllm.utils import FlexibleArgumentParser, PlaceholderModule
|
||||
|
||||
try:
|
||||
from tensorizer import (DecryptionParams, EncryptionParams,
|
||||
@ -34,8 +32,19 @@ try:
|
||||
open_stream,
|
||||
mode=mode,
|
||||
) for mode in ("rb", "wb+"))
|
||||
except ImportError as e:
|
||||
tensorizer_error_msg = str(e)
|
||||
except ImportError:
|
||||
tensorizer = PlaceholderModule("tensorizer")
|
||||
DecryptionParams = tensorizer.placeholder_attr("DecryptionParams")
|
||||
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
|
||||
TensorDeserializer = tensorizer.placeholder_attr("TensorDeserializer")
|
||||
TensorSerializer = tensorizer.placeholder_attr("TensorSerializer")
|
||||
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
|
||||
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
|
||||
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
|
||||
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
|
||||
|
||||
_read_stream = tensorizer.placeholder_attr("_read_stream")
|
||||
_write_stream = tensorizer.placeholder_attr("_write_stream")
|
||||
|
||||
__all__ = [
|
||||
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
|
||||
@ -267,12 +276,6 @@ class TensorizerAgent:
|
||||
"""
|
||||
|
||||
def __init__(self, tensorizer_config: TensorizerConfig, vllm_config):
|
||||
if tensorizer_error_msg is not None:
|
||||
raise ImportError(
|
||||
"Tensorizer is not installed. Please install tensorizer "
|
||||
"to use this feature with `pip install vllm[tensorizer]`. "
|
||||
"Error message: {}".format(tensorizer_error_msg))
|
||||
|
||||
self.tensorizer_config = tensorizer_config
|
||||
self.tensorizer_args = (
|
||||
self.tensorizer_config._construct_tensorizer_args())
|
||||
|
||||
@ -25,7 +25,15 @@ from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
||||
get_quantization_config)
|
||||
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
from vllm.utils import PlaceholderModule, print_warning_once
|
||||
|
||||
try:
|
||||
from runai_model_streamer import SafetensorsStreamer
|
||||
except ImportError:
|
||||
runai_model_streamer = PlaceholderModule(
|
||||
"runai_model_streamer") # type: ignore[assignment]
|
||||
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
|
||||
"SafetensorsStreamer")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -414,13 +422,6 @@ def runai_safetensors_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
try:
|
||||
from runai_model_streamer import SafetensorsStreamer
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install Run:ai optional dependency."
|
||||
"You can install it with: pip install vllm[runai]") from err
|
||||
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
with SafetensorsStreamer() as streamer:
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
from .base import MultiModalPlugin
|
||||
from .inputs import AudioItem, MultiModalData, MultiModalKwargs
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
|
||||
class AudioPlugin(MultiModalPlugin):
|
||||
"""Plugin for audio data."""
|
||||
@ -28,26 +32,10 @@ class AudioPlugin(MultiModalPlugin):
|
||||
"There is no default maximum multimodal tokens")
|
||||
|
||||
|
||||
def try_import_audio_packages() -> tuple[Any, Any]:
|
||||
try:
|
||||
import librosa
|
||||
import soundfile
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install vllm[audio] for audio support.") from exc
|
||||
return librosa, soundfile
|
||||
|
||||
|
||||
def resample_audio(
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
target_sr: float,
|
||||
) -> npt.NDArray[np.floating]:
|
||||
try:
|
||||
import librosa
|
||||
except ImportError as exc:
|
||||
msg = "Please install vllm[audio] for audio support."
|
||||
raise ImportError(msg) from exc
|
||||
|
||||
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
||||
|
||||
@ -13,10 +13,24 @@ import vllm.envs as envs
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
from .audio import try_import_audio_packages
|
||||
from .inputs import MultiModalDataDict, PlaceholderRange
|
||||
from .video import try_import_video_packages
|
||||
|
||||
try:
|
||||
import decord
|
||||
except ImportError:
|
||||
decord = PlaceholderModule("decord") # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import soundfile
|
||||
except ImportError:
|
||||
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -128,8 +142,6 @@ async def async_fetch_image(image_url: str,
|
||||
|
||||
|
||||
def _load_video_from_bytes(b: bytes, num_frames: int = 32) -> npt.NDArray:
|
||||
_, decord = try_import_video_packages()
|
||||
|
||||
video_path = BytesIO(b)
|
||||
vr = decord.VideoReader(video_path, num_threads=1)
|
||||
total_frame_num = len(vr)
|
||||
@ -204,8 +216,6 @@ def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
|
||||
"""
|
||||
Load audio from a URL.
|
||||
"""
|
||||
librosa, _ = try_import_audio_packages()
|
||||
|
||||
if audio_url.startswith("http"):
|
||||
audio_bytes = global_http_connection.get_bytes(
|
||||
audio_url,
|
||||
@ -226,8 +236,6 @@ async def async_fetch_audio(
|
||||
"""
|
||||
Asynchronously fetch audio from a URL.
|
||||
"""
|
||||
librosa, _ = try_import_audio_packages()
|
||||
|
||||
if audio_url.startswith("http"):
|
||||
audio_bytes = await global_http_connection.async_get_bytes(
|
||||
audio_url,
|
||||
@ -286,8 +294,6 @@ def encode_audio_base64(
|
||||
sampling_rate: int,
|
||||
) -> str:
|
||||
"""Encode audio as base64."""
|
||||
_, soundfile = try_import_audio_packages()
|
||||
|
||||
buffered = BytesIO()
|
||||
soundfile.write(buffered, audio, sampling_rate, format="WAV")
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
@ -78,19 +79,7 @@ class VideoPlugin(ImagePlugin):
|
||||
return 4096
|
||||
|
||||
|
||||
def try_import_video_packages() -> tuple[Any, Any]:
|
||||
try:
|
||||
import cv2
|
||||
import decord
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install vllm[video] for video support.") from exc
|
||||
return cv2, decord
|
||||
|
||||
|
||||
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
|
||||
cv2, _ = try_import_video_packages()
|
||||
|
||||
num_frames, _, _, channels = frames.shape
|
||||
new_height, new_width = size
|
||||
resized_frames = np.empty((num_frames, new_height, new_width, channels),
|
||||
|
||||
@ -6,7 +6,12 @@ import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
import boto3
|
||||
except ImportError:
|
||||
boto3 = PlaceholderModule("boto3") # type: ignore[assignment]
|
||||
|
||||
|
||||
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
|
||||
|
||||
@ -6,10 +6,12 @@ import datetime
|
||||
import enum
|
||||
import gc
|
||||
import getpass
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import inspect
|
||||
import ipaddress
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
@ -1550,6 +1552,67 @@ def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
|
||||
return module
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_vllm_optional_dependencies():
|
||||
metadata = importlib.metadata.metadata("vllm")
|
||||
requirements = metadata.get_all("Requires-Dist", [])
|
||||
extras = metadata.get_all("Provides-Extra", [])
|
||||
|
||||
return {
|
||||
extra: [
|
||||
re.split(r";|>=|<=|==", req)[0] for req in requirements
|
||||
if req.endswith(f'extra == "{extra}"')
|
||||
]
|
||||
for extra in extras
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PlaceholderModule:
|
||||
"""
|
||||
A placeholder object to use when a module does not exist.
|
||||
|
||||
This enables more informative errors when trying to access attributes
|
||||
of a module that does not exists.
|
||||
"""
|
||||
name: str
|
||||
|
||||
def placeholder_attr(self, attr_path: str):
|
||||
return _PlaceholderModuleAttr(self, attr_path)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
name = self.name
|
||||
|
||||
try:
|
||||
importlib.import_module(self.name)
|
||||
except ImportError as exc:
|
||||
for extra, names in get_vllm_optional_dependencies().items():
|
||||
if name in names:
|
||||
msg = f"Please install vllm[{extra}] for {extra} support"
|
||||
raise ImportError(msg) from exc
|
||||
|
||||
raise exc
|
||||
|
||||
raise AssertionError("PlaceholderModule should not be used "
|
||||
"when the original module can be imported")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _PlaceholderModuleAttr:
|
||||
module: PlaceholderModule
|
||||
attr_path: str
|
||||
|
||||
def placeholder_attr(self, attr_path: str):
|
||||
return _PlaceholderModuleAttr(self.module,
|
||||
f"{self.attr_path}.{attr_path}")
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
getattr(self.module, f"{self.attr_path}.{key}")
|
||||
|
||||
raise AssertionError("PlaceholderModule should not be used "
|
||||
"when the original module can be imported")
|
||||
|
||||
|
||||
# create a library to hold the custom op
|
||||
vllm_lib = Library("vllm", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user