mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 05:15:42 +08:00
[Misc] centralize all usage of environment variables (#4548)
This commit is contained in:
parent
1ff0c73a79
commit
5b8a7c1cb0
@ -1,10 +1,10 @@
|
|||||||
"""Attention layer ROCm GPUs."""
|
"""Attention layer ROCm GPUs."""
|
||||||
import os
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple, Type
|
from typing import Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
AttentionMetadataPerStage)
|
AttentionMetadataPerStage)
|
||||||
@ -156,8 +156,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
self.use_naive_attn = False
|
self.use_naive_attn = False
|
||||||
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
|
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
|
||||||
self.use_triton_flash_attn = (os.environ.get(
|
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
|
||||||
"VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1"))
|
|
||||||
if self.use_triton_flash_attn:
|
if self.use_triton_flash_attn:
|
||||||
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
|
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
|
||||||
triton_attention)
|
triton_attention)
|
||||||
|
|||||||
@ -1,18 +1,16 @@
|
|||||||
import enum
|
import enum
|
||||||
import os
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import is_cpu, is_hip
|
from vllm.utils import is_cpu, is_hip
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
|
|
||||||
|
|
||||||
|
|
||||||
class _Backend(enum.Enum):
|
class _Backend(enum.Enum):
|
||||||
FLASH_ATTN = enum.auto()
|
FLASH_ATTN = enum.auto()
|
||||||
@ -79,7 +77,7 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
|
|||||||
"package is not found. Please install it for better performance.")
|
"package is not found. Please install it for better performance.")
|
||||||
return _Backend.XFORMERS
|
return _Backend.XFORMERS
|
||||||
|
|
||||||
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
|
backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
|
||||||
if backend_by_env_var is not None:
|
if backend_by_env_var is not None:
|
||||||
return _Backend[backend_by_env_var]
|
return _Backend[backend_by_env_var]
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
|
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
|
||||||
|
|
||||||
@ -24,10 +23,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# If true, will load models from ModelScope instead of Hugging Face Hub.
|
|
||||||
VLLM_USE_MODELSCOPE = os.environ.get("VLLM_USE_MODELSCOPE",
|
|
||||||
"False").lower() == "true"
|
|
||||||
|
|
||||||
_GB = 1 << 30
|
_GB = 1 << 30
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
import os
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -54,9 +54,9 @@ def init_custom_ar() -> None:
|
|||||||
return
|
return
|
||||||
# test nvlink first, this will filter out most of the cases
|
# test nvlink first, this will filter out most of the cases
|
||||||
# where custom allreduce is not supported
|
# where custom allreduce is not supported
|
||||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||||
device_ids = list(
|
if cuda_visible_devices:
|
||||||
map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(",")))
|
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||||
else:
|
else:
|
||||||
device_ids = list(range(num_dev))
|
device_ids = list(range(num_dev))
|
||||||
# this checks hardware and driver support for NVLink
|
# this checks hardware and driver support for NVLink
|
||||||
|
|||||||
@ -4,11 +4,11 @@
|
|||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
"""Tensor and pipeline parallel groups."""
|
"""Tensor and pipeline parallel groups."""
|
||||||
import contextlib
|
import contextlib
|
||||||
import os
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -80,7 +80,7 @@ def init_distributed_environment(
|
|||||||
# local_rank is not available in torch ProcessGroup,
|
# local_rank is not available in torch ProcessGroup,
|
||||||
# see https://github.com/pytorch/pytorch/issues/122816
|
# see https://github.com/pytorch/pytorch/issues/122816
|
||||||
if local_rank == -1 and distributed_init_method == "env://":
|
if local_rank == -1 and distributed_init_method == "env://":
|
||||||
local_rank = int(os.environ['LOCAL_RANK'])
|
local_rank = envs.LOCAL_RANK
|
||||||
global _LOCAL_RANK
|
global _LOCAL_RANK
|
||||||
_LOCAL_RANK = local_rank
|
_LOCAL_RANK = local_rank
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from typing import Dict, Optional, Sequence
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .parallel_state import get_cpu_world_group, get_local_rank
|
from .parallel_state import get_cpu_world_group, get_local_rank
|
||||||
@ -102,11 +103,13 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
|
|||||||
is_distributed = dist.is_initialized()
|
is_distributed = dist.is_initialized()
|
||||||
|
|
||||||
num_dev = torch.cuda.device_count()
|
num_dev = torch.cuda.device_count()
|
||||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||||
if cuda_visible_devices is None:
|
if cuda_visible_devices is None:
|
||||||
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
||||||
|
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT
|
||||||
path = os.path.expanduser(
|
path = os.path.expanduser(
|
||||||
f"~/.config/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
|
f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
|
||||||
|
)
|
||||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
if (not is_distributed or get_local_rank() == 0) \
|
if (not is_distributed or get_local_rank() == 0) \
|
||||||
and (not os.path.exists(path)):
|
and (not os.path.exists(path)):
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
|
from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
|
||||||
@ -7,6 +6,7 @@ from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
|
|||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.config import DecodingConfig, ModelConfig
|
from vllm.config import DecodingConfig, ModelConfig
|
||||||
from vllm.core.scheduler import SchedulerOutputs
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
@ -20,8 +20,7 @@ from vllm.sequence import MultiModalData, SamplerOutput
|
|||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
ENGINE_ITERATION_TIMEOUT_S = int(
|
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||||
os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60"))
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncEngineDeadError(RuntimeError):
|
class AsyncEngineDeadError(RuntimeError):
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
@ -16,6 +15,7 @@ from prometheus_client import make_asgi_app
|
|||||||
from starlette.routing import Mount
|
from starlette.routing import Mount
|
||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||||
@ -129,7 +129,7 @@ if __name__ == "__main__":
|
|||||||
allow_headers=args.allowed_headers,
|
allow_headers=args.allowed_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
if token := os.environ.get("VLLM_API_KEY") or args.api_key:
|
if token := envs.VLLM_API_KEY or args.api_key:
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def authentication(request: Request, call_next):
|
async def authentication(request: Request, call_next):
|
||||||
|
|||||||
160
vllm/envs.py
Normal file
160
vllm/envs.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
VLLM_HOST_IP: str = ""
|
||||||
|
VLLM_USE_MODELSCOPE: bool = False
|
||||||
|
VLLM_INSTANCE_ID: Optional[str] = None
|
||||||
|
VLLM_NCCL_SO_PATH: Optional[str] = None
|
||||||
|
LD_LIBRARY_PATH: Optional[str] = None
|
||||||
|
VLLM_USE_TRITON_FLASH_ATTN: bool = False
|
||||||
|
LOCAL_RANK: int = 0
|
||||||
|
CUDA_VISIBLE_DEVICES: Optional[str] = None
|
||||||
|
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
|
||||||
|
VLLM_API_KEY: Optional[str] = None
|
||||||
|
S3_ACCESS_KEY_ID: Optional[str] = None
|
||||||
|
S3_SECRET_ACCESS_KEY: Optional[str] = None
|
||||||
|
S3_ENDPOINT_URL: Optional[str] = None
|
||||||
|
VLLM_CONFIG_ROOT: str = ""
|
||||||
|
VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai"
|
||||||
|
VLLM_NO_USAGE_STATS: bool = False
|
||||||
|
VLLM_DO_NOT_TRACK: bool = False
|
||||||
|
VLLM_USAGE_SOURCE: str = ""
|
||||||
|
VLLM_CONFIGURE_LOGGING: int = 1
|
||||||
|
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
|
||||||
|
VLLM_TRACE_FUNCTION: int = 0
|
||||||
|
VLLM_ATTENTION_BACKEND: Optional[str] = None
|
||||||
|
VLLM_CPU_KVCACHE_SPACE: int = 0
|
||||||
|
VLLM_USE_RAY_COMPILED_DAG: bool = False
|
||||||
|
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
|
||||||
|
|
||||||
|
environment_variables: Dict[str, Callable[[], Any]] = {
|
||||||
|
# used in distributed environment to determine the master address
|
||||||
|
'VLLM_HOST_IP':
|
||||||
|
lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""),
|
||||||
|
|
||||||
|
# If true, will load models from ModelScope instead of Hugging Face Hub.
|
||||||
|
# note that the value is true or false, not numbers
|
||||||
|
"VLLM_USE_MODELSCOPE":
|
||||||
|
lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true",
|
||||||
|
|
||||||
|
# Instance id represents an instance of the VLLM. All processes in the same
|
||||||
|
# instance should have the same instance id.
|
||||||
|
"VLLM_INSTANCE_ID":
|
||||||
|
lambda: os.environ.get("VLLM_INSTANCE_ID", None),
|
||||||
|
|
||||||
|
# path to cudatoolkit home directory, under which should be bin, include,
|
||||||
|
# and lib directories.
|
||||||
|
"CUDA_HOME":
|
||||||
|
lambda: os.environ.get("CUDA_HOME", None),
|
||||||
|
|
||||||
|
# Path to the NCCL library file. It is needed because nccl>=2.19 brought
|
||||||
|
# by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234
|
||||||
|
"VLLM_NCCL_SO_PATH":
|
||||||
|
lambda: os.environ.get("VLLM_NCCL_SO_PATH", None),
|
||||||
|
|
||||||
|
# when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl
|
||||||
|
# library file in the locations specified by `LD_LIBRARY_PATH`
|
||||||
|
"LD_LIBRARY_PATH":
|
||||||
|
lambda: os.environ.get("LD_LIBRARY_PATH", None),
|
||||||
|
|
||||||
|
# flag to control if vllm should use triton flash attention
|
||||||
|
"VLLM_USE_TRITON_FLASH_ATTN":
|
||||||
|
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
|
||||||
|
("true", "1")),
|
||||||
|
|
||||||
|
# local rank of the process in the distributed setting, used to determine
|
||||||
|
# the GPU device id
|
||||||
|
"LOCAL_RANK":
|
||||||
|
lambda: int(os.environ.get("LOCAL_RANK", "0")),
|
||||||
|
|
||||||
|
# used to control the visible devices in the distributed setting
|
||||||
|
"CUDA_VISIBLE_DEVICES":
|
||||||
|
lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None),
|
||||||
|
|
||||||
|
# timeout for each iteration in the engine
|
||||||
|
"VLLM_ENGINE_ITERATION_TIMEOUT_S":
|
||||||
|
lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")),
|
||||||
|
|
||||||
|
# API key for VLLM API server
|
||||||
|
"VLLM_API_KEY":
|
||||||
|
lambda: os.environ.get("VLLM_API_KEY", None),
|
||||||
|
|
||||||
|
# S3 access information, used for tensorizer to load model from S3
|
||||||
|
"S3_ACCESS_KEY_ID":
|
||||||
|
lambda: os.environ.get("S3_ACCESS_KEY", None),
|
||||||
|
"S3_SECRET_ACCESS_KEY":
|
||||||
|
lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None),
|
||||||
|
"S3_ENDPOINT_URL":
|
||||||
|
lambda: os.environ.get("S3_ENDPOINT_URL", None),
|
||||||
|
|
||||||
|
# Root directory for VLLM configuration files
|
||||||
|
# Note that this not only affects how vllm finds its configuration files
|
||||||
|
# during runtime, but also affects how vllm installs its configuration
|
||||||
|
# files during **installation**.
|
||||||
|
"VLLM_CONFIG_ROOT":
|
||||||
|
lambda: os.environ.get("VLLM_CONFIG_ROOT", None) or os.getenv(
|
||||||
|
"XDG_CONFIG_HOME", None) or os.path.expanduser("~/.config"),
|
||||||
|
|
||||||
|
# Usage stats collection
|
||||||
|
"VLLM_USAGE_STATS_SERVER":
|
||||||
|
lambda: os.environ.get("VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai"),
|
||||||
|
"VLLM_NO_USAGE_STATS":
|
||||||
|
lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1",
|
||||||
|
"VLLM_DO_NOT_TRACK":
|
||||||
|
lambda: (os.environ.get("VLLM_DO_NOT_TRACK", None) or os.environ.get(
|
||||||
|
"DO_NOT_TRACK", None) or "0") == "1",
|
||||||
|
"VLLM_USAGE_SOURCE":
|
||||||
|
lambda: os.environ.get("VLLM_USAGE_SOURCE", "production"),
|
||||||
|
|
||||||
|
# Logging configuration
|
||||||
|
# If set to 0, vllm will not configure logging
|
||||||
|
# If set to 1, vllm will configure logging using the default configuration
|
||||||
|
# or the configuration file specified by VLLM_LOGGING_CONFIG_PATH
|
||||||
|
"VLLM_CONFIGURE_LOGGING":
|
||||||
|
lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")),
|
||||||
|
"VLLM_LOGGING_CONFIG_PATH":
|
||||||
|
lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"),
|
||||||
|
|
||||||
|
# Trace function calls
|
||||||
|
# If set to 1, vllm will trace function calls
|
||||||
|
# Useful for debugging
|
||||||
|
"VLLM_TRACE_FUNCTION":
|
||||||
|
lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")),
|
||||||
|
|
||||||
|
# Backend for attention computation
|
||||||
|
# Available options:
|
||||||
|
# - "TORCH_SDPA": use torch.nn.MultiheadAttention
|
||||||
|
# - "FLASH_ATTN": use FlashAttention
|
||||||
|
# - "XFORMERS": use XFormers
|
||||||
|
# - "ROCM_FLASH": use ROCmFlashAttention
|
||||||
|
"VLLM_ATTENTION_BACKEND":
|
||||||
|
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
|
||||||
|
|
||||||
|
# CPU key-value cache space
|
||||||
|
# default is 4GB
|
||||||
|
"VLLM_CPU_KVCACHE_SPACE":
|
||||||
|
lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")),
|
||||||
|
|
||||||
|
# If the env var is set, it uses the Ray's compiled DAG API
|
||||||
|
# which optimizes the control plane overhead.
|
||||||
|
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||||
|
"VLLM_USE_RAY_COMPILED_DAG":
|
||||||
|
lambda: bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)),
|
||||||
|
|
||||||
|
# Use dedicated multiprocess context for workers.
|
||||||
|
# Both spawn and fork work
|
||||||
|
"VLLM_WORKER_MULTIPROC_METHOD":
|
||||||
|
lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name):
|
||||||
|
# lazy evaluation of environment variables
|
||||||
|
if name in environment_variables:
|
||||||
|
return environment_variables[name]()
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def __dir__():
|
||||||
|
return list(environment_variables.keys())
|
||||||
@ -1,8 +1,8 @@
|
|||||||
import os
|
|
||||||
from typing import Dict, List, Set, Tuple
|
from typing import Dict, List, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
|
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -152,8 +152,7 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
|
|||||||
logger.warning("Prefix caching is not supported on CPU, disable it.")
|
logger.warning("Prefix caching is not supported on CPU, disable it.")
|
||||||
config.enable_prefix_caching = False
|
config.enable_prefix_caching = False
|
||||||
|
|
||||||
kv_cache_space_str = os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")
|
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
|
||||||
kv_cache_space = int(kv_cache_space_str)
|
|
||||||
|
|
||||||
if kv_cache_space >= 0:
|
if kv_cache_space >= 0:
|
||||||
if kv_cache_space == 0:
|
if kv_cache_space == 0:
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from multiprocessing.process import BaseProcess
|
|||||||
from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
|
from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
|
||||||
TypeVar, Union)
|
TypeVar, Union)
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -26,9 +27,7 @@ RESET = '\033[0;0m'
|
|||||||
|
|
||||||
JOIN_TIMEOUT_S = 2
|
JOIN_TIMEOUT_S = 2
|
||||||
|
|
||||||
# Use dedicated multiprocess context for workers.
|
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||||
# Both spawn and fork work
|
|
||||||
mp_method = os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
|
||||||
mp = multiprocessing.get_context(mp_method)
|
mp = multiprocessing.get_context(mp_method)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from collections import defaultdict
|
|||||||
from itertools import islice, repeat
|
from itertools import islice, repeat
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
||||||
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
||||||
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||||
@ -21,10 +22,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# If the env var is set, it uses the Ray's compiled DAG API
|
USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
|
||||||
# which optimizes the control plane overhead.
|
|
||||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
|
||||||
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
|
|
||||||
|
|
||||||
|
|
||||||
class RayGPUExecutor(DistributedGPUExecutor):
|
class RayGPUExecutor(DistributedGPUExecutor):
|
||||||
@ -145,7 +143,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
"VLLM_INSTANCE_ID":
|
"VLLM_INSTANCE_ID":
|
||||||
VLLM_INSTANCE_ID,
|
VLLM_INSTANCE_ID,
|
||||||
"VLLM_TRACE_FUNCTION":
|
"VLLM_TRACE_FUNCTION":
|
||||||
os.getenv("VLLM_TRACE_FUNCTION", "0"),
|
str(envs.VLLM_TRACE_FUNCTION),
|
||||||
}, ) for (node_id, _) in worker_node_and_gpu_ids]
|
}, ) for (node_id, _) in worker_node_and_gpu_ids]
|
||||||
self._run_workers("update_environment_variables",
|
self._run_workers("update_environment_variables",
|
||||||
all_args=all_args_to_update_environment_variables)
|
all_args=all_args_to_update_environment_variables)
|
||||||
|
|||||||
@ -10,8 +10,10 @@ from logging.config import dictConfig
|
|||||||
from os import path
|
from os import path
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
|
import vllm.envs as envs
|
||||||
VLLM_LOGGING_CONFIG_PATH = os.getenv("VLLM_LOGGING_CONFIG_PATH")
|
|
||||||
|
VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
|
||||||
|
VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
|
||||||
|
|
||||||
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
|
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
|
||||||
_DATE_FORMAT = "%m-%d %H:%M:%S"
|
_DATE_FORMAT = "%m-%d %H:%M:%S"
|
||||||
|
|||||||
@ -9,9 +9,10 @@ import huggingface_hub
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.config import (VLLM_USE_MODELSCOPE, DeviceConfig, LoadConfig,
|
from vllm.config import (DeviceConfig, LoadConfig, LoadFormat, LoRAConfig,
|
||||||
LoadFormat, LoRAConfig, ModelConfig, ParallelConfig,
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
SchedulerConfig, VisionLanguageConfig)
|
VisionLanguageConfig)
|
||||||
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.config import ModelConfig, ParallelConfig
|
from vllm.config import ModelConfig, ParallelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -142,13 +143,10 @@ class TensorizerArgs:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.file_obj = self.tensorizer_uri
|
self.file_obj = self.tensorizer_uri
|
||||||
self.s3_access_key_id = (self.s3_access_key_id
|
self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID
|
||||||
or os.environ.get("S3_ACCESS_KEY_ID")) or None
|
self.s3_secret_access_key = (self.s3_secret_access_key
|
||||||
self.s3_secret_access_key = (
|
or envs.S3_SECRET_ACCESS_KEY)
|
||||||
self.s3_secret_access_key
|
self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL
|
||||||
or os.environ.get("S3_SECRET_ACCESS_KEY")) or None
|
|
||||||
self.s3_endpoint = (self.s3_endpoint
|
|
||||||
or os.environ.get("S3_ENDPOINT_URL")) or None
|
|
||||||
self.stream_params = {
|
self.stream_params = {
|
||||||
"s3_access_key_id": self.s3_access_key_id,
|
"s3_access_key_id": self.s3_access_key_id,
|
||||||
"s3_secret_access_key": self.s3_secret_access_key,
|
"s3_secret_access_key": self.s3_secret_access_key,
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import huggingface_hub
|
|||||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||||
PreTrainedTokenizerFast)
|
PreTrainedTokenizerFast)
|
||||||
|
|
||||||
from vllm.config import VLLM_USE_MODELSCOPE
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
|
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
|
||||||
|
|||||||
@ -15,20 +15,22 @@ import psutil
|
|||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
_config_home = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))
|
import vllm.envs as envs
|
||||||
|
|
||||||
|
_config_home = envs.VLLM_CONFIG_ROOT
|
||||||
_USAGE_STATS_JSON_PATH = os.path.join(_config_home, "vllm/usage_stats.json")
|
_USAGE_STATS_JSON_PATH = os.path.join(_config_home, "vllm/usage_stats.json")
|
||||||
_USAGE_STATS_DO_NOT_TRACK_PATH = os.path.join(_config_home,
|
_USAGE_STATS_DO_NOT_TRACK_PATH = os.path.join(_config_home,
|
||||||
"vllm/do_not_track")
|
"vllm/do_not_track")
|
||||||
_USAGE_STATS_ENABLED = None
|
_USAGE_STATS_ENABLED = None
|
||||||
_USAGE_STATS_SERVER = os.environ.get("VLLM_USAGE_STATS_SERVER",
|
_USAGE_STATS_SERVER = envs.VLLM_USAGE_STATS_SERVER
|
||||||
"https://stats.vllm.ai")
|
|
||||||
|
|
||||||
|
|
||||||
def is_usage_stats_enabled():
|
def is_usage_stats_enabled():
|
||||||
"""Determine whether or not we can send usage stats to the server.
|
"""Determine whether or not we can send usage stats to the server.
|
||||||
The logic is as follows:
|
The logic is as follows:
|
||||||
- By default, it should be enabled.
|
- By default, it should be enabled.
|
||||||
- Two environment variables can disable it:
|
- Three environment variables can disable it:
|
||||||
|
- VLLM_DO_NOT_TRACK=1
|
||||||
- DO_NOT_TRACK=1
|
- DO_NOT_TRACK=1
|
||||||
- VLLM_NO_USAGE_STATS=1
|
- VLLM_NO_USAGE_STATS=1
|
||||||
- A file in the home directory can disable it if it exists:
|
- A file in the home directory can disable it if it exists:
|
||||||
@ -36,8 +38,8 @@ def is_usage_stats_enabled():
|
|||||||
"""
|
"""
|
||||||
global _USAGE_STATS_ENABLED
|
global _USAGE_STATS_ENABLED
|
||||||
if _USAGE_STATS_ENABLED is None:
|
if _USAGE_STATS_ENABLED is None:
|
||||||
do_not_track = os.environ.get("DO_NOT_TRACK", "0") == "1"
|
do_not_track = envs.VLLM_DO_NOT_TRACK
|
||||||
no_usage_stats = os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1"
|
no_usage_stats = envs.VLLM_NO_USAGE_STATS
|
||||||
do_not_track_file = os.path.exists(_USAGE_STATS_DO_NOT_TRACK_PATH)
|
do_not_track_file = os.path.exists(_USAGE_STATS_DO_NOT_TRACK_PATH)
|
||||||
|
|
||||||
_USAGE_STATS_ENABLED = not (do_not_track or no_usage_stats
|
_USAGE_STATS_ENABLED = not (do_not_track or no_usage_stats
|
||||||
@ -167,7 +169,7 @@ class UsageMessage:
|
|||||||
|
|
||||||
# Metadata
|
# Metadata
|
||||||
self.log_time = _get_current_timestamp_ns()
|
self.log_time = _get_current_timestamp_ns()
|
||||||
self.source = os.environ.get("VLLM_USAGE_SOURCE", "production")
|
self.source = envs.VLLM_USAGE_SOURCE
|
||||||
|
|
||||||
data = vars(self)
|
data = vars(self)
|
||||||
if extra_kvs:
|
if extra_kvs:
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import psutil
|
|||||||
import torch
|
import torch
|
||||||
from packaging.version import Version, parse
|
from packaging.version import Version, parse
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.logger import enable_trace_function_call, init_logger
|
from vllm.logger import enable_trace_function_call, init_logger
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@ -174,7 +175,7 @@ def get_vllm_instance_id():
|
|||||||
Instance id represents an instance of the VLLM. All processes in the same
|
Instance id represents an instance of the VLLM. All processes in the same
|
||||||
instance should have the same instance id.
|
instance should have the same instance id.
|
||||||
"""
|
"""
|
||||||
return os.environ.get("VLLM_INSTANCE_ID", f"vllm-instance-{random_uuid()}")
|
return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}"
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
@ -243,7 +244,7 @@ def merge_async_iterators(
|
|||||||
|
|
||||||
|
|
||||||
def get_ip() -> str:
|
def get_ip() -> str:
|
||||||
host_ip = os.environ.get("HOST_IP")
|
host_ip = envs.VLLM_HOST_IP
|
||||||
if host_ip:
|
if host_ip:
|
||||||
return host_ip
|
return host_ip
|
||||||
|
|
||||||
@ -269,7 +270,8 @@ def get_ip() -> str:
|
|||||||
|
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Failed to get the IP address, using 0.0.0.0 by default."
|
"Failed to get the IP address, using 0.0.0.0 by default."
|
||||||
"The value can be set by the environment variable HOST_IP.",
|
"The value can be set by the environment variable"
|
||||||
|
" VLLM_HOST_IP or HOST_IP.",
|
||||||
stacklevel=2)
|
stacklevel=2)
|
||||||
return "0.0.0.0"
|
return "0.0.0.0"
|
||||||
|
|
||||||
@ -314,7 +316,7 @@ def cdiv(a: int, b: int) -> int:
|
|||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_nvcc_cuda_version() -> Optional[Version]:
|
def get_nvcc_cuda_version() -> Optional[Version]:
|
||||||
cuda_home = os.environ.get('CUDA_HOME')
|
cuda_home = envs.CUDA_HOME
|
||||||
if not cuda_home:
|
if not cuda_home:
|
||||||
cuda_home = '/usr/local/cuda'
|
cuda_home = '/usr/local/cuda'
|
||||||
if os.path.isfile(cuda_home + '/bin/nvcc'):
|
if os.path.isfile(cuda_home + '/bin/nvcc'):
|
||||||
@ -581,7 +583,7 @@ def find_library(lib_name: str) -> str:
|
|||||||
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
|
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
|
||||||
locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
|
locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
|
||||||
# `LD_LIBRARY_PATH` searches the library in the user-defined paths
|
# `LD_LIBRARY_PATH` searches the library in the user-defined paths
|
||||||
env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
|
env_ld_library_path = envs.LD_LIBRARY_PATH
|
||||||
if not locs and env_ld_library_path:
|
if not locs and env_ld_library_path:
|
||||||
locs = [
|
locs = [
|
||||||
os.path.join(dir, lib_name)
|
os.path.join(dir, lib_name)
|
||||||
@ -594,14 +596,15 @@ def find_library(lib_name: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def find_nccl_library():
|
def find_nccl_library():
|
||||||
so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")
|
so_file = envs.VLLM_NCCL_SO_PATH
|
||||||
|
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT
|
||||||
|
|
||||||
# check if we have vllm-managed nccl
|
# check if we have vllm-managed nccl
|
||||||
vllm_nccl_path = None
|
vllm_nccl_path = None
|
||||||
if torch.version.cuda is not None:
|
if torch.version.cuda is not None:
|
||||||
cuda_major = torch.version.cuda.split(".")[0]
|
cuda_major = torch.version.cuda.split(".")[0]
|
||||||
path = os.path.expanduser(
|
path = os.path.expanduser(
|
||||||
f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*")
|
f"{VLLM_CONFIG_ROOT}/vllm/nccl/cu{cuda_major}/libnccl.so.*")
|
||||||
files = glob.glob(path)
|
files = glob.glob(path)
|
||||||
vllm_nccl_path = files[0] if files else None
|
vllm_nccl_path = files[0] if files else None
|
||||||
|
|
||||||
@ -626,7 +629,7 @@ def enable_trace_function_call_for_thread() -> None:
|
|||||||
if enabled via the VLLM_TRACE_FUNCTION environment variable
|
if enabled via the VLLM_TRACE_FUNCTION environment variable
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
|
if envs.VLLM_TRACE_FUNCTION:
|
||||||
tmp_dir = tempfile.gettempdir()
|
tmp_dir = tempfile.gettempdir()
|
||||||
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
|
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
|
||||||
f"_thread_{threading.get_ident()}_"
|
f"_thread_{threading.get_ident()}_"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user