mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 00:55:59 +08:00
330 lines
11 KiB
Python
330 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import argparse
|
|
import multiprocessing
|
|
import time
|
|
import weakref
|
|
from collections.abc import Sequence
|
|
from multiprocessing import connection
|
|
from multiprocessing.process import BaseProcess
|
|
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
|
|
Union, overload)
|
|
|
|
import torch
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
|
usage_message)
|
|
from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
|
|
kill_process_tree)
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.v1.engine.coordinator import DPCoordinator
|
|
from vllm.v1.engine.utils import (CoreEngineActorManager,
|
|
CoreEngineProcManager)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class ConstantList(Generic[T], Sequence):
|
|
|
|
def __init__(self, x: list[T]) -> None:
|
|
self._x = x
|
|
|
|
def append(self, item):
|
|
raise Exception("Cannot append to a constant list")
|
|
|
|
def extend(self, item):
|
|
raise Exception("Cannot extend a constant list")
|
|
|
|
def insert(self, item):
|
|
raise Exception("Cannot insert into a constant list")
|
|
|
|
def pop(self, item):
|
|
raise Exception("Cannot pop from a constant list")
|
|
|
|
def remove(self, item):
|
|
raise Exception("Cannot remove from a constant list")
|
|
|
|
def clear(self):
|
|
raise Exception("Cannot clear a constant list")
|
|
|
|
def index(self,
|
|
item: T,
|
|
start: int = 0,
|
|
stop: Optional[int] = None) -> int:
|
|
return self._x.index(item, start,
|
|
stop if stop is not None else len(self._x))
|
|
|
|
@overload
|
|
def __getitem__(self, item: int) -> T:
|
|
...
|
|
|
|
@overload
|
|
def __getitem__(self, s: slice, /) -> list[T]:
|
|
...
|
|
|
|
def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]:
|
|
return self._x[item]
|
|
|
|
@overload
|
|
def __setitem__(self, item: int, value: T):
|
|
...
|
|
|
|
@overload
|
|
def __setitem__(self, s: slice, value: T, /):
|
|
...
|
|
|
|
def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]):
|
|
raise Exception("Cannot set item in a constant list")
|
|
|
|
def __delitem__(self, item):
|
|
raise Exception("Cannot delete item from a constant list")
|
|
|
|
def __iter__(self):
|
|
return iter(self._x)
|
|
|
|
def __contains__(self, item):
|
|
return item in self._x
|
|
|
|
def __len__(self):
|
|
return len(self._x)
|
|
|
|
def __repr__(self):
|
|
return f"ConstantList({self._x})"
|
|
|
|
|
|
def get_engine_client_zmq_addr(local_only: bool,
|
|
host: str,
|
|
port: int = 0) -> str:
|
|
"""Assign a new ZMQ socket address.
|
|
|
|
If local_only is True, participants are colocated and so a unique IPC
|
|
address will be returned.
|
|
|
|
Otherwise, the provided host and port will be used to construct a TCP
|
|
address (port == 0 means assign an available port)."""
|
|
|
|
return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
|
|
host, port or get_open_port()))
|
|
|
|
|
|
class APIServerProcessManager:
|
|
"""Manages a group of API server processes.
|
|
|
|
Handles creation, monitoring, and termination of API server worker
|
|
processes. Also monitors extra processes to check if they are healthy.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
target_server_fn: Callable,
|
|
listen_address: str,
|
|
sock: Any,
|
|
args: argparse.Namespace,
|
|
num_servers: int,
|
|
input_addresses: list[str],
|
|
output_addresses: list[str],
|
|
stats_update_address: Optional[str] = None,
|
|
):
|
|
"""Initialize and start API server worker processes.
|
|
|
|
Args:
|
|
target_server_fn: Function to call for each API server process
|
|
listen_address: Address to listen for client connections
|
|
sock: Socket for client connections
|
|
args: Command line arguments
|
|
num_servers: Number of API server processes to start
|
|
input_addresses: Input addresses for each API server
|
|
output_addresses: Output addresses for each API server
|
|
stats_update_address: Optional stats update address
|
|
"""
|
|
self.listen_address = listen_address
|
|
self.sock = sock
|
|
self.args = args
|
|
|
|
# Start API servers
|
|
spawn_context = multiprocessing.get_context("spawn")
|
|
self.processes: list[BaseProcess] = []
|
|
|
|
for i, in_addr, out_addr in zip(range(num_servers), input_addresses,
|
|
output_addresses):
|
|
client_config = {
|
|
"input_address": in_addr,
|
|
"output_address": out_addr,
|
|
"client_index": i
|
|
}
|
|
if stats_update_address is not None:
|
|
client_config["stats_update_address"] = stats_update_address
|
|
|
|
proc = spawn_context.Process(target=target_server_fn,
|
|
name=f"ApiServer_{i}",
|
|
args=(listen_address, sock, args,
|
|
client_config))
|
|
self.processes.append(proc)
|
|
proc.start()
|
|
|
|
logger.info("Started %d API server processes", len(self.processes))
|
|
|
|
# Shutdown only the API server processes on garbage collection
|
|
# The extra processes are managed by their owners
|
|
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
|
|
|
def close(self) -> None:
|
|
self._finalizer()
|
|
|
|
|
|
def wait_for_completion_or_failure(
|
|
api_server_manager: APIServerProcessManager,
|
|
engine_manager: Optional[Union["CoreEngineProcManager",
|
|
"CoreEngineActorManager"]] = None,
|
|
coordinator: Optional["DPCoordinator"] = None) -> None:
|
|
"""Wait for all processes to complete or detect if any fail.
|
|
|
|
Raises an exception if any process exits with a non-zero status.
|
|
|
|
Args:
|
|
api_server_manager: The manager for API servers.
|
|
engine_manager: The manager for engine processes.
|
|
If CoreEngineProcManager, it manages local engines;
|
|
if CoreEngineActorManager, it manages all engines.
|
|
coordinator: The coordinator for data parallel.
|
|
"""
|
|
|
|
from vllm.v1.engine.utils import (CoreEngineActorManager,
|
|
CoreEngineProcManager)
|
|
|
|
try:
|
|
logger.info("Waiting for API servers to complete ...")
|
|
# Create a mapping of sentinels to their corresponding processes
|
|
# for efficient lookup
|
|
sentinel_to_proc: dict[Any, BaseProcess] = {
|
|
proc.sentinel: proc
|
|
for proc in api_server_manager.processes
|
|
}
|
|
|
|
if coordinator:
|
|
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
|
|
|
|
actor_run_refs = []
|
|
if isinstance(engine_manager, CoreEngineProcManager):
|
|
for proc in engine_manager.processes:
|
|
sentinel_to_proc[proc.sentinel] = proc
|
|
elif isinstance(engine_manager, CoreEngineActorManager):
|
|
actor_run_refs = engine_manager.get_run_refs()
|
|
|
|
# Check if any process terminates
|
|
while sentinel_to_proc or actor_run_refs:
|
|
# Wait for any process to terminate
|
|
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc,
|
|
timeout=5)
|
|
|
|
# Process any terminated processes
|
|
for sentinel in ready_sentinels:
|
|
proc = sentinel_to_proc.pop(sentinel)
|
|
|
|
# Check if process exited with error
|
|
if proc.exitcode != 0:
|
|
raise RuntimeError(
|
|
f"Process {proc.name} (PID: {proc.pid}) "
|
|
f"died with exit code {proc.exitcode}")
|
|
|
|
if actor_run_refs:
|
|
import ray
|
|
_, actor_run_refs = ray.wait(actor_run_refs, timeout=5)
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Received KeyboardInterrupt, shutting down API servers...")
|
|
except Exception as e:
|
|
logger.exception("Exception occurred while running API servers: %s",
|
|
str(e))
|
|
raise
|
|
finally:
|
|
logger.info("Terminating remaining processes ...")
|
|
api_server_manager.close()
|
|
if coordinator:
|
|
coordinator.close()
|
|
if engine_manager:
|
|
engine_manager.close()
|
|
|
|
|
|
# Note(rob): shutdown function cannot be a bound method,
|
|
# else the gc cannot collect the object.
|
|
def shutdown(procs: list[BaseProcess]):
|
|
# Shutdown the process.
|
|
for proc in procs:
|
|
if proc.is_alive():
|
|
proc.terminate()
|
|
|
|
# Allow 5 seconds for remaining procs to terminate.
|
|
deadline = time.monotonic() + 5
|
|
for proc in procs:
|
|
remaining = deadline - time.monotonic()
|
|
if remaining <= 0:
|
|
break
|
|
if proc.is_alive():
|
|
proc.join(remaining)
|
|
|
|
for proc in procs:
|
|
if proc.is_alive() and (pid := proc.pid) is not None:
|
|
kill_process_tree(pid)
|
|
|
|
|
|
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
|
|
length: int) -> torch.Tensor:
|
|
"""
|
|
Copy the first length elements of a tensor into another tensor in a
|
|
non-blocking manner.
|
|
|
|
Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
|
|
|
|
Returns the sliced target tensor.
|
|
"""
|
|
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
|
|
|
|
|
|
def report_usage_stats(
|
|
vllm_config,
|
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None:
|
|
"""Report usage statistics if enabled."""
|
|
|
|
if not is_usage_stats_enabled():
|
|
return
|
|
|
|
from vllm.model_executor.model_loader import get_architecture_class_name
|
|
|
|
usage_message.report_usage(
|
|
get_architecture_class_name(vllm_config.model_config),
|
|
usage_context,
|
|
extra_kvs={
|
|
# Common configuration
|
|
"dtype":
|
|
str(vllm_config.model_config.dtype),
|
|
"tensor_parallel_size":
|
|
vllm_config.parallel_config.tensor_parallel_size,
|
|
"block_size":
|
|
vllm_config.cache_config.block_size,
|
|
"gpu_memory_utilization":
|
|
vllm_config.cache_config.gpu_memory_utilization,
|
|
|
|
# Quantization
|
|
"quantization":
|
|
vllm_config.model_config.quantization,
|
|
"kv_cache_dtype":
|
|
str(vllm_config.cache_config.cache_dtype),
|
|
|
|
# Feature flags
|
|
"enable_lora":
|
|
bool(vllm_config.lora_config),
|
|
"enable_prompt_adapter":
|
|
bool(vllm_config.prompt_adapter_config),
|
|
"enable_prefix_caching":
|
|
vllm_config.cache_config.enable_prefix_caching,
|
|
"enforce_eager":
|
|
vllm_config.model_config.enforce_eager,
|
|
"disable_custom_all_reduce":
|
|
vllm_config.parallel_config.disable_custom_all_reduce,
|
|
})
|