[Frontend] Multithreaded async multimodal load_bytes (#22710)

Signed-off-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com>
Co-authored-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com>
This commit is contained in:
milesial 2025-08-13 06:09:26 -07:00 committed by GitHub
parent b159c0a67a
commit 20d65aa755
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 6 deletions

View File

@ -63,6 +63,7 @@ if TYPE_CHECKING:
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_VIDEO_FETCH_TIMEOUT: int = 30
VLLM_AUDIO_FETCH_TIMEOUT: int = 10 VLLM_AUDIO_FETCH_TIMEOUT: int = 10
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
VLLM_VIDEO_LOADER_BACKEND: str = "opencv" VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
VLLM_MM_INPUT_CACHE_GIB: int = 4 VLLM_MM_INPUT_CACHE_GIB: int = 4
@ -555,6 +556,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_AUDIO_FETCH_TIMEOUT": "VLLM_AUDIO_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
# Max number of workers for the thread pool handling
# media bytes loading. Set to 1 to disable parallel processing.
# Default is 8
"VLLM_MEDIA_LOADING_THREAD_COUNT":
lambda: int(os.getenv("VLLM_MEDIA_LOADING_THREAD_COUNT", "8")),
# Maximum filesize in MB for a single audio file when processing # Maximum filesize in MB for a single audio file when processing
# speech-to-text requests. Files larger than this will be rejected. # speech-to-text requests. Files larger than this will be rejected.
# Default is 25 MB # Default is 25 MB

View File

@ -1,6 +1,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import atexit
from concurrent.futures import ThreadPoolExecutor
from itertools import groupby from itertools import groupby
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
@ -33,6 +36,10 @@ else:
MultiModalKwargs = Any MultiModalKwargs = Any
MultiModalPlaceholderDict = Any MultiModalPlaceholderDict = Any
global_thread_pool = ThreadPoolExecutor(
max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT)
atexit.register(global_thread_pool.shutdown)
class MediaConnector: class MediaConnector:
@ -139,19 +146,26 @@ class MediaConnector:
fetch_timeout: Optional[int] = None, fetch_timeout: Optional[int] = None,
) -> _M: ) -> _M:
url_spec = urlparse(url) url_spec = urlparse(url)
loop = asyncio.get_running_loop()
if url_spec.scheme.startswith("http"): if url_spec.scheme.startswith("http"):
connection = self.connection connection = self.connection
data = await connection.async_get_bytes(url, timeout=fetch_timeout) data = await connection.async_get_bytes(url, timeout=fetch_timeout)
future = loop.run_in_executor(global_thread_pool,
return media_io.load_bytes(data) media_io.load_bytes, data)
return await future
if url_spec.scheme == "data": if url_spec.scheme == "data":
return self._load_data_url(url_spec, media_io) future = loop.run_in_executor(global_thread_pool,
self._load_data_url, url_spec,
media_io)
return await future
if url_spec.scheme == "file": if url_spec.scheme == "file":
return self._load_file_url(url_spec, media_io) future = loop.run_in_executor(global_thread_pool,
self._load_file_url, url_spec,
media_io)
return await future
msg = "The URL must be either a HTTP, data or file URL." msg = "The URL must be either a HTTP, data or file URL."
raise ValueError(msg) raise ValueError(msg)