mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:04:58 +08:00
[Misc] Clean up more utils (#27567)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
3b96f85c36
commit
6ebffafbb6
@ -13,6 +13,8 @@ ruff
|
||||
# Required for argparse hook only
|
||||
-f https://download.pytorch.org/whl/cpu
|
||||
cachetools
|
||||
cloudpickle
|
||||
py-cpuinfo
|
||||
msgspec
|
||||
pydantic
|
||||
torch
|
||||
|
||||
@ -39,7 +39,6 @@ ALLOWED_FILES = {
|
||||
"vllm/v1/executor/multiproc_executor.py",
|
||||
"vllm/v1/executor/ray_executor.py",
|
||||
"vllm/entrypoints/llm.py",
|
||||
"vllm/utils/__init__.py",
|
||||
"tests/utils.py",
|
||||
# pickle and cloudpickle
|
||||
"vllm/v1/serial_utils.py",
|
||||
|
||||
@ -1618,6 +1618,29 @@ class ModelConfig:
|
||||
"""Extract the HF encoder/decoder model flag."""
|
||||
return is_encoder_decoder(self.hf_config)
|
||||
|
||||
@property
|
||||
def uses_alibi(self) -> bool:
|
||||
cfg = self.hf_text_config
|
||||
|
||||
return (
|
||||
getattr(cfg, "alibi", False) # Falcon
|
||||
or "BloomForCausalLM" in self.architectures # Bloom
|
||||
or getattr(cfg, "position_encoding_type", "") == "alibi" # codellm_1b_alibi
|
||||
or (
|
||||
hasattr(cfg, "attn_config") # MPT
|
||||
and (
|
||||
(
|
||||
isinstance(cfg.attn_config, dict)
|
||||
and cfg.attn_config.get("alibi", False)
|
||||
)
|
||||
or (
|
||||
not isinstance(cfg.attn_config, dict)
|
||||
and getattr(cfg.attn_config, "alibi", False)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def uses_mrope(self) -> bool:
|
||||
return uses_mrope(self.hf_config)
|
||||
|
||||
@ -2,12 +2,16 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import getpass
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import replace
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
@ -17,7 +21,7 @@ from pydantic import ConfigDict, Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
@ -206,6 +210,28 @@ class VllmConfig:
|
||||
# i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size
|
||||
return self.compilation_config.bs_to_padded_graph_size[batch_size]
|
||||
|
||||
def enable_trace_function_call_for_thread(self) -> None:
|
||||
"""
|
||||
Set up function tracing for the current thread,
|
||||
if enabled via the `VLLM_TRACE_FUNCTION` environment variable.
|
||||
"""
|
||||
if envs.VLLM_TRACE_FUNCTION:
|
||||
tmp_dir = tempfile.gettempdir()
|
||||
# add username to tmp_dir to avoid permission issues
|
||||
tmp_dir = os.path.join(tmp_dir, getpass.getuser())
|
||||
filename = (
|
||||
f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
|
||||
f"_thread_{threading.get_ident()}_at_{datetime.now()}.log"
|
||||
).replace(" ", "_")
|
||||
log_path = os.path.join(
|
||||
tmp_dir,
|
||||
"vllm",
|
||||
f"vllm-instance-{self.instance_id}",
|
||||
filename,
|
||||
)
|
||||
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
||||
enable_trace_function_call(log_path)
|
||||
|
||||
@staticmethod
|
||||
def _get_quantization_config(
|
||||
model_config: ModelConfig, load_config: LoadConfig
|
||||
|
||||
@ -73,7 +73,7 @@ from vllm.config.utils import get_field
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.plugins import load_general_plugins
|
||||
from vllm.ray.lazy_utils import is_ray_initialized
|
||||
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
from vllm.transformers_utils.config import (
|
||||
@ -82,7 +82,6 @@ from vllm.transformers_utils.config import (
|
||||
maybe_override_with_speculators,
|
||||
)
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import is_in_ray_actor
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.utils.network_utils import get_ip
|
||||
|
||||
@ -51,9 +51,9 @@ from vllm.entrypoints.utils import (
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import set_ulimit
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.network_utils import is_valid_ipv6_address
|
||||
from vllm.utils.system_utils import set_ulimit
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
|
||||
@ -26,8 +26,9 @@ from vllm.entrypoints.utils import with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import random_uuid, set_ulimit
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.system_utils import set_ulimit
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger("vllm.entrypoints.api_server")
|
||||
|
||||
@ -108,10 +108,10 @@ from vllm.entrypoints.utils import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device, set_ulimit
|
||||
from vllm.utils import Device
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.network_utils import is_valid_ipv6_address
|
||||
from vllm.utils.system_utils import decorate_logs
|
||||
from vllm.utils.system_utils import decorate_logs, set_ulimit
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
from vllm.v1.metrics.prometheus import get_prometheus_registry
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
@ -60,7 +60,7 @@ def cuda_platform_plugin() -> str | None:
|
||||
is_cuda = False
|
||||
logger.debug("Checking if CUDA platform is available.")
|
||||
try:
|
||||
from vllm.utils import import_pynvml
|
||||
from vllm.utils.import_utils import import_pynvml
|
||||
|
||||
pynvml = import_pynvml()
|
||||
pynvml.nvmlInit()
|
||||
|
||||
@ -16,7 +16,7 @@ from typing_extensions import ParamSpec
|
||||
import vllm._C # noqa
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import import_pynvml
|
||||
from vllm.utils.import_utils import import_pynvml
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
@ -1,34 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import enum
|
||||
import getpass
|
||||
import inspect
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from functools import partial, wraps
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
from functools import wraps
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import cloudpickle
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
from vllm.ray.lazy_utils import is_in_ray_actor
|
||||
from vllm.logger import init_logger
|
||||
|
||||
_DEPRECATED_MAPPINGS = {
|
||||
"cprofile": "profiling",
|
||||
"cprofile_context": "profiling",
|
||||
# Used by lm-eval
|
||||
"get_open_port": "network_utils",
|
||||
}
|
||||
|
||||
@ -53,12 +41,6 @@ def __dir__() -> list[str]:
|
||||
return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys()))
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
else:
|
||||
ModelConfig = object
|
||||
VllmConfig = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# This value is chosen to have a balance between ITL and TTFT. Note it is
|
||||
@ -83,13 +65,7 @@ STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||
STR_INVALID_VAL: str = "INVALID"
|
||||
|
||||
|
||||
# ANSI color codes
|
||||
CYAN = "\033[1;36m"
|
||||
RESET = "\033[0;0m"
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
|
||||
class Device(enum.Enum):
|
||||
@ -144,195 +120,6 @@ def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
# TODO: This function can be removed if transformer_modules classes are
|
||||
# serialized by value when communicating between processes
|
||||
def init_cached_hf_modules() -> None:
|
||||
"""
|
||||
Lazy initialization of the Hugging Face modules.
|
||||
"""
|
||||
from transformers.dynamic_module_utils import init_hf_modules
|
||||
|
||||
init_hf_modules()
|
||||
|
||||
|
||||
def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
|
||||
"""Set up function tracing for the current thread,
|
||||
if enabled via the VLLM_TRACE_FUNCTION environment variable
|
||||
"""
|
||||
|
||||
if envs.VLLM_TRACE_FUNCTION:
|
||||
tmp_dir = tempfile.gettempdir()
|
||||
# add username to tmp_dir to avoid permission issues
|
||||
tmp_dir = os.path.join(tmp_dir, getpass.getuser())
|
||||
filename = (
|
||||
f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
|
||||
f"_thread_{threading.get_ident()}_"
|
||||
f"at_{datetime.datetime.now()}.log"
|
||||
).replace(" ", "_")
|
||||
log_path = os.path.join(
|
||||
tmp_dir, "vllm", f"vllm-instance-{vllm_config.instance_id}", filename
|
||||
)
|
||||
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
||||
enable_trace_function_call(log_path)
|
||||
|
||||
|
||||
def kill_process_tree(pid: int):
|
||||
"""
|
||||
Kills all descendant processes of the given pid by sending SIGKILL.
|
||||
|
||||
Args:
|
||||
pid (int): Process ID of the parent process
|
||||
"""
|
||||
try:
|
||||
parent = psutil.Process(pid)
|
||||
except psutil.NoSuchProcess:
|
||||
return
|
||||
|
||||
# Get all children recursively
|
||||
children = parent.children(recursive=True)
|
||||
|
||||
# Send SIGKILL to all children first
|
||||
for child in children:
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
os.kill(child.pid, signal.SIGKILL)
|
||||
|
||||
# Finally kill the parent
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
|
||||
def set_ulimit(target_soft_limit=65535):
|
||||
if sys.platform.startswith("win"):
|
||||
logger.info("Windows detected, skipping ulimit adjustment.")
|
||||
return
|
||||
|
||||
import resource
|
||||
|
||||
resource_type = resource.RLIMIT_NOFILE
|
||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||
|
||||
if current_soft < target_soft_limit:
|
||||
try:
|
||||
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Found ulimit of %s and failed to automatically increase "
|
||||
"with error %s. This can cause fd limit errors like "
|
||||
"`OSError: [Errno 24] Too many open files`. Consider "
|
||||
"increasing with ulimit -n",
|
||||
current_soft,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
def _maybe_force_spawn():
|
||||
"""Check if we need to force the use of the `spawn` multiprocessing start
|
||||
method.
|
||||
"""
|
||||
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
|
||||
return
|
||||
|
||||
reasons = []
|
||||
if is_in_ray_actor():
|
||||
# even if we choose to spawn, we need to pass the ray address
|
||||
# to the subprocess so that it knows how to connect to the ray cluster.
|
||||
# env vars are inherited by subprocesses, even if we use spawn.
|
||||
import ray
|
||||
|
||||
os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
|
||||
reasons.append("In a Ray actor and can only be spawned")
|
||||
|
||||
from .platform_utils import cuda_is_initialized, xpu_is_initialized
|
||||
|
||||
if cuda_is_initialized():
|
||||
reasons.append("CUDA is initialized")
|
||||
elif xpu_is_initialized():
|
||||
reasons.append("XPU is initialized")
|
||||
|
||||
if reasons:
|
||||
logger.warning(
|
||||
"We must use the `spawn` multiprocessing start method. "
|
||||
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
|
||||
"See https://docs.vllm.ai/en/latest/usage/"
|
||||
"troubleshooting.html#python-multiprocessing "
|
||||
"for more information. Reasons: %s",
|
||||
"; ".join(reasons),
|
||||
)
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
def get_mp_context():
|
||||
"""Get a multiprocessing context with a particular method (spawn or fork).
|
||||
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
|
||||
determine the multiprocessing method (default is fork). However, under
|
||||
certain conditions, we may enforce spawn and override the value of
|
||||
VLLM_WORKER_MULTIPROC_METHOD.
|
||||
"""
|
||||
_maybe_force_spawn()
|
||||
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||
return multiprocessing.get_context(mp_method)
|
||||
|
||||
|
||||
def run_method(
|
||||
obj: Any,
|
||||
method: str | bytes | Callable,
|
||||
args: tuple[Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
"""
|
||||
Run a method of an object with the given arguments and keyword arguments.
|
||||
If the method is string, it will be converted to a method using getattr.
|
||||
If the method is serialized bytes and will be deserialized using
|
||||
cloudpickle.
|
||||
If the method is a callable, it will be called directly.
|
||||
"""
|
||||
if isinstance(method, bytes):
|
||||
func = partial(cloudpickle.loads(method), obj)
|
||||
elif isinstance(method, str):
|
||||
try:
|
||||
func = getattr(obj, method)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(
|
||||
f"Method {method!r} is not implemented."
|
||||
) from None
|
||||
else:
|
||||
func = partial(method, obj) # type: ignore
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def import_pynvml():
|
||||
"""
|
||||
Historical comments:
|
||||
|
||||
libnvml.so is the library behind nvidia-smi, and
|
||||
pynvml is a Python wrapper around it. We use it to get GPU
|
||||
status without initializing CUDA context in the current process.
|
||||
Historically, there are two packages that provide pynvml:
|
||||
- `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
|
||||
wrapper. It is a dependency of vLLM, and is installed when users
|
||||
install vLLM. It provides a Python module named `pynvml`.
|
||||
- `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
|
||||
Prior to version 12.0, it also provides a Python module `pynvml`,
|
||||
and therefore conflicts with the official one. What's worse,
|
||||
the module is a Python package, and has higher priority than
|
||||
the official one which is a standalone Python file.
|
||||
This causes errors when both of them are installed.
|
||||
Starting from version 12.0, it migrates to a new module
|
||||
named `pynvml_utils` to avoid the conflict.
|
||||
It is so confusing that many packages in the community use the
|
||||
unofficial one by mistake, and we have to handle this case.
|
||||
For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial
|
||||
one, and it will cause errors, see the issue
|
||||
https://github.com/vllm-project/vllm/issues/12847 for example.
|
||||
After all the troubles, we decide to copy the official `pynvml`
|
||||
module to our codebase, and use it directly.
|
||||
"""
|
||||
import vllm.third_party.pynvml as pynvml
|
||||
|
||||
return pynvml
|
||||
|
||||
|
||||
def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
|
||||
"""
|
||||
A replacement for `abc.ABC`.
|
||||
@ -376,31 +163,6 @@ def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
|
||||
return cls
|
||||
|
||||
|
||||
# Only relevant for models using ALiBi (e.g, MPT)
|
||||
def check_use_alibi(model_config: ModelConfig) -> bool:
|
||||
cfg = model_config.hf_text_config
|
||||
return (
|
||||
getattr(cfg, "alibi", False) # Falcon
|
||||
or (
|
||||
"BloomForCausalLM" in getattr(model_config.hf_config, "architectures", [])
|
||||
) # Bloom
|
||||
or getattr(cfg, "position_encoding_type", "") == "alibi" # codellm_1b_alibi
|
||||
or (
|
||||
hasattr(cfg, "attn_config") # MPT
|
||||
and (
|
||||
(
|
||||
isinstance(cfg.attn_config, dict)
|
||||
and cfg.attn_config.get("alibi", False)
|
||||
)
|
||||
or (
|
||||
not isinstance(cfg.attn_config, dict)
|
||||
and getattr(cfg.attn_config, "alibi", False)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def length_from_prompt_token_ids_or_embeds(
|
||||
prompt_token_ids: list[int] | None,
|
||||
prompt_embeds: torch.Tensor | None,
|
||||
|
||||
@ -10,37 +10,21 @@ from argparse import (
|
||||
ArgumentDefaultsHelpFormatter,
|
||||
ArgumentParser,
|
||||
ArgumentTypeError,
|
||||
Namespace,
|
||||
RawDescriptionHelpFormatter,
|
||||
_ArgumentGroup,
|
||||
)
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
import yaml
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
else:
|
||||
Namespace = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StoreBoolean(Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
if values.lower() == "true":
|
||||
setattr(namespace, self.dest, True)
|
||||
elif values.lower() == "false":
|
||||
setattr(namespace, self.dest, False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid boolean value: {values}. Expected 'true' or 'false'."
|
||||
)
|
||||
|
||||
|
||||
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
|
||||
"""SortedHelpFormatter that sorts arguments by their option strings."""
|
||||
|
||||
@ -487,12 +471,8 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
)
|
||||
raise ex
|
||||
|
||||
store_boolean_arguments = [
|
||||
action.dest for action in self._actions if isinstance(action, StoreBoolean)
|
||||
]
|
||||
|
||||
for key, value in config.items():
|
||||
if isinstance(value, bool) and key not in store_boolean_arguments:
|
||||
if isinstance(value, bool):
|
||||
if value:
|
||||
processed_args.append("--" + key)
|
||||
elif isinstance(value, list):
|
||||
|
||||
@ -19,6 +19,49 @@ import regex as re
|
||||
from typing_extensions import Never
|
||||
|
||||
|
||||
# TODO: This function can be removed if transformer_modules classes are
|
||||
# serialized by value when communicating between processes
|
||||
def init_cached_hf_modules() -> None:
|
||||
"""
|
||||
Lazy initialization of the Hugging Face modules.
|
||||
"""
|
||||
from transformers.dynamic_module_utils import init_hf_modules
|
||||
|
||||
init_hf_modules()
|
||||
|
||||
|
||||
def import_pynvml():
|
||||
"""
|
||||
Historical comments:
|
||||
|
||||
libnvml.so is the library behind nvidia-smi, and
|
||||
pynvml is a Python wrapper around it. We use it to get GPU
|
||||
status without initializing CUDA context in the current process.
|
||||
Historically, there are two packages that provide pynvml:
|
||||
- `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
|
||||
wrapper. It is a dependency of vLLM, and is installed when users
|
||||
install vLLM. It provides a Python module named `pynvml`.
|
||||
- `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
|
||||
Prior to version 12.0, it also provides a Python module `pynvml`,
|
||||
and therefore conflicts with the official one. What's worse,
|
||||
the module is a Python package, and has higher priority than
|
||||
the official one which is a standalone Python file.
|
||||
This causes errors when both of them are installed.
|
||||
Starting from version 12.0, it migrates to a new module
|
||||
named `pynvml_utils` to avoid the conflict.
|
||||
It is so confusing that many packages in the community use the
|
||||
unofficial one by mistake, and we have to handle this case.
|
||||
For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial
|
||||
one, and it will cause errors, see the issue
|
||||
https://github.com/vllm-project/vllm/issues/12847 for example.
|
||||
After all the troubles, we decide to copy the official `pynvml`
|
||||
module to our codebase, and use it directly.
|
||||
"""
|
||||
import vllm.third_party.pynvml as pynvml
|
||||
|
||||
return pynvml
|
||||
|
||||
|
||||
def import_from_path(module_name: str, file_path: str | os.PathLike):
|
||||
"""
|
||||
Import a Python file according to its file path.
|
||||
|
||||
@ -4,19 +4,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from collections.abc import Callable, Iterator
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
|
||||
try:
|
||||
import setproctitle
|
||||
except ImportError:
|
||||
setproctitle = None # type: ignore[assignment]
|
||||
import psutil
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.ray.lazy_utils import is_in_ray_actor
|
||||
|
||||
from .platform_utils import cuda_is_initialized, xpu_is_initialized
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -75,14 +77,66 @@ def unique_filepath(fn: Callable[[int], Path]) -> Path:
|
||||
# Process management utilities
|
||||
|
||||
|
||||
def _maybe_force_spawn():
|
||||
"""Check if we need to force the use of the `spawn` multiprocessing start
|
||||
method.
|
||||
"""
|
||||
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
|
||||
return
|
||||
|
||||
reasons = []
|
||||
if is_in_ray_actor():
|
||||
# even if we choose to spawn, we need to pass the ray address
|
||||
# to the subprocess so that it knows how to connect to the ray cluster.
|
||||
# env vars are inherited by subprocesses, even if we use spawn.
|
||||
import ray
|
||||
|
||||
os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
|
||||
reasons.append("In a Ray actor and can only be spawned")
|
||||
|
||||
if cuda_is_initialized():
|
||||
reasons.append("CUDA is initialized")
|
||||
elif xpu_is_initialized():
|
||||
reasons.append("XPU is initialized")
|
||||
|
||||
if reasons:
|
||||
logger.warning(
|
||||
"We must use the `spawn` multiprocessing start method. "
|
||||
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
|
||||
"See https://docs.vllm.ai/en/latest/usage/"
|
||||
"troubleshooting.html#python-multiprocessing "
|
||||
"for more information. Reasons: %s",
|
||||
"; ".join(reasons),
|
||||
)
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
def get_mp_context():
|
||||
"""Get a multiprocessing context with a particular method (spawn or fork).
|
||||
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
|
||||
determine the multiprocessing method (default is fork). However, under
|
||||
certain conditions, we may enforce spawn and override the value of
|
||||
VLLM_WORKER_MULTIPROC_METHOD.
|
||||
"""
|
||||
_maybe_force_spawn()
|
||||
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||
return multiprocessing.get_context(mp_method)
|
||||
|
||||
|
||||
def set_process_title(
|
||||
name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX
|
||||
name: str,
|
||||
suffix: str = "",
|
||||
prefix: str = envs.VLLM_PROCESS_NAME_PREFIX,
|
||||
) -> None:
|
||||
"""Set the current process title with optional suffix."""
|
||||
if setproctitle is None:
|
||||
try:
|
||||
import setproctitle
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
if suffix:
|
||||
name = f"{name}_{suffix}"
|
||||
|
||||
setproctitle.setproctitle(f"{prefix}::{name}")
|
||||
|
||||
|
||||
@ -114,10 +168,62 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
|
||||
|
||||
def decorate_logs(process_name: str | None = None) -> None:
|
||||
"""Decorate stdout/stderr with process name and PID prefix."""
|
||||
from vllm.utils import get_mp_context
|
||||
|
||||
if process_name is None:
|
||||
process_name = get_mp_context().current_process().name
|
||||
|
||||
pid = os.getpid()
|
||||
_add_prefix(sys.stdout, process_name, pid)
|
||||
_add_prefix(sys.stderr, process_name, pid)
|
||||
|
||||
|
||||
def kill_process_tree(pid: int):
|
||||
"""
|
||||
Kills all descendant processes of the given pid by sending SIGKILL.
|
||||
|
||||
Args:
|
||||
pid (int): Process ID of the parent process
|
||||
"""
|
||||
try:
|
||||
parent = psutil.Process(pid)
|
||||
except psutil.NoSuchProcess:
|
||||
return
|
||||
|
||||
# Get all children recursively
|
||||
children = parent.children(recursive=True)
|
||||
|
||||
# Send SIGKILL to all children first
|
||||
for child in children:
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
os.kill(child.pid, signal.SIGKILL)
|
||||
|
||||
# Finally kill the parent
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
|
||||
|
||||
# Resource utilities
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630
|
||||
def set_ulimit(target_soft_limit: int = 65535):
|
||||
if sys.platform.startswith("win"):
|
||||
logger.info("Windows detected, skipping ulimit adjustment.")
|
||||
return
|
||||
|
||||
import resource
|
||||
|
||||
resource_type = resource.RLIMIT_NOFILE
|
||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||
|
||||
if current_soft < target_soft_limit:
|
||||
try:
|
||||
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Found ulimit of %s and failed to automatically increase "
|
||||
"with error %s. This can cause fd limit errors like "
|
||||
"`OSError: [Errno 24] Too many open files`. Consider "
|
||||
"increasing with ulimit -n",
|
||||
current_soft,
|
||||
e,
|
||||
)
|
||||
|
||||
@ -10,9 +10,8 @@ import zmq
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_mp_context
|
||||
from vllm.utils.network_utils import make_zmq_socket
|
||||
from vllm.utils.system_utils import set_process_title
|
||||
from vllm.utils.system_utils import get_mp_context, set_process_title
|
||||
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
|
||||
from vllm.v1.serial_utils import MsgpackDecoder
|
||||
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
|
||||
|
||||
@ -20,8 +20,8 @@ from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||
from vllm.utils import get_mp_context
|
||||
from vllm.utils.network_utils import get_open_zmq_ipc_path, zmq_socket_ctx
|
||||
from vllm.utils.system_utils import get_mp_context
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
|
||||
|
||||
@ -35,13 +35,17 @@ from vllm.distributed.parallel_state import (
|
||||
)
|
||||
from vllm.envs import enable_envs_cache
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import _maybe_force_spawn, get_mp_context
|
||||
from vllm.utils.network_utils import (
|
||||
get_distributed_init_method,
|
||||
get_loopback_ip,
|
||||
get_open_port,
|
||||
)
|
||||
from vllm.utils.system_utils import decorate_logs, set_process_title
|
||||
from vllm.utils.system_utils import (
|
||||
_maybe_force_spawn,
|
||||
decorate_logs,
|
||||
get_mp_context,
|
||||
set_process_title,
|
||||
)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
|
||||
|
||||
@ -12,11 +12,11 @@ import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import run_method
|
||||
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput
|
||||
from vllm.v1.serial_utils import run_method
|
||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -5,6 +5,7 @@ import dataclasses
|
||||
import importlib
|
||||
import pickle
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from inspect import isclass
|
||||
from types import FunctionType
|
||||
from typing import Any, TypeAlias
|
||||
@ -429,3 +430,30 @@ class MsgpackDecoder:
|
||||
return cloudpickle.loads(data)
|
||||
|
||||
raise NotImplementedError(f"Extension type code {code} is not supported")
|
||||
|
||||
|
||||
def run_method(
|
||||
obj: Any,
|
||||
method: str | bytes | Callable,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
"""
|
||||
Run a method of an object with the given arguments and keyword arguments.
|
||||
If the method is string, it will be converted to a method using getattr.
|
||||
If the method is serialized bytes and will be deserialized using
|
||||
cloudpickle.
|
||||
If the method is a callable, it will be called directly.
|
||||
"""
|
||||
if isinstance(method, bytes):
|
||||
func = partial(cloudpickle.loads(method), obj)
|
||||
elif isinstance(method, str):
|
||||
try:
|
||||
func = getattr(obj, method)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(
|
||||
f"Method {method!r} is not implemented."
|
||||
) from None
|
||||
else:
|
||||
func = partial(method, obj) # type: ignore
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@ -25,8 +25,8 @@ from torch.autograd.profiler import record_function
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message
|
||||
from vllm.utils import kill_process_tree
|
||||
from vllm.utils.network_utils import get_open_port, get_open_zmq_ipc_path, get_tcp_uri
|
||||
from vllm.utils.system_utils import kill_process_tree
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
@ -69,10 +69,7 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
from vllm.utils import (
|
||||
check_use_alibi,
|
||||
length_from_prompt_token_ids_or_embeds,
|
||||
)
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
@ -266,7 +263,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
self.attention_chunk_size = model_config.attention_chunk_size
|
||||
# Only relevant for models using ALiBi (e.g, MPT)
|
||||
self.use_alibi = check_use_alibi(model_config)
|
||||
self.use_alibi = model_config.uses_alibi
|
||||
|
||||
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
|
||||
|
||||
|
||||
@ -72,7 +72,7 @@ class Worker(WorkerBase):
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
from vllm.utils.import_utils import init_cached_hf_modules
|
||||
|
||||
init_cached_hf_modules()
|
||||
|
||||
|
||||
@ -89,7 +89,7 @@ class TPUWorker:
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
from vllm.utils.import_utils import init_cached_hf_modules
|
||||
|
||||
init_cached_hf_modules()
|
||||
|
||||
|
||||
@ -13,14 +13,11 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import worker_receiver_cache_from_config
|
||||
from vllm.utils import (
|
||||
enable_trace_function_call_for_thread,
|
||||
run_method,
|
||||
warn_for_unimplemented_methods,
|
||||
)
|
||||
from vllm.utils import warn_for_unimplemented_methods
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.v1.serial_utils import run_method
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -182,19 +179,20 @@ class WorkerWrapperBase:
|
||||
"""
|
||||
self.rpc_rank = rpc_rank
|
||||
self.worker: WorkerBase | None = None
|
||||
self.vllm_config: VllmConfig | None = None
|
||||
# do not store this `vllm_config`, `init_worker` will set the final
|
||||
# one. TODO: investigate if we can remove this field in
|
||||
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
|
||||
# unnecessary now.
|
||||
if vllm_config.model_config is not None:
|
||||
# it can be None in tests
|
||||
trust_remote_code = vllm_config.model_config.trust_remote_code
|
||||
if trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
|
||||
init_cached_hf_modules()
|
||||
# do not store this `vllm_config`, `init_worker` will set the final
|
||||
# one.
|
||||
# TODO: investigate if we can remove this field in `WorkerWrapperBase`,
|
||||
# `init_cached_hf_modules` should be unnecessary now.
|
||||
self.vllm_config: VllmConfig | None = None
|
||||
|
||||
# `model_config` can be None in tests
|
||||
model_config = vllm_config.model_config
|
||||
if model_config and model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils.import_utils import init_cached_hf_modules
|
||||
|
||||
init_cached_hf_modules()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self.worker is not None:
|
||||
@ -231,7 +229,7 @@ class WorkerWrapperBase:
|
||||
assert self.vllm_config is not None, (
|
||||
"vllm_config is required to initialize the worker"
|
||||
)
|
||||
enable_trace_function_call_for_thread(self.vllm_config)
|
||||
self.vllm_config.enable_trace_function_call_for_thread()
|
||||
|
||||
from vllm.plugins import load_general_plugins
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user