[core][executor] simplify instance id (#10976)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-12-07 09:33:45 -08:00 committed by GitHub
parent 78029b34ed
commit 1b62745b1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 22 additions and 55 deletions

View File

@ -27,7 +27,8 @@ from vllm.transformers_utils.config import (
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
print_warning_once, resolve_obj_by_qualname)
print_warning_once, random_uuid,
resolve_obj_by_qualname)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -2408,6 +2409,7 @@ class VllmConfig:
init=True) # type: ignore
kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore
instance_id: str = ""
@staticmethod
def get_graph_batch_size(batch_size: int) -> int:
@ -2573,6 +2575,9 @@ class VllmConfig:
current_platform.check_and_update_config(self)
if not self.instance_id:
self.instance_id = random_uuid()[:5]
def __str__(self):
return ("model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "

View File

@ -8,7 +8,6 @@ if TYPE_CHECKING:
VLLM_RPC_BASE_PATH: str = tempfile.gettempdir()
VLLM_USE_MODELSCOPE: bool = False
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_INSTANCE_ID: Optional[str] = None
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
@ -175,11 +174,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_USE_MODELSCOPE":
lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true",
# Instance id represents an instance of the VLLM. All processes in the same
# instance should have the same instance id.
"VLLM_INSTANCE_ID":
lambda: os.environ.get("VLLM_INSTANCE_ID", None),
# Interval in seconds to log a warning message when the ring buffer is full
"VLLM_RINGBUFFER_WARNING_INTERVAL":
lambda: int(os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")),

View File

@ -10,8 +10,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async)
from vllm.utils import get_distributed_init_method, get_open_port, make_async
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@ -31,9 +30,6 @@ class CPUExecutor(ExecutorBase):
# Environment variables for CPU executor
#
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

View File

@ -16,7 +16,7 @@ from vllm.sequence import ExecuteModelRequest
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
cuda_is_initialized, get_distributed_init_method,
get_open_port, get_vllm_instance_id, make_async,
get_open_port, make_async,
update_environment_variables)
if HAS_TRITON:
@ -37,9 +37,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

View File

@ -15,8 +15,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, get_vllm_instance_id,
make_async)
get_ip, get_open_port, make_async)
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@ -220,14 +219,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
" environment variable, make sure it is unique for"
" each node.")
VLLM_INSTANCE_ID = get_vllm_instance_id()
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])),
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
**({

View File

@ -15,8 +15,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, get_vllm_instance_id,
make_async)
get_ip, get_open_port, make_async)
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@ -196,12 +195,8 @@ class RayHPUExecutor(DistributedGPUExecutor):
"environment variable, make sure it is unique for"
" each node.")
VLLM_INSTANCE_ID = get_vllm_instance_id()
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for (node_id, _) in worker_node_and_gpu_ids]

View File

@ -13,7 +13,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)
make_async)
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@ -144,12 +144,8 @@ class RayTPUExecutor(TPUExecutor):
for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
VLLM_INSTANCE_ID = get_vllm_instance_id()
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for _ in worker_node_and_gpu_ids]

View File

@ -5,7 +5,7 @@ import vllm.envs as envs
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync
from vllm.executor.xpu_executor import XPUExecutor
from vllm.logger import init_logger
from vllm.utils import get_vllm_instance_id, make_async
from vllm.utils import make_async
logger = init_logger(__name__)
@ -17,12 +17,8 @@ class RayXPUExecutor(RayGPUExecutor, XPUExecutor):
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)
VLLM_INSTANCE_ID = get_vllm_instance_id()
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for (_, _) in worker_node_and_gpu_ids]

View File

@ -24,9 +24,9 @@ from collections import UserDict, defaultdict
from collections.abc import Iterable, Mapping
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
Hashable, List, Literal, Optional, OrderedDict, Set, Tuple,
Type, TypeVar, Union, overload)
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generic, Hashable, List, Literal, Optional,
OrderedDict, Set, Tuple, Type, TypeVar, Union, overload)
from uuid import uuid4
import numpy as np
@ -43,6 +43,9 @@ import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
from vllm.platforms import current_platform
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
# Exception strings for non-implemented encoder/decoder scenarios
@ -335,17 +338,6 @@ def random_uuid() -> str:
return str(uuid.uuid4().hex)
@lru_cache(maxsize=None)
def get_vllm_instance_id() -> str:
"""
If the environment variable VLLM_INSTANCE_ID is set, return it.
Otherwise, return a random UUID.
Instance id represents an instance of the VLLM. All processes in the same
instance should have the same instance id.
"""
return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}"
@lru_cache(maxsize=None)
def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071
@ -997,7 +989,7 @@ def find_nccl_library() -> str:
return so_file
def enable_trace_function_call_for_thread() -> None:
def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None:
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable
"""
@ -1009,7 +1001,8 @@ def enable_trace_function_call_for_thread() -> None:
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
log_path = os.path.join(tmp_dir, "vllm",
f"vllm-instance-{vllm_config.instance_id}",
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)

View File

@ -439,7 +439,7 @@ class WorkerWrapperBase:
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
"""
enable_trace_function_call_for_thread()
enable_trace_function_call_for_thread(self.vllm_config)
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'