From 6ebffafbb609963d696d27c5e334fd0b9fe0add7 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 27 Oct 2025 23:30:38 +0800 Subject: [PATCH] [Misc] Clean up more utils (#27567) Signed-off-by: DarkLight1337 --- requirements/docs.txt | 2 + tools/pre_commit/check_pickle_imports.py | 1 - vllm/config/model.py | 23 +++ vllm/config/vllm.py | 28 ++- vllm/engine/arg_utils.py | 3 +- vllm/entrypoints/anthropic/api_server.py | 2 +- vllm/entrypoints/api_server.py | 3 +- vllm/entrypoints/openai/api_server.py | 4 +- vllm/platforms/__init__.py | 2 +- vllm/platforms/cuda.py | 2 +- vllm/utils/__init__.py | 246 +---------------------- vllm/utils/argparse_utils.py | 26 +-- vllm/utils/import_utils.py | 43 ++++ vllm/utils/system_utils.py | 122 ++++++++++- vllm/v1/engine/coordinator.py | 3 +- vllm/v1/engine/utils.py | 2 +- vllm/v1/executor/multiproc_executor.py | 8 +- vllm/v1/executor/uniproc_executor.py | 2 +- vllm/v1/serial_utils.py | 28 +++ vllm/v1/utils.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 7 +- vllm/v1/worker/gpu_worker.py | 2 +- vllm/v1/worker/tpu_worker.py | 2 +- vllm/v1/worker/worker_base.py | 34 ++-- 24 files changed, 282 insertions(+), 315 deletions(-) diff --git a/requirements/docs.txt b/requirements/docs.txt index d1c546398780..00c314874016 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -13,6 +13,8 @@ ruff # Required for argparse hook only -f https://download.pytorch.org/whl/cpu cachetools +cloudpickle +py-cpuinfo msgspec pydantic torch diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py index c9256cd91a4e..b96a6701333d 100644 --- a/tools/pre_commit/check_pickle_imports.py +++ b/tools/pre_commit/check_pickle_imports.py @@ -39,7 +39,6 @@ ALLOWED_FILES = { "vllm/v1/executor/multiproc_executor.py", "vllm/v1/executor/ray_executor.py", "vllm/entrypoints/llm.py", - "vllm/utils/__init__.py", "tests/utils.py", # pickle and cloudpickle "vllm/v1/serial_utils.py", diff --git a/vllm/config/model.py b/vllm/config/model.py index adb0dd9ac9f5..c335c5c25e9e 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1618,6 +1618,29 @@ class ModelConfig: """Extract the HF encoder/decoder model flag.""" return is_encoder_decoder(self.hf_config) + @property + def uses_alibi(self) -> bool: + cfg = self.hf_text_config + + return ( + getattr(cfg, "alibi", False) # Falcon + or "BloomForCausalLM" in self.architectures # Bloom + or getattr(cfg, "position_encoding_type", "") == "alibi" # codellm_1b_alibi + or ( + hasattr(cfg, "attn_config") # MPT + and ( + ( + isinstance(cfg.attn_config, dict) + and cfg.attn_config.get("alibi", False) + ) + or ( + not isinstance(cfg.attn_config, dict) + and getattr(cfg.attn_config, "alibi", False) + ) + ) + ) + ) + @property def uses_mrope(self) -> bool: return uses_mrope(self.hf_config) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 916f258d6586..597cf5793963 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -2,12 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +import getpass import hashlib import json import os +import tempfile +import threading import time from contextlib import contextmanager from dataclasses import replace +from datetime import datetime from functools import lru_cache from pathlib import Path from typing import TYPE_CHECKING, Any, TypeVar @@ -17,7 +21,7 @@ from pydantic import ConfigDict, Field from pydantic.dataclasses import dataclass import vllm.envs as envs -from vllm.logger import init_logger +from vllm.logger import enable_trace_function_call, init_logger from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.utils import random_uuid @@ -206,6 +210,28 @@ class VllmConfig: # i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size return self.compilation_config.bs_to_padded_graph_size[batch_size] + def enable_trace_function_call_for_thread(self) -> None: + """ + Set up function tracing for the current thread, + if enabled via the `VLLM_TRACE_FUNCTION` environment variable. + """ + if envs.VLLM_TRACE_FUNCTION: + tmp_dir = tempfile.gettempdir() + # add username to tmp_dir to avoid permission issues + tmp_dir = os.path.join(tmp_dir, getpass.getuser()) + filename = ( + f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" + f"_thread_{threading.get_ident()}_at_{datetime.now()}.log" + ).replace(" ", "_") + log_path = os.path.join( + tmp_dir, + "vllm", + f"vllm-instance-{self.instance_id}", + filename, + ) + os.makedirs(os.path.dirname(log_path), exist_ok=True) + enable_trace_function_call(log_path) + @staticmethod def _get_quantization_config( model_config: ModelConfig, load_config: LoadConfig diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 617c464cff25..24f9d18dc958 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -73,7 +73,7 @@ from vllm.config.utils import get_field from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins -from vllm.ray.lazy_utils import is_ray_initialized +from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.transformers_utils.config import ( @@ -82,7 +82,6 @@ from vllm.transformers_utils.config import ( maybe_override_with_speculators, ) from vllm.transformers_utils.utils import check_gguf_file -from vllm.utils import is_in_ray_actor from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.mem_constants import GiB_bytes from vllm.utils.network_utils import get_ip diff --git a/vllm/entrypoints/anthropic/api_server.py b/vllm/entrypoints/anthropic/api_server.py index b575dcdc8e77..df877f99b084 100644 --- a/vllm/entrypoints/anthropic/api_server.py +++ b/vllm/entrypoints/anthropic/api_server.py @@ -51,9 +51,9 @@ from vllm.entrypoints.utils import ( with_cancellation, ) from vllm.logger import init_logger -from vllm.utils import set_ulimit from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.network_utils import is_valid_ipv6_address +from vllm.utils.system_utils import set_ulimit from vllm.version import __version__ as VLLM_VERSION prometheus_multiproc_dir: tempfile.TemporaryDirectory diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 184cc47ceb83..154cdeb42a3e 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -26,8 +26,9 @@ from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.usage.usage_lib import UsageContext -from vllm.utils import random_uuid, set_ulimit +from vllm.utils import random_uuid from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.system_utils import set_ulimit from vllm.version import __version__ as VLLM_VERSION logger = init_logger("vllm.entrypoints.api_server") diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 1a785e49df2b..632bd741290b 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -108,10 +108,10 @@ from vllm.entrypoints.utils import ( from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager from vllm.usage.usage_lib import UsageContext -from vllm.utils import Device, set_ulimit +from vllm.utils import Device from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.network_utils import is_valid_ipv6_address -from vllm.utils.system_utils import decorate_logs +from vllm.utils.system_utils import decorate_logs, set_ulimit from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.metrics.prometheus import get_prometheus_registry from vllm.version import __version__ as VLLM_VERSION diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index f64d7a010b5f..badf72de4a90 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -60,7 +60,7 @@ def cuda_platform_plugin() -> str | None: is_cuda = False logger.debug("Checking if CUDA platform is available.") try: - from vllm.utils import import_pynvml + from vllm.utils.import_utils import import_pynvml pynvml = import_pynvml() pynvml.nvmlInit() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 637f35a4920e..66cffde9503d 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -16,7 +16,7 @@ from typing_extensions import ParamSpec import vllm._C # noqa import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import import_pynvml +from vllm.utils.import_utils import import_pynvml from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 55efeb41fe53..eaa78839cf3f 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -1,34 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import contextlib -import datetime import enum -import getpass import inspect -import multiprocessing -import os -import signal -import sys -import tempfile import threading import uuid import warnings -from collections.abc import Callable -from functools import partial, wraps -from typing import TYPE_CHECKING, Any, TypeVar +from functools import wraps +from typing import Any, TypeVar -import cloudpickle -import psutil import torch -import vllm.envs as envs -from vllm.logger import enable_trace_function_call, init_logger -from vllm.ray.lazy_utils import is_in_ray_actor +from vllm.logger import init_logger _DEPRECATED_MAPPINGS = { "cprofile": "profiling", "cprofile_context": "profiling", + # Used by lm-eval "get_open_port": "network_utils", } @@ -53,12 +41,6 @@ def __dir__() -> list[str]: return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys())) -if TYPE_CHECKING: - from vllm.config import ModelConfig, VllmConfig -else: - ModelConfig = object - VllmConfig = object - logger = init_logger(__name__) # This value is chosen to have a balance between ITL and TTFT. Note it is @@ -83,13 +65,7 @@ STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_INVALID_VAL: str = "INVALID" -# ANSI color codes -CYAN = "\033[1;36m" -RESET = "\033[0;0m" - - T = TypeVar("T") -U = TypeVar("U") class Device(enum.Enum): @@ -144,195 +120,6 @@ def random_uuid() -> str: return str(uuid.uuid4().hex) -# TODO: This function can be removed if transformer_modules classes are -# serialized by value when communicating between processes -def init_cached_hf_modules() -> None: - """ - Lazy initialization of the Hugging Face modules. - """ - from transformers.dynamic_module_utils import init_hf_modules - - init_hf_modules() - - -def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: - """Set up function tracing for the current thread, - if enabled via the VLLM_TRACE_FUNCTION environment variable - """ - - if envs.VLLM_TRACE_FUNCTION: - tmp_dir = tempfile.gettempdir() - # add username to tmp_dir to avoid permission issues - tmp_dir = os.path.join(tmp_dir, getpass.getuser()) - filename = ( - f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" - f"_thread_{threading.get_ident()}_" - f"at_{datetime.datetime.now()}.log" - ).replace(" ", "_") - log_path = os.path.join( - tmp_dir, "vllm", f"vllm-instance-{vllm_config.instance_id}", filename - ) - os.makedirs(os.path.dirname(log_path), exist_ok=True) - enable_trace_function_call(log_path) - - -def kill_process_tree(pid: int): - """ - Kills all descendant processes of the given pid by sending SIGKILL. - - Args: - pid (int): Process ID of the parent process - """ - try: - parent = psutil.Process(pid) - except psutil.NoSuchProcess: - return - - # Get all children recursively - children = parent.children(recursive=True) - - # Send SIGKILL to all children first - for child in children: - with contextlib.suppress(ProcessLookupError): - os.kill(child.pid, signal.SIGKILL) - - # Finally kill the parent - with contextlib.suppress(ProcessLookupError): - os.kill(pid, signal.SIGKILL) - - -# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501 -def set_ulimit(target_soft_limit=65535): - if sys.platform.startswith("win"): - logger.info("Windows detected, skipping ulimit adjustment.") - return - - import resource - - resource_type = resource.RLIMIT_NOFILE - current_soft, current_hard = resource.getrlimit(resource_type) - - if current_soft < target_soft_limit: - try: - resource.setrlimit(resource_type, (target_soft_limit, current_hard)) - except ValueError as e: - logger.warning( - "Found ulimit of %s and failed to automatically increase " - "with error %s. This can cause fd limit errors like " - "`OSError: [Errno 24] Too many open files`. Consider " - "increasing with ulimit -n", - current_soft, - e, - ) - - -def _maybe_force_spawn(): - """Check if we need to force the use of the `spawn` multiprocessing start - method. - """ - if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn": - return - - reasons = [] - if is_in_ray_actor(): - # even if we choose to spawn, we need to pass the ray address - # to the subprocess so that it knows how to connect to the ray cluster. - # env vars are inherited by subprocesses, even if we use spawn. - import ray - - os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address - reasons.append("In a Ray actor and can only be spawned") - - from .platform_utils import cuda_is_initialized, xpu_is_initialized - - if cuda_is_initialized(): - reasons.append("CUDA is initialized") - elif xpu_is_initialized(): - reasons.append("XPU is initialized") - - if reasons: - logger.warning( - "We must use the `spawn` multiprocessing start method. " - "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " - "See https://docs.vllm.ai/en/latest/usage/" - "troubleshooting.html#python-multiprocessing " - "for more information. Reasons: %s", - "; ".join(reasons), - ) - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - - -def get_mp_context(): - """Get a multiprocessing context with a particular method (spawn or fork). - By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to - determine the multiprocessing method (default is fork). However, under - certain conditions, we may enforce spawn and override the value of - VLLM_WORKER_MULTIPROC_METHOD. - """ - _maybe_force_spawn() - mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD - return multiprocessing.get_context(mp_method) - - -def run_method( - obj: Any, - method: str | bytes | Callable, - args: tuple[Any], - kwargs: dict[str, Any], -) -> Any: - """ - Run a method of an object with the given arguments and keyword arguments. - If the method is string, it will be converted to a method using getattr. - If the method is serialized bytes and will be deserialized using - cloudpickle. - If the method is a callable, it will be called directly. - """ - if isinstance(method, bytes): - func = partial(cloudpickle.loads(method), obj) - elif isinstance(method, str): - try: - func = getattr(obj, method) - except AttributeError: - raise NotImplementedError( - f"Method {method!r} is not implemented." - ) from None - else: - func = partial(method, obj) # type: ignore - return func(*args, **kwargs) - - -def import_pynvml(): - """ - Historical comments: - - libnvml.so is the library behind nvidia-smi, and - pynvml is a Python wrapper around it. We use it to get GPU - status without initializing CUDA context in the current process. - Historically, there are two packages that provide pynvml: - - `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official - wrapper. It is a dependency of vLLM, and is installed when users - install vLLM. It provides a Python module named `pynvml`. - - `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper. - Prior to version 12.0, it also provides a Python module `pynvml`, - and therefore conflicts with the official one. What's worse, - the module is a Python package, and has higher priority than - the official one which is a standalone Python file. - This causes errors when both of them are installed. - Starting from version 12.0, it migrates to a new module - named `pynvml_utils` to avoid the conflict. - It is so confusing that many packages in the community use the - unofficial one by mistake, and we have to handle this case. - For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial - one, and it will cause errors, see the issue - https://github.com/vllm-project/vllm/issues/12847 for example. - After all the troubles, we decide to copy the official `pynvml` - module to our codebase, and use it directly. - """ - import vllm.third_party.pynvml as pynvml - - return pynvml - - def warn_for_unimplemented_methods(cls: type[T]) -> type[T]: """ A replacement for `abc.ABC`. @@ -376,31 +163,6 @@ def warn_for_unimplemented_methods(cls: type[T]) -> type[T]: return cls -# Only relevant for models using ALiBi (e.g, MPT) -def check_use_alibi(model_config: ModelConfig) -> bool: - cfg = model_config.hf_text_config - return ( - getattr(cfg, "alibi", False) # Falcon - or ( - "BloomForCausalLM" in getattr(model_config.hf_config, "architectures", []) - ) # Bloom - or getattr(cfg, "position_encoding_type", "") == "alibi" # codellm_1b_alibi - or ( - hasattr(cfg, "attn_config") # MPT - and ( - ( - isinstance(cfg.attn_config, dict) - and cfg.attn_config.get("alibi", False) - ) - or ( - not isinstance(cfg.attn_config, dict) - and getattr(cfg.attn_config, "alibi", False) - ) - ) - ) - ) - - def length_from_prompt_token_ids_or_embeds( prompt_token_ids: list[int] | None, prompt_embeds: torch.Tensor | None, diff --git a/vllm/utils/argparse_utils.py b/vllm/utils/argparse_utils.py index 0007c72f1e38..3d105a3685b3 100644 --- a/vllm/utils/argparse_utils.py +++ b/vllm/utils/argparse_utils.py @@ -10,37 +10,21 @@ from argparse import ( ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError, + Namespace, RawDescriptionHelpFormatter, _ArgumentGroup, ) from collections import defaultdict -from typing import TYPE_CHECKING, Any +from typing import Any import regex as re import yaml from vllm.logger import init_logger -if TYPE_CHECKING: - from argparse import Namespace -else: - Namespace = object - logger = init_logger(__name__) -class StoreBoolean(Action): - def __call__(self, parser, namespace, values, option_string=None): - if values.lower() == "true": - setattr(namespace, self.dest, True) - elif values.lower() == "false": - setattr(namespace, self.dest, False) - else: - raise ValueError( - f"Invalid boolean value: {values}. Expected 'true' or 'false'." - ) - - class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter): """SortedHelpFormatter that sorts arguments by their option strings.""" @@ -487,12 +471,8 @@ class FlexibleArgumentParser(ArgumentParser): ) raise ex - store_boolean_arguments = [ - action.dest for action in self._actions if isinstance(action, StoreBoolean) - ] - for key, value in config.items(): - if isinstance(value, bool) and key not in store_boolean_arguments: + if isinstance(value, bool): if value: processed_args.append("--" + key) elif isinstance(value, list): diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py index 65f588b52e5e..409a5a6cd302 100644 --- a/vllm/utils/import_utils.py +++ b/vllm/utils/import_utils.py @@ -19,6 +19,49 @@ import regex as re from typing_extensions import Never +# TODO: This function can be removed if transformer_modules classes are +# serialized by value when communicating between processes +def init_cached_hf_modules() -> None: + """ + Lazy initialization of the Hugging Face modules. + """ + from transformers.dynamic_module_utils import init_hf_modules + + init_hf_modules() + + +def import_pynvml(): + """ + Historical comments: + + libnvml.so is the library behind nvidia-smi, and + pynvml is a Python wrapper around it. We use it to get GPU + status without initializing CUDA context in the current process. + Historically, there are two packages that provide pynvml: + - `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official + wrapper. It is a dependency of vLLM, and is installed when users + install vLLM. It provides a Python module named `pynvml`. + - `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper. + Prior to version 12.0, it also provides a Python module `pynvml`, + and therefore conflicts with the official one. What's worse, + the module is a Python package, and has higher priority than + the official one which is a standalone Python file. + This causes errors when both of them are installed. + Starting from version 12.0, it migrates to a new module + named `pynvml_utils` to avoid the conflict. + It is so confusing that many packages in the community use the + unofficial one by mistake, and we have to handle this case. + For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial + one, and it will cause errors, see the issue + https://github.com/vllm-project/vllm/issues/12847 for example. + After all the troubles, we decide to copy the official `pynvml` + module to our codebase, and use it directly. + """ + import vllm.third_party.pynvml as pynvml + + return pynvml + + def import_from_path(module_name: str, file_path: str | os.PathLike): """ Import a Python file according to its file path. diff --git a/vllm/utils/system_utils.py b/vllm/utils/system_utils.py index dd18adf55e1f..5968884e232a 100644 --- a/vllm/utils/system_utils.py +++ b/vllm/utils/system_utils.py @@ -4,19 +4,21 @@ from __future__ import annotations import contextlib +import multiprocessing import os +import signal import sys from collections.abc import Callable, Iterator from pathlib import Path from typing import TextIO -try: - import setproctitle -except ImportError: - setproctitle = None # type: ignore[assignment] +import psutil import vllm.envs as envs from vllm.logger import init_logger +from vllm.ray.lazy_utils import is_in_ray_actor + +from .platform_utils import cuda_is_initialized, xpu_is_initialized logger = init_logger(__name__) @@ -75,14 +77,66 @@ def unique_filepath(fn: Callable[[int], Path]) -> Path: # Process management utilities +def _maybe_force_spawn(): + """Check if we need to force the use of the `spawn` multiprocessing start + method. + """ + if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn": + return + + reasons = [] + if is_in_ray_actor(): + # even if we choose to spawn, we need to pass the ray address + # to the subprocess so that it knows how to connect to the ray cluster. + # env vars are inherited by subprocesses, even if we use spawn. + import ray + + os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address + reasons.append("In a Ray actor and can only be spawned") + + if cuda_is_initialized(): + reasons.append("CUDA is initialized") + elif xpu_is_initialized(): + reasons.append("XPU is initialized") + + if reasons: + logger.warning( + "We must use the `spawn` multiprocessing start method. " + "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " + "See https://docs.vllm.ai/en/latest/usage/" + "troubleshooting.html#python-multiprocessing " + "for more information. Reasons: %s", + "; ".join(reasons), + ) + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +def get_mp_context(): + """Get a multiprocessing context with a particular method (spawn or fork). + By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to + determine the multiprocessing method (default is fork). However, under + certain conditions, we may enforce spawn and override the value of + VLLM_WORKER_MULTIPROC_METHOD. + """ + _maybe_force_spawn() + mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD + return multiprocessing.get_context(mp_method) + + def set_process_title( - name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX + name: str, + suffix: str = "", + prefix: str = envs.VLLM_PROCESS_NAME_PREFIX, ) -> None: """Set the current process title with optional suffix.""" - if setproctitle is None: + try: + import setproctitle + except ImportError: return + if suffix: name = f"{name}_{suffix}" + setproctitle.setproctitle(f"{prefix}::{name}") @@ -114,10 +168,62 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: def decorate_logs(process_name: str | None = None) -> None: """Decorate stdout/stderr with process name and PID prefix.""" - from vllm.utils import get_mp_context - if process_name is None: process_name = get_mp_context().current_process().name + pid = os.getpid() _add_prefix(sys.stdout, process_name, pid) _add_prefix(sys.stderr, process_name, pid) + + +def kill_process_tree(pid: int): + """ + Kills all descendant processes of the given pid by sending SIGKILL. + + Args: + pid (int): Process ID of the parent process + """ + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + return + + # Get all children recursively + children = parent.children(recursive=True) + + # Send SIGKILL to all children first + for child in children: + with contextlib.suppress(ProcessLookupError): + os.kill(child.pid, signal.SIGKILL) + + # Finally kill the parent + with contextlib.suppress(ProcessLookupError): + os.kill(pid, signal.SIGKILL) + + +# Resource utilities + + +# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 +def set_ulimit(target_soft_limit: int = 65535): + if sys.platform.startswith("win"): + logger.info("Windows detected, skipping ulimit adjustment.") + return + + import resource + + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + logger.warning( + "Found ulimit of %s and failed to automatically increase " + "with error %s. This can cause fd limit errors like " + "`OSError: [Errno 24] Too many open files`. Consider " + "increasing with ulimit -n", + current_soft, + e, + ) diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 39d8655ff858..953342cdd5d0 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -10,9 +10,8 @@ import zmq from vllm.config import ParallelConfig from vllm.logger import init_logger -from vllm.utils import get_mp_context from vllm.utils.network_utils import make_zmq_socket -from vllm.utils.system_utils import set_process_title +from vllm.utils.system_utils import get_mp_context, set_process_title from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType from vllm.v1.serial_utils import MsgpackDecoder from vllm.v1.utils import get_engine_client_zmq_addr, shutdown diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index ca416dbc0df9..bdc124b0571c 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -20,8 +20,8 @@ from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy -from vllm.utils import get_mp_context from vllm.utils.network_utils import get_open_zmq_ipc_path, zmq_socket_ctx +from vllm.utils.system_utils import get_mp_context from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.executor import Executor from vllm.v1.utils import get_engine_client_zmq_addr, shutdown diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 1b4b9c4550f7..4c58d5771c39 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -35,13 +35,17 @@ from vllm.distributed.parallel_state import ( ) from vllm.envs import enable_envs_cache from vllm.logger import init_logger -from vllm.utils import _maybe_force_spawn, get_mp_context from vllm.utils.network_utils import ( get_distributed_init_method, get_loopback_ip, get_open_port, ) -from vllm.utils.system_utils import decorate_logs, set_process_title +from vllm.utils.system_utils import ( + _maybe_force_spawn, + decorate_logs, + get_mp_context, + set_process_title, +) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index 0d072172fdf3..f17d3c309270 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -12,11 +12,11 @@ import torch.distributed as dist import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import run_method from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import AsyncModelRunnerOutput +from vllm.v1.serial_utils import run_method from vllm.v1.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 39147a67d6cf..102357ca7c64 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -5,6 +5,7 @@ import dataclasses import importlib import pickle from collections.abc import Callable, Sequence +from functools import partial from inspect import isclass from types import FunctionType from typing import Any, TypeAlias @@ -429,3 +430,30 @@ class MsgpackDecoder: return cloudpickle.loads(data) raise NotImplementedError(f"Extension type code {code} is not supported") + + +def run_method( + obj: Any, + method: str | bytes | Callable, + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> Any: + """ + Run a method of an object with the given arguments and keyword arguments. + If the method is string, it will be converted to a method using getattr. + If the method is serialized bytes and will be deserialized using + cloudpickle. + If the method is a callable, it will be called directly. + """ + if isinstance(method, bytes): + func = partial(cloudpickle.loads(method), obj) + elif isinstance(method, str): + try: + func = getattr(obj, method) + except AttributeError: + raise NotImplementedError( + f"Method {method!r} is not implemented." + ) from None + else: + func = partial(method, obj) # type: ignore + return func(*args, **kwargs) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 789a74cc6c4a..a401f6d74cdd 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -25,8 +25,8 @@ from torch.autograd.profiler import record_function import vllm.envs as envs from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message -from vllm.utils import kill_process_tree from vllm.utils.network_utils import get_open_port, get_open_zmq_ipc_path, get_tcp_uri +from vllm.utils.system_utils import kill_process_tree if TYPE_CHECKING: import numpy as np diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6759fe630e62..a110ad54a05e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -69,10 +69,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.utils import ( - check_use_alibi, - length_from_prompt_token_ids_or_embeds, -) +from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.utils.jsontree import json_map_leaves from vllm.utils.math_utils import cdiv, round_up from vllm.utils.mem_constants import GiB_bytes @@ -266,7 +263,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.hidden_size = model_config.get_hidden_size() self.attention_chunk_size = model_config.attention_chunk_size # Only relevant for models using ALiBi (e.g, MPT) - self.use_alibi = check_use_alibi(model_config) + self.use_alibi = model_config.uses_alibi self.cascade_attn_enabled = not self.model_config.disable_cascade_attn diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 3ed9cab42a14..29b6532e4366 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -72,7 +72,7 @@ class Worker(WorkerBase): if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules + from vllm.utils.import_utils import init_cached_hf_modules init_cached_hf_modules() diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index f1885f9b34a1..e867e3c07caa 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -89,7 +89,7 @@ class TPUWorker: if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules + from vllm.utils.import_utils import init_cached_hf_modules init_cached_hf_modules() diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index d912589ef73a..9162e2e85a51 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -13,14 +13,11 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import worker_receiver_cache_from_config -from vllm.utils import ( - enable_trace_function_call_for_thread, - run_method, - warn_for_unimplemented_methods, -) +from vllm.utils import warn_for_unimplemented_methods from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.system_utils import update_environment_variables from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.serial_utils import run_method if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -182,19 +179,20 @@ class WorkerWrapperBase: """ self.rpc_rank = rpc_rank self.worker: WorkerBase | None = None - self.vllm_config: VllmConfig | None = None - # do not store this `vllm_config`, `init_worker` will set the final - # one. TODO: investigate if we can remove this field in - # `WorkerWrapperBase`, `init_cached_hf_modules` should be - # unnecessary now. - if vllm_config.model_config is not None: - # it can be None in tests - trust_remote_code = vllm_config.model_config.trust_remote_code - if trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() + # do not store this `vllm_config`, `init_worker` will set the final + # one. + # TODO: investigate if we can remove this field in `WorkerWrapperBase`, + # `init_cached_hf_modules` should be unnecessary now. + self.vllm_config: VllmConfig | None = None + + # `model_config` can be None in tests + model_config = vllm_config.model_config + if model_config and model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils.import_utils import init_cached_hf_modules + + init_cached_hf_modules() def shutdown(self) -> None: if self.worker is not None: @@ -231,7 +229,7 @@ class WorkerWrapperBase: assert self.vllm_config is not None, ( "vllm_config is required to initialize the worker" ) - enable_trace_function_call_for_thread(self.vllm_config) + self.vllm_config.enable_trace_function_call_for_thread() from vllm.plugins import load_general_plugins