mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-29 07:07:26 +08:00
[core][executor] simplify instance id (#10976)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
78029b34ed
commit
1b62745b1d
@ -27,7 +27,8 @@ from vllm.transformers_utils.config import (
|
||||
get_hf_text_config, get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
|
||||
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
||||
print_warning_once, resolve_obj_by_qualname)
|
||||
print_warning_once, random_uuid,
|
||||
resolve_obj_by_qualname)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
@ -2408,6 +2409,7 @@ class VllmConfig:
|
||||
init=True) # type: ignore
|
||||
kv_transfer_config: KVTransferConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
instance_id: str = ""
|
||||
|
||||
@staticmethod
|
||||
def get_graph_batch_size(batch_size: int) -> int:
|
||||
@ -2573,6 +2575,9 @@ class VllmConfig:
|
||||
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
def __str__(self):
|
||||
return ("model=%r, speculative_config=%r, tokenizer=%r, "
|
||||
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
|
||||
|
||||
@ -8,7 +8,6 @@ if TYPE_CHECKING:
|
||||
VLLM_RPC_BASE_PATH: str = tempfile.gettempdir()
|
||||
VLLM_USE_MODELSCOPE: bool = False
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
|
||||
VLLM_INSTANCE_ID: Optional[str] = None
|
||||
VLLM_NCCL_SO_PATH: Optional[str] = None
|
||||
LD_LIBRARY_PATH: Optional[str] = None
|
||||
VLLM_USE_TRITON_FLASH_ATTN: bool = False
|
||||
@ -175,11 +174,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_MODELSCOPE":
|
||||
lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true",
|
||||
|
||||
# Instance id represents an instance of the VLLM. All processes in the same
|
||||
# instance should have the same instance id.
|
||||
"VLLM_INSTANCE_ID":
|
||||
lambda: os.environ.get("VLLM_INSTANCE_ID", None),
|
||||
|
||||
# Interval in seconds to log a warning message when the ring buffer is full
|
||||
"VLLM_RINGBUFFER_WARNING_INTERVAL":
|
||||
lambda: int(os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")),
|
||||
|
||||
@ -10,8 +10,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (get_distributed_init_method, get_open_port,
|
||||
get_vllm_instance_id, make_async)
|
||||
from vllm.utils import get_distributed_init_method, get_open_port, make_async
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -31,9 +30,6 @@ class CPUExecutor(ExecutorBase):
|
||||
# Environment variables for CPU executor
|
||||
#
|
||||
|
||||
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
|
||||
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
|
||||
|
||||
# Disable torch async compiling which won't work with daemonic processes
|
||||
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.triton_utils.importing import HAS_TRITON
|
||||
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
|
||||
cuda_is_initialized, get_distributed_init_method,
|
||||
get_open_port, get_vllm_instance_id, make_async,
|
||||
get_open_port, make_async,
|
||||
update_environment_variables)
|
||||
|
||||
if HAS_TRITON:
|
||||
@ -37,9 +37,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
|
||||
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
|
||||
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
|
||||
|
||||
# Disable torch async compiling which won't work with daemonic processes
|
||||
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
||||
|
||||
|
||||
@ -15,8 +15,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
||||
get_ip, get_open_port, get_vllm_instance_id,
|
||||
make_async)
|
||||
get_ip, get_open_port, make_async)
|
||||
|
||||
if ray is not None:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
@ -220,14 +219,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
" environment variable, make sure it is unique for"
|
||||
" each node.")
|
||||
|
||||
VLLM_INSTANCE_ID = get_vllm_instance_id()
|
||||
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = [({
|
||||
"CUDA_VISIBLE_DEVICES":
|
||||
",".join(map(str, node_gpus[node_id])),
|
||||
"VLLM_INSTANCE_ID":
|
||||
VLLM_INSTANCE_ID,
|
||||
"VLLM_TRACE_FUNCTION":
|
||||
str(envs.VLLM_TRACE_FUNCTION),
|
||||
**({
|
||||
|
||||
@ -15,8 +15,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
||||
get_ip, get_open_port, get_vllm_instance_id,
|
||||
make_async)
|
||||
get_ip, get_open_port, make_async)
|
||||
|
||||
if ray is not None:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
@ -196,12 +195,8 @@ class RayHPUExecutor(DistributedGPUExecutor):
|
||||
"environment variable, make sure it is unique for"
|
||||
" each node.")
|
||||
|
||||
VLLM_INSTANCE_ID = get_vllm_instance_id()
|
||||
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = [({
|
||||
"VLLM_INSTANCE_ID":
|
||||
VLLM_INSTANCE_ID,
|
||||
"VLLM_TRACE_FUNCTION":
|
||||
str(envs.VLLM_TRACE_FUNCTION),
|
||||
}, ) for (node_id, _) in worker_node_and_gpu_ids]
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
get_vllm_instance_id, make_async)
|
||||
make_async)
|
||||
|
||||
if ray is not None:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
@ -144,12 +144,8 @@ class RayTPUExecutor(TPUExecutor):
|
||||
for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
|
||||
node_workers[node_id].append(i)
|
||||
|
||||
VLLM_INSTANCE_ID = get_vllm_instance_id()
|
||||
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = [({
|
||||
"VLLM_INSTANCE_ID":
|
||||
VLLM_INSTANCE_ID,
|
||||
"VLLM_TRACE_FUNCTION":
|
||||
str(envs.VLLM_TRACE_FUNCTION),
|
||||
}, ) for _ in worker_node_and_gpu_ids]
|
||||
|
||||
@ -5,7 +5,7 @@ import vllm.envs as envs
|
||||
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync
|
||||
from vllm.executor.xpu_executor import XPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_vllm_instance_id, make_async
|
||||
from vllm.utils import make_async
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -17,12 +17,8 @@ class RayXPUExecutor(RayGPUExecutor, XPUExecutor):
|
||||
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
|
||||
use_dummy_driver=True)
|
||||
|
||||
VLLM_INSTANCE_ID = get_vllm_instance_id()
|
||||
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = [({
|
||||
"VLLM_INSTANCE_ID":
|
||||
VLLM_INSTANCE_ID,
|
||||
"VLLM_TRACE_FUNCTION":
|
||||
str(envs.VLLM_TRACE_FUNCTION),
|
||||
}, ) for (_, _) in worker_node_and_gpu_ids]
|
||||
|
||||
@ -24,9 +24,9 @@ from collections import UserDict, defaultdict
|
||||
from collections.abc import Iterable, Mapping
|
||||
from functools import lru_cache, partial, wraps
|
||||
from platform import uname
|
||||
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
|
||||
Hashable, List, Literal, Optional, OrderedDict, Set, Tuple,
|
||||
Type, TypeVar, Union, overload)
|
||||
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
||||
Dict, Generic, Hashable, List, Literal, Optional,
|
||||
OrderedDict, Set, Tuple, Type, TypeVar, Union, overload)
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
@ -43,6 +43,9 @@ import vllm.envs as envs
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Exception strings for non-implemented encoder/decoder scenarios
|
||||
@ -335,17 +338,6 @@ def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_vllm_instance_id() -> str:
|
||||
"""
|
||||
If the environment variable VLLM_INSTANCE_ID is set, return it.
|
||||
Otherwise, return a random UUID.
|
||||
Instance id represents an instance of the VLLM. All processes in the same
|
||||
instance should have the same instance id.
|
||||
"""
|
||||
return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}"
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def in_wsl() -> bool:
|
||||
# Reference: https://github.com/microsoft/WSL/issues/4071
|
||||
@ -997,7 +989,7 @@ def find_nccl_library() -> str:
|
||||
return so_file
|
||||
|
||||
|
||||
def enable_trace_function_call_for_thread() -> None:
|
||||
def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None:
|
||||
"""Set up function tracing for the current thread,
|
||||
if enabled via the VLLM_TRACE_FUNCTION environment variable
|
||||
"""
|
||||
@ -1009,7 +1001,8 @@ def enable_trace_function_call_for_thread() -> None:
|
||||
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
|
||||
f"_thread_{threading.get_ident()}_"
|
||||
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
|
||||
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
|
||||
log_path = os.path.join(tmp_dir, "vllm",
|
||||
f"vllm-instance-{vllm_config.instance_id}",
|
||||
filename)
|
||||
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
||||
enable_trace_function_call(log_path)
|
||||
|
||||
@ -439,7 +439,7 @@ class WorkerWrapperBase:
|
||||
Here we inject some common logic before initializing the worker.
|
||||
Arguments are passed to the worker class constructor.
|
||||
"""
|
||||
enable_trace_function_call_for_thread()
|
||||
enable_trace_function_call_for_thread(self.vllm_config)
|
||||
|
||||
# see https://github.com/NVIDIA/nccl/issues/1234
|
||||
os.environ['NCCL_CUMEM_ENABLE'] = '0'
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user