[Misc] Clean up more utils (#27567)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-27 23:30:38 +08:00 committed by GitHub
parent 3b96f85c36
commit 6ebffafbb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 282 additions and 315 deletions

View File

@ -13,6 +13,8 @@ ruff
# Required for argparse hook only
-f https://download.pytorch.org/whl/cpu
cachetools
cloudpickle
py-cpuinfo
msgspec
pydantic
torch

View File

@ -39,7 +39,6 @@ ALLOWED_FILES = {
"vllm/v1/executor/multiproc_executor.py",
"vllm/v1/executor/ray_executor.py",
"vllm/entrypoints/llm.py",
"vllm/utils/__init__.py",
"tests/utils.py",
# pickle and cloudpickle
"vllm/v1/serial_utils.py",

View File

@ -1618,6 +1618,29 @@ class ModelConfig:
"""Extract the HF encoder/decoder model flag."""
return is_encoder_decoder(self.hf_config)
@property
def uses_alibi(self) -> bool:
cfg = self.hf_text_config
return (
getattr(cfg, "alibi", False) # Falcon
or "BloomForCausalLM" in self.architectures # Bloom
or getattr(cfg, "position_encoding_type", "") == "alibi" # codellm_1b_alibi
or (
hasattr(cfg, "attn_config") # MPT
and (
(
isinstance(cfg.attn_config, dict)
and cfg.attn_config.get("alibi", False)
)
or (
not isinstance(cfg.attn_config, dict)
and getattr(cfg.attn_config, "alibi", False)
)
)
)
)
@property
def uses_mrope(self) -> bool:
return uses_mrope(self.hf_config)

View File

@ -2,12 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import getpass
import hashlib
import json
import os
import tempfile
import threading
import time
from contextlib import contextmanager
from dataclasses import replace
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar
@ -17,7 +21,7 @@ from pydantic import ConfigDict, Field
from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.logger import enable_trace_function_call, init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
@ -206,6 +210,28 @@ class VllmConfig:
# i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size
return self.compilation_config.bs_to_padded_graph_size[batch_size]
def enable_trace_function_call_for_thread(self) -> None:
"""
Set up function tracing for the current thread,
if enabled via the `VLLM_TRACE_FUNCTION` environment variable.
"""
if envs.VLLM_TRACE_FUNCTION:
tmp_dir = tempfile.gettempdir()
# add username to tmp_dir to avoid permission issues
tmp_dir = os.path.join(tmp_dir, getpass.getuser())
filename = (
f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_at_{datetime.now()}.log"
).replace(" ", "_")
log_path = os.path.join(
tmp_dir,
"vllm",
f"vllm-instance-{self.instance_id}",
filename,
)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)
@staticmethod
def _get_quantization_config(
model_config: ModelConfig, load_config: LoadConfig

View File

@ -73,7 +73,7 @@ from vllm.config.utils import get_field
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
from vllm.ray.lazy_utils import is_ray_initialized
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.config import (
@ -82,7 +82,6 @@ from vllm.transformers_utils.config import (
maybe_override_with_speculators,
)
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import is_in_ray_actor
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_ip

View File

@ -51,9 +51,9 @@ from vllm.entrypoints.utils import (
with_cancellation,
)
from vllm.logger import init_logger
from vllm.utils import set_ulimit
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import is_valid_ipv6_address
from vllm.utils.system_utils import set_ulimit
from vllm.version import __version__ as VLLM_VERSION
prometheus_multiproc_dir: tempfile.TemporaryDirectory

View File

@ -26,8 +26,9 @@ from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid, set_ulimit
from vllm.utils import random_uuid
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.system_utils import set_ulimit
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger("vllm.entrypoints.api_server")

View File

@ -108,10 +108,10 @@ from vllm.entrypoints.utils import (
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, set_ulimit
from vllm.utils import Device
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import is_valid_ipv6_address
from vllm.utils.system_utils import decorate_logs
from vllm.utils.system_utils import decorate_logs, set_ulimit
from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.metrics.prometheus import get_prometheus_registry
from vllm.version import __version__ as VLLM_VERSION

View File

@ -60,7 +60,7 @@ def cuda_platform_plugin() -> str | None:
is_cuda = False
logger.debug("Checking if CUDA platform is available.")
try:
from vllm.utils import import_pynvml
from vllm.utils.import_utils import import_pynvml
pynvml = import_pynvml()
pynvml.nvmlInit()

View File

@ -16,7 +16,7 @@ from typing_extensions import ParamSpec
import vllm._C # noqa
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import import_pynvml
from vllm.utils.import_utils import import_pynvml
from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum

View File

@ -1,34 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import datetime
import enum
import getpass
import inspect
import multiprocessing
import os
import signal
import sys
import tempfile
import threading
import uuid
import warnings
from collections.abc import Callable
from functools import partial, wraps
from typing import TYPE_CHECKING, Any, TypeVar
from functools import wraps
from typing import Any, TypeVar
import cloudpickle
import psutil
import torch
import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
from vllm.ray.lazy_utils import is_in_ray_actor
from vllm.logger import init_logger
_DEPRECATED_MAPPINGS = {
"cprofile": "profiling",
"cprofile_context": "profiling",
# Used by lm-eval
"get_open_port": "network_utils",
}
@ -53,12 +41,6 @@ def __dir__() -> list[str]:
return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys()))
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
else:
ModelConfig = object
VllmConfig = object
logger = init_logger(__name__)
# This value is chosen to have a balance between ITL and TTFT. Note it is
@ -83,13 +65,7 @@ STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
# ANSI color codes
CYAN = "\033[1;36m"
RESET = "\033[0;0m"
T = TypeVar("T")
U = TypeVar("U")
class Device(enum.Enum):
@ -144,195 +120,6 @@ def random_uuid() -> str:
return str(uuid.uuid4().hex)
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
def init_cached_hf_modules() -> None:
"""
Lazy initialization of the Hugging Face modules.
"""
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable
"""
if envs.VLLM_TRACE_FUNCTION:
tmp_dir = tempfile.gettempdir()
# add username to tmp_dir to avoid permission issues
tmp_dir = os.path.join(tmp_dir, getpass.getuser())
filename = (
f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log"
).replace(" ", "_")
log_path = os.path.join(
tmp_dir, "vllm", f"vllm-instance-{vllm_config.instance_id}", filename
)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)
def kill_process_tree(pid: int):
"""
Kills all descendant processes of the given pid by sending SIGKILL.
Args:
pid (int): Process ID of the parent process
"""
try:
parent = psutil.Process(pid)
except psutil.NoSuchProcess:
return
# Get all children recursively
children = parent.children(recursive=True)
# Send SIGKILL to all children first
for child in children:
with contextlib.suppress(ProcessLookupError):
os.kill(child.pid, signal.SIGKILL)
# Finally kill the parent
with contextlib.suppress(ProcessLookupError):
os.kill(pid, signal.SIGKILL)
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
def set_ulimit(target_soft_limit=65535):
if sys.platform.startswith("win"):
logger.info("Windows detected, skipping ulimit adjustment.")
return
import resource
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
if current_soft < target_soft_limit:
try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
except ValueError as e:
logger.warning(
"Found ulimit of %s and failed to automatically increase "
"with error %s. This can cause fd limit errors like "
"`OSError: [Errno 24] Too many open files`. Consider "
"increasing with ulimit -n",
current_soft,
e,
)
def _maybe_force_spawn():
"""Check if we need to force the use of the `spawn` multiprocessing start
method.
"""
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
return
reasons = []
if is_in_ray_actor():
# even if we choose to spawn, we need to pass the ray address
# to the subprocess so that it knows how to connect to the ray cluster.
# env vars are inherited by subprocesses, even if we use spawn.
import ray
os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
reasons.append("In a Ray actor and can only be spawned")
from .platform_utils import cuda_is_initialized, xpu_is_initialized
if cuda_is_initialized():
reasons.append("CUDA is initialized")
elif xpu_is_initialized():
reasons.append("XPU is initialized")
if reasons:
logger.warning(
"We must use the `spawn` multiprocessing start method. "
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/usage/"
"troubleshooting.html#python-multiprocessing "
"for more information. Reasons: %s",
"; ".join(reasons),
)
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def get_mp_context():
"""Get a multiprocessing context with a particular method (spawn or fork).
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
determine the multiprocessing method (default is fork). However, under
certain conditions, we may enforce spawn and override the value of
VLLM_WORKER_MULTIPROC_METHOD.
"""
_maybe_force_spawn()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)
def run_method(
obj: Any,
method: str | bytes | Callable,
args: tuple[Any],
kwargs: dict[str, Any],
) -> Any:
"""
Run a method of an object with the given arguments and keyword arguments.
If the method is string, it will be converted to a method using getattr.
If the method is serialized bytes and will be deserialized using
cloudpickle.
If the method is a callable, it will be called directly.
"""
if isinstance(method, bytes):
func = partial(cloudpickle.loads(method), obj)
elif isinstance(method, str):
try:
func = getattr(obj, method)
except AttributeError:
raise NotImplementedError(
f"Method {method!r} is not implemented."
) from None
else:
func = partial(method, obj) # type: ignore
return func(*args, **kwargs)
def import_pynvml():
"""
Historical comments:
libnvml.so is the library behind nvidia-smi, and
pynvml is a Python wrapper around it. We use it to get GPU
status without initializing CUDA context in the current process.
Historically, there are two packages that provide pynvml:
- `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
wrapper. It is a dependency of vLLM, and is installed when users
install vLLM. It provides a Python module named `pynvml`.
- `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
Prior to version 12.0, it also provides a Python module `pynvml`,
and therefore conflicts with the official one. What's worse,
the module is a Python package, and has higher priority than
the official one which is a standalone Python file.
This causes errors when both of them are installed.
Starting from version 12.0, it migrates to a new module
named `pynvml_utils` to avoid the conflict.
It is so confusing that many packages in the community use the
unofficial one by mistake, and we have to handle this case.
For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial
one, and it will cause errors, see the issue
https://github.com/vllm-project/vllm/issues/12847 for example.
After all the troubles, we decide to copy the official `pynvml`
module to our codebase, and use it directly.
"""
import vllm.third_party.pynvml as pynvml
return pynvml
def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
"""
A replacement for `abc.ABC`.
@ -376,31 +163,6 @@ def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
return cls
# Only relevant for models using ALiBi (e.g, MPT)
def check_use_alibi(model_config: ModelConfig) -> bool:
cfg = model_config.hf_text_config
return (
getattr(cfg, "alibi", False) # Falcon
or (
"BloomForCausalLM" in getattr(model_config.hf_config, "architectures", [])
) # Bloom
or getattr(cfg, "position_encoding_type", "") == "alibi" # codellm_1b_alibi
or (
hasattr(cfg, "attn_config") # MPT
and (
(
isinstance(cfg.attn_config, dict)
and cfg.attn_config.get("alibi", False)
)
or (
not isinstance(cfg.attn_config, dict)
and getattr(cfg.attn_config, "alibi", False)
)
)
)
)
def length_from_prompt_token_ids_or_embeds(
prompt_token_ids: list[int] | None,
prompt_embeds: torch.Tensor | None,

View File

@ -10,37 +10,21 @@ from argparse import (
ArgumentDefaultsHelpFormatter,
ArgumentParser,
ArgumentTypeError,
Namespace,
RawDescriptionHelpFormatter,
_ArgumentGroup,
)
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from typing import Any
import regex as re
import yaml
from vllm.logger import init_logger
if TYPE_CHECKING:
from argparse import Namespace
else:
Namespace = object
logger = init_logger(__name__)
class StoreBoolean(Action):
def __call__(self, parser, namespace, values, option_string=None):
if values.lower() == "true":
setattr(namespace, self.dest, True)
elif values.lower() == "false":
setattr(namespace, self.dest, False)
else:
raise ValueError(
f"Invalid boolean value: {values}. Expected 'true' or 'false'."
)
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
"""SortedHelpFormatter that sorts arguments by their option strings."""
@ -487,12 +471,8 @@ class FlexibleArgumentParser(ArgumentParser):
)
raise ex
store_boolean_arguments = [
action.dest for action in self._actions if isinstance(action, StoreBoolean)
]
for key, value in config.items():
if isinstance(value, bool) and key not in store_boolean_arguments:
if isinstance(value, bool):
if value:
processed_args.append("--" + key)
elif isinstance(value, list):

View File

@ -19,6 +19,49 @@ import regex as re
from typing_extensions import Never
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
def init_cached_hf_modules() -> None:
"""
Lazy initialization of the Hugging Face modules.
"""
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
def import_pynvml():
"""
Historical comments:
libnvml.so is the library behind nvidia-smi, and
pynvml is a Python wrapper around it. We use it to get GPU
status without initializing CUDA context in the current process.
Historically, there are two packages that provide pynvml:
- `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
wrapper. It is a dependency of vLLM, and is installed when users
install vLLM. It provides a Python module named `pynvml`.
- `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
Prior to version 12.0, it also provides a Python module `pynvml`,
and therefore conflicts with the official one. What's worse,
the module is a Python package, and has higher priority than
the official one which is a standalone Python file.
This causes errors when both of them are installed.
Starting from version 12.0, it migrates to a new module
named `pynvml_utils` to avoid the conflict.
It is so confusing that many packages in the community use the
unofficial one by mistake, and we have to handle this case.
For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial
one, and it will cause errors, see the issue
https://github.com/vllm-project/vllm/issues/12847 for example.
After all the troubles, we decide to copy the official `pynvml`
module to our codebase, and use it directly.
"""
import vllm.third_party.pynvml as pynvml
return pynvml
def import_from_path(module_name: str, file_path: str | os.PathLike):
"""
Import a Python file according to its file path.

View File

@ -4,19 +4,21 @@
from __future__ import annotations
import contextlib
import multiprocessing
import os
import signal
import sys
from collections.abc import Callable, Iterator
from pathlib import Path
from typing import TextIO
try:
import setproctitle
except ImportError:
setproctitle = None # type: ignore[assignment]
import psutil
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.ray.lazy_utils import is_in_ray_actor
from .platform_utils import cuda_is_initialized, xpu_is_initialized
logger = init_logger(__name__)
@ -75,14 +77,66 @@ def unique_filepath(fn: Callable[[int], Path]) -> Path:
# Process management utilities
def _maybe_force_spawn():
"""Check if we need to force the use of the `spawn` multiprocessing start
method.
"""
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
return
reasons = []
if is_in_ray_actor():
# even if we choose to spawn, we need to pass the ray address
# to the subprocess so that it knows how to connect to the ray cluster.
# env vars are inherited by subprocesses, even if we use spawn.
import ray
os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
reasons.append("In a Ray actor and can only be spawned")
if cuda_is_initialized():
reasons.append("CUDA is initialized")
elif xpu_is_initialized():
reasons.append("XPU is initialized")
if reasons:
logger.warning(
"We must use the `spawn` multiprocessing start method. "
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/usage/"
"troubleshooting.html#python-multiprocessing "
"for more information. Reasons: %s",
"; ".join(reasons),
)
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def get_mp_context():
"""Get a multiprocessing context with a particular method (spawn or fork).
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
determine the multiprocessing method (default is fork). However, under
certain conditions, we may enforce spawn and override the value of
VLLM_WORKER_MULTIPROC_METHOD.
"""
_maybe_force_spawn()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)
def set_process_title(
name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX
name: str,
suffix: str = "",
prefix: str = envs.VLLM_PROCESS_NAME_PREFIX,
) -> None:
"""Set the current process title with optional suffix."""
if setproctitle is None:
try:
import setproctitle
except ImportError:
return
if suffix:
name = f"{name}_{suffix}"
setproctitle.setproctitle(f"{prefix}::{name}")
@ -114,10 +168,62 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
def decorate_logs(process_name: str | None = None) -> None:
"""Decorate stdout/stderr with process name and PID prefix."""
from vllm.utils import get_mp_context
if process_name is None:
process_name = get_mp_context().current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
def kill_process_tree(pid: int):
"""
Kills all descendant processes of the given pid by sending SIGKILL.
Args:
pid (int): Process ID of the parent process
"""
try:
parent = psutil.Process(pid)
except psutil.NoSuchProcess:
return
# Get all children recursively
children = parent.children(recursive=True)
# Send SIGKILL to all children first
for child in children:
with contextlib.suppress(ProcessLookupError):
os.kill(child.pid, signal.SIGKILL)
# Finally kill the parent
with contextlib.suppress(ProcessLookupError):
os.kill(pid, signal.SIGKILL)
# Resource utilities
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630
def set_ulimit(target_soft_limit: int = 65535):
if sys.platform.startswith("win"):
logger.info("Windows detected, skipping ulimit adjustment.")
return
import resource
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
if current_soft < target_soft_limit:
try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
except ValueError as e:
logger.warning(
"Found ulimit of %s and failed to automatically increase "
"with error %s. This can cause fd limit errors like "
"`OSError: [Errno 24] Too many open files`. Consider "
"increasing with ulimit -n",
current_soft,
e,
)

View File

@ -10,9 +10,8 @@ import zmq
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import get_mp_context
from vllm.utils.network_utils import make_zmq_socket
from vllm.utils.system_utils import set_process_title
from vllm.utils.system_utils import get_mp_context, set_process_title
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
from vllm.v1.serial_utils import MsgpackDecoder
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown

View File

@ -20,8 +20,8 @@ from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy
from vllm.utils import get_mp_context
from vllm.utils.network_utils import get_open_zmq_ipc_path, zmq_socket_ctx
from vllm.utils.system_utils import get_mp_context
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.executor import Executor
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown

View File

@ -35,13 +35,17 @@ from vllm.distributed.parallel_state import (
)
from vllm.envs import enable_envs_cache
from vllm.logger import init_logger
from vllm.utils import _maybe_force_spawn, get_mp_context
from vllm.utils.network_utils import (
get_distributed_init_method,
get_loopback_ip,
get_open_port,
)
from vllm.utils.system_utils import decorate_logs, set_process_title
from vllm.utils.system_utils import (
_maybe_force_spawn,
decorate_logs,
get_mp_context,
set_process_title,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput

View File

@ -12,11 +12,11 @@ import torch.distributed as dist
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import run_method
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.v1.serial_utils import run_method
from vllm.v1.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)

View File

@ -5,6 +5,7 @@ import dataclasses
import importlib
import pickle
from collections.abc import Callable, Sequence
from functools import partial
from inspect import isclass
from types import FunctionType
from typing import Any, TypeAlias
@ -429,3 +430,30 @@ class MsgpackDecoder:
return cloudpickle.loads(data)
raise NotImplementedError(f"Extension type code {code} is not supported")
def run_method(
obj: Any,
method: str | bytes | Callable,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Any:
"""
Run a method of an object with the given arguments and keyword arguments.
If the method is string, it will be converted to a method using getattr.
If the method is serialized bytes and will be deserialized using
cloudpickle.
If the method is a callable, it will be called directly.
"""
if isinstance(method, bytes):
func = partial(cloudpickle.loads(method), obj)
elif isinstance(method, str):
try:
func = getattr(obj, method)
except AttributeError:
raise NotImplementedError(
f"Method {method!r} is not implemented."
) from None
else:
func = partial(method, obj) # type: ignore
return func(*args, **kwargs)

View File

@ -25,8 +25,8 @@ from torch.autograd.profiler import record_function
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message
from vllm.utils import kill_process_tree
from vllm.utils.network_utils import get_open_port, get_open_zmq_ipc_path, get_tcp_uri
from vllm.utils.system_utils import kill_process_tree
if TYPE_CHECKING:
import numpy as np

View File

@ -69,10 +69,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (
check_use_alibi,
length_from_prompt_token_ids_or_embeds,
)
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.mem_constants import GiB_bytes
@ -266,7 +263,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.hidden_size = model_config.get_hidden_size()
self.attention_chunk_size = model_config.attention_chunk_size
# Only relevant for models using ALiBi (e.g, MPT)
self.use_alibi = check_use_alibi(model_config)
self.use_alibi = model_config.uses_alibi
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn

View File

@ -72,7 +72,7 @@ class Worker(WorkerBase):
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
from vllm.utils.import_utils import init_cached_hf_modules
init_cached_hf_modules()

View File

@ -89,7 +89,7 @@ class TPUWorker:
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
from vllm.utils.import_utils import init_cached_hf_modules
init_cached_hf_modules()

View File

@ -13,14 +13,11 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils import (
enable_trace_function_call_for_thread,
run_method,
warn_for_unimplemented_methods,
)
from vllm.utils import warn_for_unimplemented_methods
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.system_utils import update_environment_variables
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.serial_utils import run_method
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@ -182,19 +179,20 @@ class WorkerWrapperBase:
"""
self.rpc_rank = rpc_rank
self.worker: WorkerBase | None = None
self.vllm_config: VllmConfig | None = None
# do not store this `vllm_config`, `init_worker` will set the final
# one. TODO: investigate if we can remove this field in
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
# unnecessary now.
if vllm_config.model_config is not None:
# it can be None in tests
trust_remote_code = vllm_config.model_config.trust_remote_code
if trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# do not store this `vllm_config`, `init_worker` will set the final
# one.
# TODO: investigate if we can remove this field in `WorkerWrapperBase`,
# `init_cached_hf_modules` should be unnecessary now.
self.vllm_config: VllmConfig | None = None
# `model_config` can be None in tests
model_config = vllm_config.model_config
if model_config and model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils.import_utils import init_cached_hf_modules
init_cached_hf_modules()
def shutdown(self) -> None:
if self.worker is not None:
@ -231,7 +229,7 @@ class WorkerWrapperBase:
assert self.vllm_config is not None, (
"vllm_config is required to initialize the worker"
)
enable_trace_function_call_for_thread(self.vllm_config)
self.vllm_config.enable_trace_function_call_for_thread()
from vllm.plugins import load_general_plugins