[Perf] API-server scaleout with many-to-many server-engine comms (#17546)

This commit is contained in:
Nick Hill 2025-05-30 08:17:00 -07:00 committed by GitHub
parent 84ec470fca
commit 2dbe8c0774
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 1828 additions and 436 deletions

View File

@ -618,9 +618,11 @@ steps:
- vllm/worker/model_runner.py
- entrypoints/llm/test_collective_rpc.py
- tests/v1/test_async_llm_dp.py
- tests/v1/entrypoints/openai/test_multi_api_servers.py
- vllm/v1/engine/
commands:
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py

View File

@ -0,0 +1,268 @@
# SPDX-License-Identifier: Apache-2.0
import multiprocessing
import socket
import threading
import time
from typing import Optional
from unittest.mock import patch
import pytest
from vllm.v1.utils import (APIServerProcessManager,
wait_for_completion_or_failure)
# Global variables to control worker behavior
WORKER_RUNTIME_SECONDS = 0.5
# Mock implementation of run_api_server_worker
def mock_run_api_server_worker(listen_address, sock, args, client_config=None):
"""Mock run_api_server_worker that runs for a specific time."""
print(f"Mock worker started with client_config: {client_config}")
time.sleep(WORKER_RUNTIME_SECONDS)
print("Mock worker completed successfully")
@pytest.fixture
def api_server_args():
"""Fixture to provide arguments for APIServerProcessManager."""
sock = socket.socket()
return {
"target_server_fn":
mock_run_api_server_worker,
"listen_address":
"localhost:8000",
"sock":
sock,
"args":
"test_args", # Simple string to avoid pickling issues
"num_servers":
3,
"input_addresses": [
"tcp://127.0.0.1:5001", "tcp://127.0.0.1:5002",
"tcp://127.0.0.1:5003"
],
"output_addresses": [
"tcp://127.0.0.1:6001", "tcp://127.0.0.1:6002",
"tcp://127.0.0.1:6003"
],
"stats_update_address":
"tcp://127.0.0.1:7000",
}
@pytest.mark.parametrize("with_stats_update", [True, False])
def test_api_server_process_manager_init(api_server_args, with_stats_update):
"""Test initializing the APIServerProcessManager."""
# Set the worker runtime to ensure tests complete in reasonable time
global WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS = 0.5
# Copy the args to avoid mutating the
args = api_server_args.copy()
if not with_stats_update:
args.pop("stats_update_address")
manager = APIServerProcessManager(**args)
try:
# Verify the manager was initialized correctly
assert len(manager.processes) == 3
# Verify all processes are running
for proc in manager.processes:
assert proc.is_alive()
print("Waiting for processes to run...")
time.sleep(WORKER_RUNTIME_SECONDS / 2)
# They should still be alive at this point
for proc in manager.processes:
assert proc.is_alive()
finally:
# Always clean up the processes
print("Cleaning up processes...")
manager.close()
# Give processes time to terminate
time.sleep(0.2)
# Verify all processes were terminated
for proc in manager.processes:
assert not proc.is_alive()
@patch("vllm.entrypoints.cli.serve.run_api_server_worker",
mock_run_api_server_worker)
def test_wait_for_completion_or_failure(api_server_args):
"""Test that wait_for_completion_or_failure works with failures."""
global WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS = 1.0
# Create the manager
manager = APIServerProcessManager(**api_server_args)
try:
assert len(manager.processes) == 3
# Create a result capture for the thread
result: dict[str, Optional[Exception]] = {"exception": None}
def run_with_exception_capture():
try:
wait_for_completion_or_failure(api_server_manager=manager)
except Exception as e:
result["exception"] = e
# Start a thread to run wait_for_completion_or_failure
wait_thread = threading.Thread(target=run_with_exception_capture,
daemon=True)
wait_thread.start()
# Let all processes run for a short time
time.sleep(0.2)
# All processes should still be running
assert all(proc.is_alive() for proc in manager.processes)
# Now simulate a process failure
print("Simulating process failure...")
manager.processes[0].terminate()
# Wait for the wait_for_completion_or_failure
# to detect and handle the failure
# This should trigger it to terminate all other processes
wait_thread.join(timeout=1.0)
# The wait thread should have exited
assert not wait_thread.is_alive()
# Verify that an exception was raised with appropriate error message
assert result["exception"] is not None
assert "died with exit code" in str(result["exception"])
# All processes should now be terminated
for i, proc in enumerate(manager.processes):
assert not proc.is_alive(), f"Process {i} should not be alive"
finally:
manager.close()
time.sleep(0.2)
@pytest.mark.timeout(30)
def test_normal_completion(api_server_args):
"""Test that wait_for_completion_or_failure works in normal completion."""
global WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS = 0.1
# Create the manager
manager = APIServerProcessManager(**api_server_args)
try:
# Give processes time to terminate
# wait for processes to complete
remaining_processes = manager.processes.copy()
while remaining_processes:
for proc in remaining_processes:
if not proc.is_alive():
remaining_processes.remove(proc)
time.sleep(0.1)
# Verify all processes have terminated
for i, proc in enumerate(manager.processes):
assert not proc.is_alive(
), f"Process {i} still alive after terminate()"
# Now call wait_for_completion_or_failure
# since all processes have already
# terminated, it should return immediately
# with no error
wait_for_completion_or_failure(api_server_manager=manager)
finally:
# Clean up just in case
manager.close()
time.sleep(0.2)
@pytest.mark.timeout(30)
def test_external_process_monitoring(api_server_args):
"""Test that wait_for_completion_or_failure handles additional processes."""
global WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS = 100
# Create and start the external process
# (simulates local_engine_manager or coordinator)
spawn_context = multiprocessing.get_context("spawn")
external_proc = spawn_context.Process(target=mock_run_api_server_worker,
name="MockExternalProcess")
external_proc.start()
# Create the class to simulate a coordinator
class MockCoordinator:
def __init__(self, proc):
self.proc = proc
def close(self):
if self.proc.is_alive():
self.proc.terminate()
self.proc.join(timeout=0.5)
# Create a mock coordinator with the external process
mock_coordinator = MockCoordinator(external_proc)
# Create the API server manager
manager = APIServerProcessManager(**api_server_args)
try:
# Verify manager initialization
assert len(manager.processes) == 3
# Create a result capture for the thread
result: dict[str, Optional[Exception]] = {"exception": None}
def run_with_exception_capture():
try:
wait_for_completion_or_failure(api_server_manager=manager,
coordinator=mock_coordinator)
except Exception as e:
result["exception"] = e
# Start a thread to run wait_for_completion_or_failure
wait_thread = threading.Thread(target=run_with_exception_capture,
daemon=True)
wait_thread.start()
# Terminate the external process to trigger a failure
time.sleep(0.2)
external_proc.terminate()
# Wait for the thread to detect the failure
wait_thread.join(timeout=1.0)
# The wait thread should have completed
assert not wait_thread.is_alive(
), "wait_for_completion_or_failure thread still running"
# Verify that an exception was raised with appropriate error message
assert result["exception"] is not None, "No exception was raised"
error_message = str(result["exception"])
assert "died with exit code" in error_message, \
f"Unexpected error message: {error_message}"
assert "MockExternalProcess" in error_message, \
f"Error doesn't mention external process: {error_message}"
# Verify that all API server processes were terminated as a result
for i, proc in enumerate(manager.processes):
assert not proc.is_alive(
), f"API server process {i} was not terminated"
finally:
# Clean up
manager.close()
mock_coordinator.close()
time.sleep(0.2)

View File

@ -28,7 +28,7 @@ from tests.models.utils import TextTextLogprobs
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.cli.serve import ServeSubcommand
from vllm.model_executor.model_loader import get_model_loader
from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer
@ -99,7 +99,8 @@ class RemoteOpenAIServer:
parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
subparsers = parser.add_subparsers(required=False, dest="subparser")
parser = ServeSubcommand().subparser_init(subparsers)
args = parser.parse_args(["--model", model, *vllm_serve_args])
self.host = str(args.host or 'localhost')
self.port = int(args.port)

View File

@ -45,7 +45,6 @@ def make_request(request_id,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100,
arrival_time=0,
lora_request=None,
cache_salt=cache_salt,
)

View File

@ -38,7 +38,6 @@ def make_request(request_id,
sampling_params=SamplingParams(max_tokens=17,
prompt_logprobs=prompt_logprobs),
eos_token_id=100,
arrival_time=0,
lora_request=None,
cache_salt=cache_salt,
)

View File

@ -138,7 +138,6 @@ def create_requests(num_requests: int,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=0,
)
requests.append(request)
return requests
@ -744,7 +743,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
# No draft or accepted tokens counted yet
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is None
assert not engine_core_outputs or (
engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None)
# Schedule the speculated tokens for validation
output = scheduler.schedule()
@ -772,7 +772,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
scheduler_stats = engine_core_outputs.scheduler_stats
scheduler_stats = engine_core_outputs[0].scheduler_stats \
if engine_core_outputs else None
if expected[0] == 0:
assert scheduler_stats.spec_decoding_stats is None
else:
@ -843,7 +844,7 @@ def _step_until_done(
# We should be in the decode phase now.
assert num_scheduled_tokens == 1
assert len(output.kv_connector_metadata.requests) == 0
ecos = scheduler.update_from_output(output, model_runner_output)
ecos = scheduler.update_from_output(output, model_runner_output)[0]
all_done = True
for eco in ecos.outputs:
if eco.finish_reason is None:

View File

@ -88,7 +88,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
assert len(engine_core.scheduler.running) == 4
# Loop through until they are all done.
while len(engine_core.step()[0].outputs) > 0:
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
pass
assert len(engine_core.scheduler.waiting) == 0
@ -163,11 +163,11 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
req0.request_id = req1.request_id = "test"
engine_core.add_request(req0)
while len(engine_core.step()[0].outputs) > 0:
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
pass
engine_core.add_request(req1)
while len(engine_core.step()[0].outputs) > 0:
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
pass
assert len(engine_core.scheduler.waiting) == 0
@ -207,7 +207,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
# Loop through until they are all done.
while len(engine_core.step()[0].outputs) > 0:
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
pass
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0
@ -327,7 +327,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert scheduler_output.num_scheduled_tokens[1] == 4
# Batch queue is full. Finish Batch 2. Get first token of req0.
output = engine_core.step_with_batch_queue()[0]
output = engine_core.step_with_batch_queue()[0].get(0)
assert output is not None
assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
@ -339,7 +339,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert scheduler_output.num_scheduled_tokens[0] == 1
# Batch queue is full. Finish Batch 3. Get first token of req1.
output = engine_core.step_with_batch_queue()[0]
output = engine_core.step_with_batch_queue()[0].get(0)
assert output is not None
assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
@ -362,7 +362,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
if step % 2 == 0:
# Even steps consumes an output.
assert output is not None
assert len(output.outputs) == 1
assert len(output[0].outputs) == 1
if req_id in engine_core.scheduler.requests:
assert engine_core.scheduler.requests[
req_id].num_tokens == expected_num_tokens[req_id]

View File

@ -0,0 +1,171 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import os
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
MODEL_NAME = "ibm-research/PowerMoE-3b"
DP_SIZE = os.getenv("DP_SIZE", "1")
@pytest.fixture(scope="module")
def default_server_args():
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"128",
"--enforce-eager",
"--api-server-count",
"4",
"--data_parallel_size",
DP_SIZE,
]
@pytest.fixture(scope="module")
def server(default_server_args):
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_single_completion(client: openai.AsyncOpenAI,
model_name: str) -> None:
async def make_request():
completion = await client.completions.create(
model=model_name,
prompt="Hello, my name is",
max_tokens=10,
temperature=1.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert len(choice.text) >= 1
# Finish reason might not always be 'length' if the model finishes early
# or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert choice.finish_reason in ("length", "stop")
# Token counts can also vary, so we check they are positive.
assert completion.usage.completion_tokens > 0
assert completion.usage.prompt_tokens > 0
assert completion.usage.total_tokens > 0
return completion
# Test single request
result = await make_request()
assert result is not None
await asyncio.sleep(0.5)
# Send two bursts of requests
num_requests = 100
tasks = [make_request() for _ in range(num_requests)]
results = await asyncio.gather(*tasks)
assert len(results) == num_requests
assert all(completion is not None for completion in results)
await asyncio.sleep(0.5)
tasks = [make_request() for _ in range(num_requests)]
results = await asyncio.gather(*tasks)
assert len(results) == num_requests
assert all(completion is not None for completion in results)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_completion_streaming(client: openai.AsyncOpenAI,
model_name: str) -> None:
prompt = "What is an LLM?"
async def make_streaming_request():
# Perform a non-streaming request to get the expected full output
single_completion = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
)
single_output = single_completion.choices[0].text
# Perform the streaming request
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True)
chunks: list[str] = []
finish_reason_count = 0
last_chunk = None
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
last_chunk = chunk # Keep track of the last chunk
# finish reason should only return in the last block for OpenAI API
assert finish_reason_count == 1, (
"Finish reason should appear exactly once.")
assert last_chunk is not None, (
"Stream should have yielded at least one chunk.")
assert last_chunk.choices[
0].finish_reason == "length", "Finish reason should be 'length'."
# Check that the combined text matches the non-streamed version.
assert "".join(
chunks
) == single_output, "Streamed output should match non-streamed output."
return True # Indicate success for this request
# Test single request
result = await make_streaming_request()
assert result is not None
await asyncio.sleep(0.5)
# Send two bursts of requests
num_requests = 100
tasks = [make_streaming_request() for _ in range(num_requests)]
results = await asyncio.gather(*tasks)
assert len(
results
) == num_requests, f"Expected {num_requests} results, got {len(results)}"
assert all(results), "Not all streaming requests completed successfully."
await asyncio.sleep(0.5)
tasks = [make_streaming_request() for _ in range(num_requests)]
results = await asyncio.gather(*tasks)
assert len(
results
) == num_requests, f"Expected {num_requests} results, got {len(results)}"
assert all(results), "Not all streaming requests completed successfully."

View File

@ -43,7 +43,7 @@ def test_basic_lifecycle():
# Ensure the request is finished after 1 tokens.
assert request.is_finished()
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
output = engine_core_outputs.outputs[0]
output = engine_core_outputs[0].outputs[0]
assert output.finish_reason == FinishReason.LENGTH
assert output.kv_transfer_params is not None
@ -165,7 +165,7 @@ def test_prefix_cache_lifecycle():
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote])
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
kv_transfer_params = eco.outputs[0].kv_transfer_params
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
# Ensure we send all block ids, even if there is a cache hit.
assert (len(

View File

@ -61,7 +61,7 @@ def test_basic_lifecycle():
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
assert len(engine_core_outputs.outputs) == 0
assert not engine_core_outputs or not engine_core_outputs[0].outputs
# STEP (2):
# (2a): schedule(): nothing happens!
@ -112,7 +112,7 @@ def test_basic_lifecycle():
model_runner_output)
scheduler.schedule()
outputs = engine_core_outputs.outputs
outputs = engine_core_outputs[0].outputs
assert len(outputs) == 1
output = outputs[0]
assert output.finish_reason == FinishReason.STOP
@ -335,7 +335,7 @@ def test_full_block_prompt():
model_runner_output)
scheduler.schedule()
outputs = engine_core_outputs.outputs
outputs = engine_core_outputs[0].outputs
assert len(outputs) == 1
output = outputs[0]
assert output.finish_reason == FinishReason.STOP

View File

@ -153,7 +153,6 @@ def create_request(
multi_modal_placeholders=None,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=0,
)
req.kv_transfer_params = kv_transfer_params
return req

View File

@ -1,24 +1,35 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
import signal
import sys
import uvloop
import zmq
import vllm.envs as envs
from vllm import AsyncEngineArgs
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.api_server import (run_server, run_server_worker,
setup_server)
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.entrypoints.utils import (VLLM_SERVE_PARSER_EPILOG,
show_filtered_argument_or_group_from_help)
from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.core_client import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
EngineZmqAddresses, get_engine_client_zmq_addr,
wait_for_completion_or_failure,
wait_for_engine_startup)
logger = init_logger(__name__)
@ -36,9 +47,12 @@ class ServeSubcommand(CLISubcommand):
if hasattr(args, 'model_tag') and args.model_tag is not None:
args.model = args.model_tag
if args.headless:
if args.headless or args.api_server_count < 1:
run_headless(args)
elif args.api_server_count > 1:
run_multi_api_server(args)
else:
# Single API server (this process).
uvloop.run(run_server(args))
def validate(self, args: argparse.Namespace) -> None:
@ -69,6 +83,11 @@ class ServeSubcommand(CLISubcommand):
type=int,
default=0,
help='Starting data parallel rank for secondary nodes.')
serve_parser.add_argument('--api-server-count',
'-asc',
type=int,
default=1,
help='How many API server processes to run.')
serve_parser.add_argument(
"--config",
type=str,
@ -91,23 +110,26 @@ def cmd_init() -> list[CLISubcommand]:
def run_headless(args: argparse.Namespace):
if args.api_server_count > 1:
raise ValueError("api_server_count can't be set in headless mode")
# Create the EngineConfig.
engine_args = AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
if not envs.VLLM_USE_V1:
raise RuntimeError("Headless mode is only supported for V1")
raise ValueError("Headless mode is only supported for V1")
parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
port = engine_args.data_parallel_rpc_port # add to config too
input_address = get_tcp_uri(host, port)
handshake_address = get_tcp_uri(host, port)
if local_engine_count <= 0:
raise RuntimeError("data_parallel_size_local must be > 0 in "
"headless mode")
raise ValueError("data_parallel_size_local must be > 0 in "
"headless mode")
# Catch SIGTERM and SIGINT to allow graceful shutdown.
def signal_handler(signum, frame):
@ -119,7 +141,7 @@ def run_headless(args: argparse.Namespace):
logger.info(
"Launching %d data parallel engine(s) in headless mode, "
"with head node address %s.", local_engine_count, input_address)
"with head node address %s.", local_engine_count, handshake_address)
# Create the engines.
engine_manager = CoreEngineProcManager(
@ -129,7 +151,7 @@ def run_headless(args: argparse.Namespace):
local_start_index=0,
vllm_config=vllm_config,
on_head_node=False,
input_address=input_address,
handshake_address=handshake_address,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
)
@ -139,3 +161,142 @@ def run_headless(args: argparse.Namespace):
finally:
logger.info("Shutting down.")
engine_manager.close()
def run_multi_api_server(args: argparse.Namespace):
assert not args.headless
num_api_servers = args.api_server_count
assert num_api_servers > 0
if num_api_servers > 1:
setup_multiprocess_prometheus()
listen_address, sock = setup_server(args)
engine_args = AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
model_config = vllm_config.model_config
if num_api_servers > 1:
if not envs.VLLM_USE_V1:
raise ValueError("api_server_count > 1 is only supported for V1")
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
"with api_server_count > 1")
if model_config.is_multimodal_model and not (
model_config.disable_mm_preprocessor_cache):
logger.warning(
"Multi-model preprocessor cache will be disabled for"
" api_server_count > 1")
model_config.disable_mm_preprocessor_cache = True
parallel_config = vllm_config.parallel_config
assert parallel_config.data_parallel_rank == 0
dp_size = parallel_config.data_parallel_size
local_engine_count = parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
local_only = local_engine_count == dp_size
# Set up input and output addresses.
input_addresses = [
get_engine_client_zmq_addr(local_only, host)
for _ in range(num_api_servers)
]
output_addresses = [
get_engine_client_zmq_addr(local_only, host)
for _ in range(num_api_servers)
]
addresses = EngineZmqAddresses(
inputs=input_addresses,
outputs=output_addresses,
)
# Set up coordinator for dp > 1.
coordinator = None
stats_update_address = None
if dp_size > 1:
coordinator = DPCoordinator(parallel_config)
addresses.coordinator_input, addresses.coordinator_output = (
coordinator.get_engine_socket_addresses())
stats_update_address = coordinator.get_stats_publish_address()
logger.info("Started DP Coordinator process (PID: %d)",
coordinator.proc.pid)
handshake_address = get_engine_client_zmq_addr(
local_only, host, parallel_config.data_parallel_rpc_port)
with zmq_socket_ctx(handshake_address, zmq.ROUTER,
bind=True) as handshake_socket:
# Start local engines.
if not local_engine_count:
local_engine_manager = None
else:
local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
handshake_address=handshake_address,
on_head_node=True,
local_engine_count=local_engine_count,
start_index=0,
local_start_index=0)
# Start API servers using the manager
api_server_manager = APIServerProcessManager(
target_server_fn=run_api_server_worker_proc,
listen_address=listen_address,
sock=sock,
args=args,
num_servers=num_api_servers,
input_addresses=input_addresses,
output_addresses=output_addresses,
stats_update_address=stats_update_address)
# Wait for engine handshakes to complete.
core_engines = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(dp_size)
]
wait_for_engine_startup(
handshake_socket,
addresses,
core_engines,
parallel_config,
vllm_config.cache_config,
local_engine_manager,
coordinator.proc if coordinator else None,
)
# Wait for API servers
wait_for_completion_or_failure(
api_server_manager=api_server_manager,
local_engine_manager=local_engine_manager,
coordinator=coordinator)
def run_api_server_worker_proc(listen_address,
sock,
args,
client_config=None,
**uvicorn_kwargs) -> None:
"""Entrypoint for individual API server worker processes."""
# Add process-specific prefix to stdout and stderr.
from multiprocessing import current_process
process_name = current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
uvloop.run(
run_server_worker(listen_address, sock, args, client_config,
**uvicorn_kwargs))

View File

@ -17,7 +17,7 @@ from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from json import JSONDecodeError
from typing import Annotated, Optional
from typing import Annotated, Any, Optional
import prometheus_client
import regex as re
@ -26,6 +26,8 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import State
from starlette.routing import Mount
@ -97,6 +99,7 @@ from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit)
from vllm.v1.metrics.prometheus import get_prometheus_registry
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
@ -142,14 +145,17 @@ async def lifespan(app: FastAPI):
@asynccontextmanager
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[EngineClient]:
args: Namespace,
client_config: Optional[dict[str, Any]] = None,
) -> AsyncIterator[EngineClient]:
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args)
async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine:
engine_args, args.disable_frontend_multiprocessing,
client_config) as engine:
yield engine
@ -157,6 +163,7 @@ async def build_async_engine_client(
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
client_config: Optional[dict[str, Any]] = None,
) -> AsyncIterator[EngineClient]:
"""
Create EngineClient, either:
@ -179,12 +186,16 @@ async def build_async_engine_client_from_engine_args(
from vllm.v1.engine.async_llm import AsyncLLM
async_llm: Optional[AsyncLLM] = None
client_index = client_config.pop(
"client_index") if client_config else 0
try:
async_llm = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_requests=engine_args.disable_log_requests,
disable_log_stats=engine_args.disable_log_stats)
disable_log_stats=engine_args.disable_log_stats,
client_addresses=client_config,
client_index=client_index)
# Don't keep the dummy data in memory
await async_llm.reset_mm_cache()
@ -318,22 +329,9 @@ class PrometheusResponse(Response):
def mount_metrics(app: FastAPI):
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (REGISTRY, CollectorRegistry, make_asgi_app,
multiprocess)
from prometheus_fastapi_instrumentator import Instrumentator
"""Mount prometheus metrics to a FastAPI app."""
registry = REGISTRY
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
if prometheus_multiproc_dir_path is not None:
logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
prometheus_multiproc_dir_path)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
registry = get_prometheus_registry()
# `response_class=PrometheusResponse` is needed to return an HTTP response
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
@ -1256,16 +1254,10 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket:
return sock
async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
log_non_default_args(args)
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
def validate_api_server_args(args):
valid_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valid_tool_parses:
and args.tool_call_parser not in valid_tool_parses:
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valid_tool_parses)} }})")
@ -1276,6 +1268,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
f"invalid reasoning parser: {args.reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
def setup_server(args):
"""Validate API server args, set up signal handler, create socket
ready to serve."""
logger.info("vLLM API server version %s", VLLM_VERSION)
log_non_default_args(args)
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
validate_api_server_args(args)
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
@ -1292,22 +1297,41 @@ async def run_server(args, **uvicorn_kwargs) -> None:
signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as engine_client:
addr, port = sock_addr
is_ssl = args.ssl_keyfile and args.ssl_certfile
host_part = f"[{addr}]" if is_valid_ipv6_address(
addr) else addr or "0.0.0.0"
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
return listen_address, sock
async def run_server(args, **uvicorn_kwargs) -> None:
"""Run a single-worker API server."""
listen_address, sock = setup_server(args)
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
async def run_server_worker(listen_address,
sock,
args,
client_config=None,
**uvicorn_kwargs) -> None:
"""Run a single API server worker."""
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
server_index = client_config.get("client_index", 0) if client_config else 0
async with build_async_engine_client(args, client_config) as engine_client:
app = build_app(args)
vllm_config = await engine_client.get_vllm_config()
await init_app_state(engine_client, vllm_config, app.state, args)
def _listen_addr(a: str) -> str:
if is_valid_ipv6_address(a):
return '[' + a + ']'
return a or "0.0.0.0"
is_ssl = args.ssl_keyfile and args.ssl_certfile
logger.info("Starting vLLM API server on http%s://%s:%d",
"s" if is_ssl else "", _listen_addr(sock_addr[0]),
sock_addr[1])
logger.info("Starting vLLM API server %d on %s", server_index,
listen_address)
shutdown_task = await serve_http(
app,
sock=sock,

View File

@ -229,6 +229,11 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
self.add_adapter(lora)
def add_adapter(self, lora_request: LoRARequest) -> bool:
# Note that this method is not thread-safe. It may be invoked multiple
# times for the same adapter when using multiple API servers.
# This is ok because it's currently only called from
# the single-threaded core engine loop.
if lora_request.lora_int_id not in self.list_adapters():
# Load the new adapter first to ensure it is actually valid, before
# evicting any existing adapters.

View File

@ -2420,6 +2420,7 @@ def make_zmq_socket(
socket_type: Any,
bind: Optional[bool] = None,
identity: Optional[bytes] = None,
linger: Optional[int] = None,
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics."""
@ -2439,7 +2440,7 @@ def make_zmq_socket(
buf_size = -1 # Use system default buffer size
if bind is None:
bind = socket_type != zmq.PUSH
bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB)
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
socket.setsockopt(zmq.RCVHWM, 0)
@ -2452,6 +2453,9 @@ def make_zmq_socket(
if identity is not None:
socket.setsockopt(zmq.IDENTITY, identity)
if linger is not None:
socket.setsockopt(zmq.LINGER, linger)
# Determine if the path is a TCP socket with an IPv6 address.
# Enable IPv6 on the zmq socket if so.
scheme, host, _ = split_zmq_path(path)

View File

@ -45,7 +45,7 @@ class SchedulerInterface(ABC):
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> "EngineCoreOutputs":
) -> dict[int, "EngineCoreOutputs"]:
"""Update the scheduler state based on the model runner output.
This method is called after the model runner has processed the scheduled
@ -55,7 +55,8 @@ class SchedulerInterface(ABC):
for each request.
Returns:
A EngineCoreOutputs object containing the outputs for each request.
A dict of client index to EngineCoreOutputs object containing the
outputs for each request originating from that client.
"""
raise NotImplementedError
@ -126,6 +127,11 @@ class SchedulerInterface(ABC):
"""
raise NotImplementedError
@abstractmethod
def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
raise NotImplementedError
@abstractmethod
def make_stats(self) -> Optional["SchedulerStats"]:
"""Make a SchedulerStats object for logging.

View File

@ -58,7 +58,8 @@ class Scheduler(SchedulerInterface):
# request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine
# case to track request lifetimes efficiently.
self.include_finished_set = include_finished_set
self.finished_req_ids_dict: Optional[dict[int, set[str]]] = (
defaultdict(set) if include_finished_set else None)
# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
@ -693,7 +694,7 @@ class Scheduler(SchedulerInterface):
self,
scheduler_output: SchedulerOutput,
model_runner_output: ModelRunnerOutput,
) -> EngineCoreOutputs:
) -> dict[int, EngineCoreOutputs]:
sampled_token_ids = model_runner_output.sampled_token_ids
spec_token_ids = model_runner_output.spec_token_ids
logprobs = model_runner_output.logprobs
@ -701,7 +702,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: list[Request] = []
outputs: list[EngineCoreOutput] = []
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: Optional[SpecDecodingStats] = None
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
@ -797,7 +798,7 @@ class Scheduler(SchedulerInterface):
if new_token_ids or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs.append(
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
@ -828,17 +829,38 @@ class Scheduler(SchedulerInterface):
self._cached_reqs_data[req_data.req_id].append(req_data)
self.running = new_running
engine_core_outputs = EngineCoreOutputs(
outputs=outputs,
scheduler_stats=self.make_stats(spec_decoding_stats),
)
if self.include_finished_set:
#TODO currently sending duplicates here, improve this
engine_core_outputs.finished_requests = (
scheduler_output.finished_req_ids | self.finished_req_ids)
# Create EngineCoreOutputs for all clients that have requests with
# outputs in this step.
engine_core_outputs = {
client_index: EngineCoreOutputs(outputs=outs)
for client_index, outs in outputs.items()
}
finished_req_ids = self.finished_req_ids_dict
if finished_req_ids is not None:
# Include ids of requests that finished since last outputs
# were sent.
for client_index, finished_set in finished_req_ids.items():
# Set finished request set in EngineCoreOutputs for this client.
if (eco := engine_core_outputs.get(client_index)) is not None:
eco.finished_requests = finished_set
else:
engine_core_outputs[client_index] = EngineCoreOutputs(
finished_requests=finished_set)
finished_req_ids.clear()
if engine_core_outputs:
# Return stats to only one of the front-ends.
next(iter(engine_core_outputs.values())).scheduler_stats = (
self.make_stats(spec_decoding_stats))
return engine_core_outputs
def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
return len(self.running), len(self.waiting)
def add_request(self, request: Request) -> None:
self.waiting.append(request)
self.requests[request.request_id] = request
@ -880,8 +902,11 @@ class Scheduler(SchedulerInterface):
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request)
self._cached_reqs_data.pop(request.request_id, None)
self.finished_req_ids.add(request.request_id)
request_id = request.request_id
self._cached_reqs_data.pop(request_id, None)
self.finished_req_ids.add(request_id)
if self.finished_req_ids_dict is not None:
self.finished_req_ids_dict[request.client_index].add(request_id)
if not delay_free_blocks:
self._free_blocks(request)

View File

@ -44,10 +44,6 @@ class EngineCoreRequest(
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
# but this object is currently not playing well with msgspec
# due to circular imports and typing we have in data.py
request_id: str
prompt_token_ids: list[int]
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
@ -59,6 +55,10 @@ class EngineCoreRequest(
lora_request: Optional[LoRARequest]
cache_salt: Optional[str]
# Index of the client, used to ensure outputs are sent back to the same
# client for this request when scaling out the front-end.
client_index: int = 0
# Used in DP case to indicate which wave of requests this is expected to
# belong to, to cover a race condition where the request is sent before
# a wave finished notification is received.

View File

@ -36,6 +36,7 @@ from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
setup_default_loggers)
from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
logger = init_logger(__name__)
@ -54,6 +55,8 @@ class AsyncLLM(EngineClient):
log_requests: bool = True,
start_engine_loop: bool = True,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0,
) -> None:
"""
Create an AsyncLLM.
@ -124,6 +127,8 @@ class AsyncLLM(EngineClient):
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
client_addresses=client_addresses,
client_index=client_index,
)
if self.stat_loggers:
for stat_logger in self.stat_loggers[0]:
@ -145,6 +150,8 @@ class AsyncLLM(EngineClient):
stat_loggers: Optional[list[StatLoggerFactory]] = None,
disable_log_requests: bool = False,
disable_log_stats: bool = False,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0,
) -> "AsyncLLM":
if not envs.VLLM_USE_V1:
raise ValueError(
@ -162,6 +169,8 @@ class AsyncLLM(EngineClient):
log_requests=not disable_log_requests,
log_stats=not disable_log_stats,
usage_context=usage_context,
client_addresses=client_addresses,
client_index=client_index,
)
@classmethod
@ -195,6 +204,8 @@ class AsyncLLM(EngineClient):
def shutdown(self):
"""Shutdown, cleaning up the background proc and IPC."""
shutdown_prometheus()
if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()
@ -398,7 +409,6 @@ class AsyncLLM(EngineClient):
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
if stat_loggers:
assert outputs.scheduler_stats is not None
AsyncLLM._record_stats(
stat_loggers[outputs.engine_index],
scheduler_stats=outputs.scheduler_stats,
@ -422,7 +432,7 @@ class AsyncLLM(EngineClient):
@staticmethod
def _record_stats(
stat_loggers: list[StatLoggerBase],
scheduler_stats: SchedulerStats,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
):
"""static so that it can be used from the output_handler task

View File

@ -0,0 +1,252 @@
# SPDX-License-Identifier: Apache-2.0
import multiprocessing
import time
import weakref
from typing import Optional
import msgspec.msgpack
import zmq
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
from vllm.v1.serial_utils import MsgpackDecoder
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
logger = init_logger(__name__)
class DPCoordinator:
"""Coordinator process used for data-parallel deployments (DP>1).
Intermediates between multiple DP engine rank processes and one or more
front-end API server processes.
* Collects stats from each DP engine (currently just waiting and running
queue lengths), and publishes these to all front-ends for use in
load-balancing decisions.
* Keeps track of the current DP "request wave" number and running state
of the engines. This is received from the DP rank 0 engine and published
to the front-end processes along with the current load stats.
The engines alternate between a global running/paused state. The global
"request wave" number is a count of the number of times that the workers
collectively move from a running state to a paused state. This transition
is synchronized via the all-reduce operation performed in the
DPEngineCoreProc._has_global_unfinished_reqs method.
* Broadcasts the START_DP_WAVE message to engines to move them from paused
to running state when one engine receives a new request. This can happen
in two cases:
1) A front-end sending a new request while the engines are paused will
concurrently notify the coordinator.
2) An engine receiving a request for a stale request wave while in paused
state will notify the coordinator.
Engines will move into running state when receiving a new request or
START_DP_WAVE message.
"""
def __init__(self, parallel_config: ParallelConfig):
# Assume coordinator is colocated with front-end procs.
front_publish_address = get_open_zmq_ipc_path()
dp_size = parallel_config.data_parallel_size
assert dp_size > 1, "Coordinator only used for data parallel"
local_only = dp_size == parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
back_publish_address = get_engine_client_zmq_addr(local_only, host)
back_output_address = get_engine_client_zmq_addr(local_only, host)
context = get_mp_context()
self.proc: multiprocessing.Process = context.Process(
target=CoordinatorProc.run_coordinator,
name="VLLM_DP_Coordinator",
kwargs={
"engine_count": parallel_config.data_parallel_size,
"front_publish_address": front_publish_address,
"back_output_address": back_output_address,
"back_publish_address": back_publish_address,
},
daemon=True)
self.proc.start()
self.stats_publish_address = front_publish_address
self.coord_in_address = back_publish_address
self.coord_out_address = back_output_address
self._finalizer = weakref.finalize(self, shutdown, [self.proc])
def get_stats_publish_address(self) -> str:
return self.stats_publish_address
def get_engine_socket_addresses(self) -> tuple[str, str]:
"""Returns tuple of ZMQ input address, output address."""
return self.coord_in_address, self.coord_out_address
def close(self):
self._finalizer()
class EngineState:
def __init__(self):
self.request_counts = [0, 0] # [waiting, running]
class CoordinatorProc:
def __init__(self, engine_count: int):
self.ctx = zmq.Context()
self.engines = [EngineState() for _ in range(engine_count)]
self.current_wave = 0
self.engines_running = False
self.stats_changed = False
@staticmethod
def run_coordinator(
engine_count: int,
front_publish_address: str,
back_output_address: str,
back_publish_address: str,
):
coordinator = CoordinatorProc(engine_count=engine_count)
try:
coordinator.process_input_socket(
front_publish_address,
back_output_address,
back_publish_address,
)
except KeyboardInterrupt:
logger.info("DP Coordinator process exiting")
def process_input_socket(self, front_publish_address: str,
back_output_address: str,
back_publish_address: str):
decoder = MsgpackDecoder(EngineCoreOutputs)
with make_zmq_socket(
path=front_publish_address, # IPC
ctx=self.ctx,
socket_type=zmq.XPUB,
bind=True,
) as publish_front, make_zmq_socket(
path=back_output_address, # IPC or TCP
ctx=self.ctx,
socket_type=zmq.PULL,
bind=True,
) as output_back, make_zmq_socket(
path=back_publish_address, # IPC or TCP
ctx=self.ctx,
socket_type=zmq.XPUB,
bind=True,
) as publish_back:
poller = zmq.Poller()
poller.register(publish_front, zmq.POLLIN)
poller.register(output_back, zmq.POLLIN)
last_publish_time = 0
while True:
elapsed = int(time.time() * 1000) - last_publish_time
# Send at 100 ms interval if the stats have changed,
# or otherwise every 3 seconds.
wait_for = 100 if self.stats_changed else 3000
events = poller.poll(timeout=max(0, wait_for - elapsed))
if not events:
# Poller timeout - publish current stats to front-ends.
engine_req_counts_list = self._get_engine_counts()
to_publish = (engine_req_counts_list, self.current_wave,
self.engines_running)
publish_front.send(msgspec.msgpack.encode(to_publish))
last_publish_time = int(time.time() * 1000)
self.stats_changed = False
continue
events = dict(events)
if publish_front in events:
buffer = publish_front.recv()
if buffer == b'\x01':
# Ignore subscription messages.
continue
# We received a message on the front-end XPUB socket,
# from an API server sending a new request while the
# engines are paused, so that we can wake the other
# engines.
engine_to_exclude, wave = msgspec.msgpack.decode(buffer)
if wave < self.current_wave:
# If the wave number is stale, ensure the message is
# handled by all the engines.
engine_to_exclude = None
if not self.engines_running:
self.engines_running = True
self.stats_changed = True
self._send_start_wave(publish_back, self.current_wave,
engine_to_exclude)
if output_back in events:
# We received a message from one of the engines.
buffer = output_back.recv()
outputs: EngineCoreOutputs = decoder.decode(buffer)
assert not outputs.outputs
assert outputs.utility_output is None
eng_index = outputs.engine_index
if outputs.scheduler_stats:
# 1. Updated request load stats - update our local
# state with these.
stats = self.engines[eng_index].request_counts
stats[0] = outputs.scheduler_stats.num_waiting_reqs
stats[1] = outputs.scheduler_stats.num_running_reqs
self.stats_changed = True
if (wave := outputs.wave_complete) is not None:
# 2. Notification from rank 0 engine that we've
# moved into the global paused state
# (engines_running==False)
if self.current_wave <= wave:
logger.debug("Moving DP wave from %d to %d.",
self.current_wave, wave)
self.current_wave = wave + 1
self.engines_running = False
self.stats_changed = True
elif (wave := outputs.start_wave) is not None and (
wave > self.current_wave or
(wave == self.current_wave
and not self.engines_running)):
# 3. The engine received request for a non-current wave
# so we must ensure that other engines progress to the
# next wave (race condition handling).
logger.debug(
"Starting wave %d after notification of "
"stale wave request from engine.", wave)
self.current_wave = wave
self.engines_running = True
self.stats_changed = True
self._send_start_wave(publish_back, wave, eng_index)
@staticmethod
def _send_start_wave(socket: zmq.Socket, wave: int,
exclude_engine_index: Optional[int]):
"""Broadcast the START_DP_WAVE message to all the engines.
It includes the current wave number and index of engine which
has already received a request with this wave number and so doesn't
require additional notification.
"""
wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index))
socket.send_multipart(
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
def _get_engine_counts(self) -> list[list[int]]:
"""Return list of [waiting, running] count lists for each engine."""
return [e.request_counts for e in self.engines]

View File

@ -7,6 +7,7 @@ import threading
import time
from collections import deque
from concurrent.futures import Future
from contextlib import ExitStack
from inspect import isclass, signature
from logging import DEBUG
from typing import Any, Callable, Optional, TypeVar, Union
@ -22,7 +23,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface
@ -33,10 +34,12 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@ -211,7 +214,7 @@ class EngineCore:
# Re-raise exception
raise err
def step(self) -> tuple[EngineCoreOutputs, bool]:
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output.
Returns tuple of outputs and a flag indicating whether the model
@ -221,10 +224,7 @@ class EngineCore:
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return EngineCoreOutputs(
outputs=[],
scheduler_stats=self.scheduler.make_stats(),
), False
return {}, False
scheduler_output = self.scheduler.schedule()
model_output = self.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
@ -234,7 +234,7 @@ class EngineCore:
scheduler_output.total_num_scheduled_tokens > 0)
def step_with_batch_queue(
self) -> tuple[Optional[EngineCoreOutputs], bool]:
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
"""Schedule and execute batches with the batch queue.
Note that if nothing to output in this step, None is returned.
@ -276,8 +276,8 @@ class EngineCore:
# Blocking until the first result is available.
model_output = future.result()
self.batch_queue.task_done()
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output)
engine_core_outputs = (self.scheduler.update_from_output(
scheduler_output, model_output))
return engine_core_outputs, scheduled_batch
@ -362,7 +362,7 @@ class EngineCoreProc(EngineCore):
self,
vllm_config: VllmConfig,
on_head_node: bool,
input_address: str,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
engine_index: int = 0,
@ -375,65 +375,70 @@ class EngineCoreProc(EngineCore):
# Create input socket.
input_ctx = zmq.Context()
identity = engine_index.to_bytes(length=2, byteorder="little")
input_socket = make_zmq_socket(input_ctx,
input_address,
zmq.DEALER,
identity=identity,
bind=False)
try:
with make_zmq_socket(input_ctx,
handshake_address,
zmq.DEALER,
identity=identity,
linger=5000,
bind=False) as handshake_socket:
# Register engine with front-end.
output_address = self.startup_handshake(
input_socket, on_head_node, vllm_config.parallel_config)
addresses = self.startup_handshake(handshake_socket, on_head_node,
vllm_config.parallel_config)
self.client_count = len(addresses.outputs)
# Update config which may have changed from the handshake.
vllm_config.__post_init__()
# Set up data parallel environment.
self.has_coordinator = addresses.coordinator_output is not None
self._init_data_parallel(vllm_config)
# Initialize engine core and model.
super().__init__(vllm_config, executor_class, log_stats,
executor_fail_callback)
self.engine_index = engine_index
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
self.engines_running = False
self.last_counts = (0, 0)
# Send ready message.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
input_socket.send(
handshake_socket.send(
msgspec.msgpack.encode({
"status": "READY",
"local": on_head_node,
"num_gpu_blocks": num_gpu_blocks,
}))
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = input_queue
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
threading.Thread(target=self.process_input_socket,
args=(input_socket, ),
daemon=True).start()
input_socket = None
self.output_thread = threading.Thread(
target=self.process_output_socket,
args=(output_address, engine_index),
daemon=True)
self.output_thread.start()
finally:
if input_socket is not None:
input_socket.close(linger=0)
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = input_queue
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
bytes]]()
threading.Thread(target=self.process_input_sockets,
args=(addresses.inputs, addresses.coordinator_input,
identity),
daemon=True).start()
self.output_thread = threading.Thread(
target=self.process_output_sockets,
args=(addresses.outputs, addresses.coordinator_output,
engine_index),
daemon=True)
self.output_thread.start()
@staticmethod
def startup_handshake(input_socket: zmq.Socket, on_head_node: bool,
parallel_config: ParallelConfig) -> str:
def startup_handshake(
handshake_socket: zmq.Socket, on_head_node: bool,
parallel_config: ParallelConfig) -> EngineZmqAddresses:
# Send registration message.
input_socket.send(
handshake_socket.send(
msgspec.msgpack.encode({
"status": "HELLO",
"local": on_head_node,
@ -441,22 +446,20 @@ class EngineCoreProc(EngineCore):
# Receive initialization message.
logger.info("Waiting for init message from front-end.")
if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000):
if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
raise RuntimeError("Did not receive response from front-end "
f"process within {HANDSHAKE_TIMEOUT_MINS} "
f"minutes")
init_bytes = input_socket.recv()
init_message = msgspec.msgpack.decode(init_bytes)
init_bytes = handshake_socket.recv()
init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
init_bytes, type=EngineHandshakeMetadata)
logger.debug("Received init message: %s", init_message)
output_socket_address = init_message["output_socket_address"]
#TBD(nick) maybe replace IP with configured head node address
received_parallel_config = init_message["parallel_config"]
received_parallel_config = init_message.parallel_config
for key, value in received_parallel_config.items():
setattr(parallel_config, key, value)
return output_socket_address
return init_message.addresses
@staticmethod
def run_engine_core(*args,
@ -528,7 +531,7 @@ class EngineCoreProc(EngineCore):
"""Exits when an engine step needs to be performed."""
waited = False
while not self.engines_running and not (self.scheduler.has_requests()):
while not self.engines_running and not self.scheduler.has_requests():
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
logger.debug("EngineCore waiting for work.")
waited = True
@ -549,8 +552,8 @@ class EngineCoreProc(EngineCore):
# Step the engine core.
outputs, model_executed = self.step_fn()
# Put EngineCoreOutputs into the output queue.
if outputs is not None:
self.output_queue.put_nowait(outputs)
for output in (outputs.items() if outputs else ()):
self.output_queue.put_nowait(output)
return model_executed
@ -563,7 +566,7 @@ class EngineCoreProc(EngineCore):
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.UTILITY:
call_id, method_name, args = request
client_idx, call_id, method_name, args = request
output = UtilityOutput(call_id)
try:
method = getattr(self, method_name)
@ -574,7 +577,7 @@ class EngineCoreProc(EngineCore):
output.failure_message = (f"Call to {method_name} method"
f" failed: {str(e)}")
self.output_queue.put_nowait(
EngineCoreOutputs(utility_output=output))
(client_idx, EngineCoreOutputs(utility_output=output)))
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
raise RuntimeError("Executor failed.")
else:
@ -607,27 +610,68 @@ class EngineCoreProc(EngineCore):
logger.fatal("vLLM shutdown signal from EngineCore failed "
"to send. Please report this issue.")
def process_input_socket(self, input_socket: zmq.Socket):
def process_input_sockets(self, input_addresses: list[str],
coord_input_address: Optional[str],
identity: bytes):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()
while True:
# (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
with ExitStack() as stack, zmq.Context() as ctx:
input_sockets = [
stack.enter_context(
make_zmq_socket(ctx,
input_address,
zmq.DEALER,
identity=identity,
bind=False))
for input_address in input_addresses
]
if coord_input_address is None:
coord_socket = None
else:
coord_socket = stack.enter_context(
make_zmq_socket(ctx,
coord_input_address,
zmq.XSUB,
identity=identity,
bind=False))
# Send subscription message to coordinator.
coord_socket.send(b'\x01')
# Deserialize the request data.
decoder = add_request_decoder if (
request_type == EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frames)
# Register sockets with poller.
poller = zmq.Poller()
for input_socket in input_sockets:
# Send initial message to each input socket - this is required
# before the front-end ROUTER socket can send input messages
# back to us.
input_socket.send(b'')
poller.register(input_socket, zmq.POLLIN)
if coord_socket is not None:
poller.register(coord_socket, zmq.POLLIN)
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
while True:
for input_socket, _ in poller.poll():
# (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(
copy=False)
request_type = EngineCoreRequestType(
bytes(type_frame.buffer))
def process_output_socket(self, output_path: str, engine_index: int):
# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frames)
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
def process_output_sockets(self, output_paths: list[str],
coord_output_path: Optional[str],
engine_index: int):
"""Output socket IO thread."""
# Msgpack serialization encoding.
@ -641,30 +685,49 @@ class EngineCoreProc(EngineCore):
# We must set linger to ensure the ENGINE_CORE_DEAD
# message is sent prior to closing the socket.
with zmq_socket_ctx(output_path, zmq.constants.PUSH,
linger=4000) as socket:
with ExitStack() as stack, zmq.Context() as ctx:
sockets = [
stack.enter_context(
make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
for output_path in output_paths
]
coord_socket = stack.enter_context(
make_zmq_socket(
ctx, coord_output_path, zmq.PUSH, bind=False,
linger=4000)) if coord_output_path is not None else None
max_reuse_bufs = len(sockets) + 1
while True:
outputs = self.output_queue.get()
if outputs == EngineCoreProc.ENGINE_CORE_DEAD:
socket.send(outputs, copy=False)
output = self.output_queue.get()
if output == EngineCoreProc.ENGINE_CORE_DEAD:
for socket in sockets:
socket.send(output)
break
assert not isinstance(outputs, bytes)
assert not isinstance(output, bytes)
client_index, outputs = output
outputs.engine_index = engine_index
if client_index == -1:
# Don't reuse buffer for coordinator message
# which will be very small.
assert coord_socket is not None
coord_socket.send_multipart(encoder.encode(outputs))
continue
# Reclaim buffers that zmq is finished with.
while pending and pending[-1][0].done:
reuse_buffers.append(pending.pop()[2])
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
buffers = encoder.encode_into(outputs, buffer)
tracker = socket.send_multipart(buffers,
copy=False,
track=True)
tracker = sockets[client_index].send_multipart(buffers,
copy=False,
track=True)
if not tracker.done:
ref = outputs if len(buffers) > 1 else None
pending.appendleft((tracker, ref, buffer))
elif len(reuse_buffers) < 2:
# Keep at most 2 buffers to reuse.
elif len(reuse_buffers) < max_reuse_bufs:
# Limit the number of buffers to reuse.
reuse_buffers.append(buffer)
@ -676,7 +739,7 @@ class DPEngineCoreProc(EngineCoreProc):
self,
vllm_config: VllmConfig,
on_head_node: bool,
input_address: str,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
):
@ -691,10 +754,11 @@ class DPEngineCoreProc(EngineCoreProc):
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
self.current_wave = 0
# Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(vllm_config, on_head_node, input_address,
super().__init__(vllm_config, on_head_node, handshake_address,
executor_class, log_stats, dp_rank)
def _init_data_parallel(self, vllm_config: VllmConfig):
@ -726,7 +790,6 @@ class DPEngineCoreProc(EngineCoreProc):
self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
self.current_wave = 0
def shutdown(self):
super().shutdown()
@ -734,22 +797,23 @@ class DPEngineCoreProc(EngineCoreProc):
stateless_destroy_torch_distributed_process_group(dp_group)
def add_request(self, request: EngineCoreRequest):
if request.current_wave != self.current_wave:
if self.has_coordinator and request.current_wave != self.current_wave:
if request.current_wave > self.current_wave:
self.current_wave = request.current_wave
elif not self.engines_running:
# Request received for an already-completed wave, notify
# front-end that we need to start the next one.
self.output_queue.put_nowait(
EngineCoreOutputs(start_wave=self.current_wave))
(-1, EngineCoreOutputs(start_wave=self.current_wave)))
super().add_request(request)
def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
if request_type == EngineCoreRequestType.START_DP_WAVE:
new_wave: int = request
if new_wave >= self.current_wave:
new_wave, exclude_eng_index = request
if exclude_eng_index != self.engine_index and (
new_wave >= self.current_wave):
self.current_wave = new_wave
if not self.engines_running:
logger.debug("EngineCore starting idle loop for wave %d.",
@ -758,6 +822,18 @@ class DPEngineCoreProc(EngineCoreProc):
else:
super()._handle_client_request(request_type, request)
def _maybe_publish_request_counts(self):
if not self.has_coordinator:
return
# Publish our request counts (if they've changed).
counts = self.scheduler.get_request_counts()
if counts != self.last_counts:
self.last_counts = counts
stats = SchedulerStats(*counts)
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(scheduler_stats=stats)))
def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case."""
@ -768,6 +844,8 @@ class DPEngineCoreProc(EngineCoreProc):
# 2) Step the engine core.
executed = self._process_engine_step()
self._maybe_publish_request_counts()
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if not executed:
if not local_unfinished_reqs and not self.engines_running:
@ -788,7 +866,8 @@ class DPEngineCoreProc(EngineCoreProc):
logger.debug("Wave %d finished, pausing engine loop.",
self.current_wave)
self.output_queue.put_nowait(
EngineCoreOutputs(wave_complete=self.current_wave))
(-1,
EngineCoreOutputs(wave_complete=self.current_wave)))
self.current_wave += 1
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:

View File

@ -2,6 +2,7 @@
import asyncio
import contextlib
import queue
import sys
import uuid
import weakref
from abc import ABC, abstractmethod
@ -9,26 +10,28 @@ from collections import deque
from collections.abc import Awaitable, Sequence
from concurrent.futures import Future
from dataclasses import dataclass
from enum import Enum, auto
from threading import Thread
from typing import Any, Callable, Optional, TypeVar, Union
import msgspec
import msgspec.msgpack
import zmq
import zmq.asyncio
from vllm.config import ParallelConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.utils import (get_open_port, get_open_zmq_inproc_path,
get_open_zmq_ipc_path, get_tcp_uri, make_zmq_socket)
from vllm.utils import (get_open_zmq_inproc_path, make_zmq_socket,
zmq_socket_ctx)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
from vllm.v1.utils import CoreEngineProcManager
from vllm.v1.utils import (CoreEngine, CoreEngineProcManager,
EngineZmqAddresses, get_engine_client_zmq_addr,
wait_for_engine_startup)
logger = init_logger(__name__)
@ -36,8 +39,6 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]]
_R = TypeVar('_R') # Return type for collective_rpc
STARTUP_POLL_PERIOD_MS = 10000
class EngineCoreClient(ABC):
"""
@ -207,7 +208,7 @@ class InprocClient(EngineCoreClient):
def get_output(self) -> EngineCoreOutputs:
outputs, _ = self.engine_core.step()
return outputs
return outputs.get(0) or EngineCoreOutputs()
def add_request(self, request: EngineCoreRequest) -> None:
self.engine_core.add_request(request)
@ -266,24 +267,6 @@ class InprocClient(EngineCoreClient):
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
class CoreEngineState(Enum):
NEW = auto()
CONNECTED = auto()
READY = auto()
class CoreEngine:
"""One per data parallel rank."""
def __init__(self, index: int = 0, local: bool = True):
self.local = local
self.index = index
self.identity = index.to_bytes(length=2, byteorder="little")
self.state = CoreEngineState.NEW
self.num_reqs_in_flight = 0
@dataclass
class BackgroundResources:
"""Used as a finalizer for clean shutdown, avoiding
@ -291,9 +274,12 @@ class BackgroundResources:
ctx: Union[zmq.Context]
local_engine_manager: Optional[CoreEngineProcManager] = None
coordinator: Optional[DPCoordinator] = None
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
first_req_send_socket: Optional[zmq.asyncio.Socket] = None
output_queue_task: Optional[asyncio.Task] = None
stats_update_task: Optional[asyncio.Task] = None
shutdown_path: Optional[str] = None
# Set if any of the engines are dead. Here so that the output
@ -306,16 +292,21 @@ class BackgroundResources:
self.engine_dead = True
if self.local_engine_manager is not None:
self.local_engine_manager.close()
if self.coordinator is not None:
self.coordinator.close()
if self.output_queue_task is not None:
self.output_queue_task.cancel()
if self.stats_update_task is not None:
self.stats_update_task.cancel()
# ZMQ context termination can hang if the sockets
# aren't explicitly closed first.
if self.output_socket is not None:
self.output_socket.close(linger=0)
if self.input_socket is not None:
self.input_socket.close(linger=0)
for socket in (self.output_socket, self.input_socket,
self.first_req_send_socket):
if socket is not None:
socket.close(linger=0)
if self.shutdown_path is not None:
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
@ -350,6 +341,7 @@ class MPClient(EngineCoreClient):
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: Optional[dict[str, str]] = None,
):
self.vllm_config = vllm_config
# Serialization setup.
@ -369,8 +361,8 @@ class MPClient(EngineCoreClient):
try:
parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
start_index = parallel_config.data_parallel_rank
local_start_index = parallel_config.data_parallel_rank_local
dp_size = parallel_config.data_parallel_size
# SPMD mode is where there is an LLM instance per DP rank and
# one core engine per LLM, see
@ -382,42 +374,53 @@ class MPClient(EngineCoreClient):
CoreEngine(index=local_start_index, local=True)
]
else:
assert start_index == 0
assert parallel_config.data_parallel_rank == 0
local_start_index = 0
self.core_engines = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(parallel_config.data_parallel_size)
for i in range(dp_size)
]
input_address, output_address = self._get_zmq_addresses(
parallel_config, spmd_mode)
local_only = spmd_mode or local_engine_count == dp_size
self.stats_update_address: Optional[str] = None
if client_addresses is not None:
input_address = client_addresses["input_address"]
output_address = client_addresses["output_address"]
self.stats_update_address = client_addresses.get(
"stats_update_address")
else:
host = parallel_config.data_parallel_master_ip
input_address = get_engine_client_zmq_addr(local_only, host)
output_address = get_engine_client_zmq_addr(local_only, host)
# Create input and output sockets.
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, input_address, zmq.ROUTER, bind=True)
self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.constants.PULL)
# Start local engines.
if local_engine_count:
# In server mode, start_index and local_start_index will
# both be 0.
self.resources.local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
input_address=input_address,
on_head_node=True,
local_engine_count=local_engine_count,
start_index=start_index,
local_start_index=local_start_index)
self.ctx, output_address, zmq.PULL)
if client_addresses is None:
self._init_engines_direct(vllm_config, local_only,
local_start_index, input_address,
output_address, executor_class,
log_stats)
coordinator = self.resources.coordinator
if coordinator:
self.stats_update_address = (
coordinator.get_stats_publish_address())
# Wait for ready messages from each engine on the input socket.
identities = set(e.identity for e in self.core_engines)
sync_input_socket = zmq.Socket.shadow(self.input_socket)
while identities:
if not sync_input_socket.poll(timeout=600_000):
raise TimeoutError("Timed out waiting for engines to send"
"initial message on input socket.")
identity, _ = sync_input_socket.recv_multipart()
identities.remove(identity)
self.core_engine = self.core_engines[0]
# Wait for engine core process(es) to start.
self._wait_for_engine_startup(output_address, parallel_config)
self.utility_results: dict[int, AnyFuture] = {}
# Request objects which may contain pytorch-allocated tensors
@ -430,116 +433,67 @@ class MPClient(EngineCoreClient):
if not success:
self._finalizer()
@staticmethod
def _get_zmq_addresses(parallel_config: ParallelConfig,
spmd_mode: bool) -> tuple[str, str]:
"""Returns (input_address, output_address)."""
dp_size = parallel_config.data_parallel_size
def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool,
local_start_index: int, input_address: str,
output_address: str,
executor_class: type[Executor], log_stats: bool):
"""Self-contained client mode, launch engine and coordinator process
as needed."""
parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
start_index = parallel_config.data_parallel_rank
host = parallel_config.data_parallel_master_ip
if local_engine_count == dp_size or spmd_mode:
input_address = get_open_zmq_ipc_path()
output_address = get_open_zmq_ipc_path()
else:
host = parallel_config.data_parallel_master_ip
input_port = parallel_config.data_parallel_rpc_port
output_port = get_open_port()
input_address = get_tcp_uri(host, input_port)
output_address = get_tcp_uri(host, output_port)
if len(self.core_engines) > 1:
self.resources.coordinator = DPCoordinator(parallel_config)
return input_address, output_address
handshake_address = get_engine_client_zmq_addr(
local_only, host, parallel_config.data_parallel_rpc_port)
def _wait_for_engine_startup(self, output_address: str,
parallel_config: ParallelConfig):
# Get a sync handle to the socket which can be sync or async.
sync_input_socket = zmq.Socket.shadow(self.input_socket)
with zmq_socket_ctx(handshake_address, zmq.ROUTER,
bind=True) as handshake_socket:
# Wait for engine core process(es) to send ready messages.
local_count = parallel_config.data_parallel_size_local
remote_count = len(self.core_engines) - local_count
# [local, remote] counts
conn_pending, start_pending = [local_count, remote_count], [0, 0]
# Start local engines.
if local_engine_count:
# In server mode, start_index and local_start_index will
# both be 0.
self.resources.local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
handshake_address=handshake_address,
on_head_node=True,
local_engine_count=local_engine_count,
start_index=start_index,
local_start_index=local_start_index)
poller = zmq.Poller()
poller.register(sync_input_socket, zmq.POLLIN)
proc_manager = self.resources.local_engine_manager
if proc_manager is not None:
for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
while any(conn_pending) or any(start_pending):
events = poller.poll(STARTUP_POLL_PERIOD_MS)
if not events:
if any(conn_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect.", *conn_pending)
if any(start_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to start.", *start_pending)
continue
if len(events) > 1 or events[0][0] != sync_input_socket:
# One of the local core processes exited.
finished = proc_manager.finished_procs(
) if proc_manager else {}
raise RuntimeError("Engine core initialization failed. "
"See root cause above. "
f"Failed core proc(s): {finished}")
# Wait for engine core process(es) to start.
self._wait_for_engine_startup(handshake_socket, input_address,
output_address)
# Receive HELLO and READY messages from the input socket.
eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart()
eng_index = int.from_bytes(eng_identity, byteorder="little")
engine = next(
(e for e in self.core_engines if e.identity == eng_identity),
None)
if engine is None:
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"]
if local != engine.local:
raise RuntimeError(f"{status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}")
def _wait_for_engine_startup(self, handshake_socket: zmq.Socket,
input_address: str, output_address: str):
addresses = EngineZmqAddresses(
inputs=[input_address],
outputs=[output_address],
)
if status == "HELLO" and engine.state == CoreEngineState.NEW:
coordinator = self.resources.coordinator
if coordinator is not None:
addresses.coordinator_input, addresses.coordinator_output = (
coordinator.get_engine_socket_addresses())
# Send init message with DP config info.
init_message = self.encoder.encode({
"output_socket_address": output_address,
"parallel_config": {
"data_parallel_master_ip":
parallel_config.data_parallel_master_ip,
"data_parallel_master_port":
parallel_config.data_parallel_master_port,
"data_parallel_size":
parallel_config.data_parallel_size,
},
})
sync_input_socket.send_multipart((eng_identity, *init_message),
copy=False)
conn_pending[0 if local else 1] -= 1
start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED
elif status == "READY" and (engine.state
== CoreEngineState.CONNECTED):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
cache_config = self.vllm_config.cache_config
num_gpu_blocks = cache_config.num_gpu_blocks or 0
num_gpu_blocks += msg['num_gpu_blocks']
cache_config.num_gpu_blocks = num_gpu_blocks
start_pending[0 if local else 1] -= 1
engine.state = CoreEngineState.READY
else:
raise RuntimeError(f"Unexpected {status} message for "
f"{'local' if local else 'remote'} engine "
f"{eng_index} in {engine.state} state.")
logger.debug("%s from %s core engine process %s.", status,
"local" if local else "remote", eng_index)
wait_for_engine_startup(
handshake_socket,
addresses,
self.core_engines,
self.vllm_config.parallel_config,
self.vllm_config.cache_config,
self.resources.local_engine_manager,
coordinator.proc if coordinator else None,
)
def shutdown(self):
# Terminate background resources.
@ -605,8 +559,8 @@ class SyncMPClient(MPClient):
try:
shutdown_socket.bind(shutdown_path)
poller = zmq.Poller()
poller.register(shutdown_socket)
poller.register(out_socket)
poller.register(shutdown_socket, zmq.POLLIN)
poller.register(out_socket, zmq.POLLIN)
while True:
socks = poller.poll()
if not socks:
@ -668,7 +622,7 @@ class SyncMPClient(MPClient):
future: Future[Any] = Future()
self.utility_results[call_id] = future
self._send_input(EngineCoreRequestType.UTILITY,
(call_id, method, args))
(0, call_id, method, args))
return future.result()
@ -730,15 +684,21 @@ class SyncMPClient(MPClient):
class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
log_stats: bool):
def __init__(self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0):
super().__init__(
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
client_addresses=client_addresses,
)
self.client_index = client_index
self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs,
Exception]]()
try:
@ -854,12 +814,13 @@ class AsyncMPClient(MPClient):
future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
(call_id, method, args)))
(self.client_index, call_id, method, args)))
await self._send_input_message(message, engine, args)
self._ensure_output_queue_task()
return await future
async def add_request_async(self, request: EngineCoreRequest) -> None:
request.client_index = self.client_index
await self._send_input(EngineCoreRequestType.ADD, request)
self._ensure_output_queue_task()
@ -921,17 +882,120 @@ class DPAsyncMPClient(AsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
EngineCore."""
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
log_stats: bool):
def __init__(self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0):
self.current_wave = 0
self.engines_running = False
# To route aborts to the correct engine.
self.reqs_in_flight: dict[str, CoreEngine] = {}
super().__init__(vllm_config, executor_class, log_stats)
super().__init__(vllm_config, executor_class, log_stats,
client_addresses, client_index)
assert len(self.core_engines) > 1
# List of [waiting, running] pair per engine.
self.lb_engines: list[list[int]] = []
self.first_req_sock_addr = get_open_zmq_inproc_path()
self.first_req_send_socket = self.resources.first_req_send_socket = (
make_zmq_socket(self.ctx,
self.first_req_sock_addr,
zmq.PAIR,
bind=True))
try:
# If we are running in an asyncio event loop, start the stats task.
# Otherwise, it will be started lazily.
asyncio.get_running_loop()
self._ensure_stats_update_task()
except RuntimeError:
pass
def _ensure_stats_update_task(self):
resources = self.resources
if resources.stats_update_task is not None:
return
assert self.stats_update_address is not None
async def run_engine_stats_update_task():
with make_zmq_socket(self.ctx, self.stats_update_address,
zmq.XSUB) as socket, make_zmq_socket(
self.ctx,
self.first_req_sock_addr,
zmq.PAIR,
bind=False) as first_req_rcv_socket:
# Send subscription message.
await socket.send(b'\x01')
poller = zmq.asyncio.Poller()
poller.register(socket, zmq.POLLIN)
poller.register(first_req_rcv_socket, zmq.POLLIN)
while True:
events = await poller.poll()
if not self.engines_running and len(events) == 2 or (
events[0][0] == first_req_rcv_socket):
# Send a message to notify the coordinator that
# we're sending a request while the engines are
# paused, so that it can wake the others up
# (to run dummy EP loop).
self.engines_running = True
buf = first_req_rcv_socket.recv(
flags=zmq.NOBLOCK).result()
target_eng_index = int.from_bytes(buf, "little")
msg = msgspec.msgpack.encode(
(target_eng_index, self.current_wave))
await socket.send(msg)
buf = None
while True:
# Drain all stats events (we only care about latest).
future: asyncio.Future[bytes] = socket.recv(
flags=zmq.NOBLOCK)
if isinstance(future.exception(), zmq.Again):
break
buf = future.result()
if buf is None:
continue
# Update local load-balancing state.
counts, wave, running = msgspec.msgpack.decode(buf)
self.current_wave = wave
self.engines_running = running
self.lb_engines = counts
resources.stats_update_task = asyncio.create_task(
run_engine_stats_update_task())
def get_core_engine_for_request(self) -> CoreEngine:
if not self.lb_engines:
return self.core_engines[0]
# TODO use P2C alg for larger DP sizes
num_engines = len(self.lb_engines)
min_counts = [sys.maxsize, sys.maxsize]
eng_index = 0
for i in range(num_engines):
# Start from client_index to help with balancing when engines
# are empty.
idx = (self.client_index + i) % num_engines
counts = self.lb_engines[idx]
if counts < min_counts:
min_counts = counts
eng_index = idx
# Adjust local counts for better balancing between stats updates
# from the coordinator (which happen every 100ms).
if min_counts[0]:
min_counts[0] += 1
else:
min_counts[1] += 1
return self.core_engines[eng_index]
async def call_utility_async(self, method: str, *args) -> Any:
# Only the result from the first engine is returned.
return (await asyncio.gather(*[
@ -940,62 +1004,30 @@ class DPAsyncMPClient(AsyncMPClient):
]))[0]
async def add_request_async(self, request: EngineCoreRequest) -> None:
self._ensure_stats_update_task()
request.current_wave = self.current_wave
request.client_index = self.client_index
chosen_engine = self.get_core_engine_for_request()
self.reqs_in_flight[request.request_id] = chosen_engine
chosen_engine.num_reqs_in_flight += 1
to_await = self._send_input(EngineCoreRequestType.ADD, request,
chosen_engine)
if not self.engines_running:
# Send request to chosen engine and dp start loop
# control message to all other engines.
self.engines_running = True
to_await = asyncio.gather(
to_await, # type: ignore[assignment]
*self._start_wave_coros(exclude_index=chosen_engine.index))
# Notify coordinator that we're sending a request
await self.first_req_send_socket.send(chosen_engine.identity)
await to_await
self._ensure_output_queue_task()
def get_core_engine_for_request(self) -> CoreEngine:
return min(self.core_engines, key=lambda e: e.num_reqs_in_flight)
@staticmethod
async def process_engine_outputs(self: "DPAsyncMPClient",
outputs: EngineCoreOutputs):
if self.reqs_in_flight:
for req_id in outputs.finished_requests or ():
if engine := self.reqs_in_flight.pop(req_id, None):
engine.num_reqs_in_flight -= 1
if outputs.wave_complete is not None:
# Current wave is complete, move to next wave number
# and mark engines as paused.
if self.current_wave <= outputs.wave_complete:
self.current_wave = outputs.wave_complete + 1
self.engines_running = False
elif outputs.start_wave is not None and (
outputs.start_wave > self.current_wave or
(outputs.start_wave == self.current_wave
and not self.engines_running)):
# Engine received request for a non-current wave so we must ensure
# that other engines progress to the next wave.
self.current_wave = outputs.start_wave
self.engines_running = True
await asyncio.gather(*self._start_wave_coros(
exclude_index=outputs.engine_index))
def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]:
logger.debug("Sending start DP wave %d.", self.current_wave)
return [
self._send_input(EngineCoreRequestType.START_DP_WAVE,
self.current_wave, engine)
for engine in self.core_engines if engine.index != exclude_index
]
if outputs.finished_requests and self.reqs_in_flight:
for req_id in outputs.finished_requests:
self.reqs_in_flight.pop(req_id, None)
async def abort_requests_async(self, request_ids: list[str]) -> None:
if not request_ids:

View File

@ -12,13 +12,12 @@ from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5.0
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
@ -35,7 +34,7 @@ class StatLoggerBase(ABC):
...
@abstractmethod
def record(self, scheduler_stats: SchedulerStats,
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
...
@ -78,20 +77,22 @@ class LoggingStatLogger(StatLoggerBase):
# Compute summary metrics for tracked stats
return float(np.sum(tracked_stats) / (now - self.last_log_time))
def record(self, scheduler_stats: SchedulerStats,
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
"""Log Stats to standard output."""
if iteration_stats:
self._track_iteration_stats(iteration_stats)
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
if scheduler_stats is not None:
self.prefix_caching_metrics.observe(
scheduler_stats.prefix_cache_stats)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_logging.observe(
scheduler_stats.spec_decoding_stats)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_logging.observe(
scheduler_stats.spec_decoding_stats)
self.last_scheduler_stats = scheduler_stats
self.last_scheduler_stats = scheduler_stats
def log(self):
now = time.monotonic()
@ -131,10 +132,11 @@ class LoggingStatLogger(StatLoggerBase):
self.spec_decoding_logging.log(log_fn=log_fn)
def log_engine_initialized(self):
logger.info(
"vllm cache_config_info with initialization " \
"after num_gpu_blocks is: %d",
self.vllm_config.cache_config.num_gpu_blocks)
if self.vllm_config.cache_config.num_gpu_blocks:
logger.info(
"Engine %03d: vllm cache_config_info with initialization "
"after num_gpu_blocks is: %d", self.engine_index,
self.vllm_config.cache_config.num_gpu_blocks)
class PrometheusStatLogger(StatLoggerBase):
@ -144,7 +146,8 @@ class PrometheusStatLogger(StatLoggerBase):
_spec_decoding_cls = SpecDecodingProm
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self._unregister_vllm_metrics()
unregister_vllm_metrics()
self.vllm_config = vllm_config
self.engine_index = engine_index
# Use this flag to hide metrics that were deprecated in
@ -169,11 +172,13 @@ class PrometheusStatLogger(StatLoggerBase):
self.gauge_scheduler_running = self._gauge_cls(
name="vllm:num_requests_running",
documentation="Number of requests in model execution batches.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues)
self.gauge_scheduler_waiting = self._gauge_cls(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues)
#
@ -182,6 +187,7 @@ class PrometheusStatLogger(StatLoggerBase):
self.gauge_gpu_cache_usage = self._gauge_cls(
name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues)
self.counter_gpu_prefix_cache_queries = self._counter_cls(
@ -242,6 +248,9 @@ class PrometheusStatLogger(StatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames).labels(*labelvalues)
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
# See: https://github.com/vllm-project/vllm/pull/18053
self.histogram_iteration_tokens = \
self._histogram_cls(
name="vllm:iteration_tokens_total",
@ -340,6 +349,9 @@ class PrometheusStatLogger(StatLoggerBase):
#
# LoRA metrics
#
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
self.gauge_lora_info: Optional[prometheus_client.Gauge] = None
if vllm_config.lora_config is not None:
self.labelname_max_lora = "max_lora"
@ -350,13 +362,16 @@ class PrometheusStatLogger(StatLoggerBase):
self._gauge_cls(
name="vllm:lora_requests_info",
documentation="Running stats on lora requests.",
multiprocess_mode="sum",
labelnames=[
self.labelname_max_lora,
self.labelname_waiting_lora_adapters,
self.labelname_running_lora_adapters,
])
],
)
def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
metrics_info = config_obj.metrics_info()
metrics_info["engine"] = self.engine_index
@ -372,25 +387,28 @@ class PrometheusStatLogger(StatLoggerBase):
info_gauge = self._gauge_cls(
name=name,
documentation=documentation,
labelnames=metrics_info.keys()).labels(**metrics_info)
multiprocess_mode="mostrecent",
labelnames=metrics_info.keys(),
).labels(**metrics_info)
info_gauge.set(1)
def record(self, scheduler_stats: SchedulerStats,
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
"""Log to prometheus."""
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
if scheduler_stats is not None:
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage)
self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage)
self.counter_gpu_prefix_cache_queries.inc(
scheduler_stats.prefix_cache_stats.queries)
self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits)
self.counter_gpu_prefix_cache_queries.inc(
scheduler_stats.prefix_cache_stats.queries)
self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_prom.observe(
scheduler_stats.spec_decoding_stats)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_prom.observe(
scheduler_stats.spec_decoding_stats)
if iteration_stats is None:
return
@ -445,13 +463,6 @@ class PrometheusStatLogger(StatLoggerBase):
self.gauge_lora_info.labels(**lora_info_labels)\
.set_to_current_time()
@staticmethod
def _unregister_vllm_metrics():
# Unregister any existing vLLM collectors (for CI/CD
for collector in list(prometheus_client.REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
prometheus_client.REGISTRY.unregister(collector)
def log_engine_initialized(self):
self.log_metrics_info("cache_config", self.vllm_config.cache_config)

View File

@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
import os
import tempfile
from typing import Optional
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
from vllm.logger import init_logger
logger = init_logger(__name__)
# Global temporary directory for prometheus multiprocessing
_prometheus_multiproc_dir: Optional[tempfile.TemporaryDirectory] = None
def setup_multiprocess_prometheus():
"""Set up prometheus multiprocessing directory if not already configured.
"""
global _prometheus_multiproc_dir
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
_prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name
logger.debug("Created PROMETHEUS_MULTIPROC_DIR at %s",
_prometheus_multiproc_dir.name)
else:
logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup.")
def get_prometheus_registry():
"""Get the appropriate prometheus registry based on multiprocessing
configuration.
Returns:
Registry: A prometheus registry
"""
if os.getenv("PROMETHEUS_MULTIPROC_DIR") is not None:
logger.debug("Using multiprocess registry for prometheus metrics")
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
return registry
return REGISTRY
def unregister_vllm_metrics():
"""Unregister any existing vLLM collectors from the prometheus registry.
This is useful for testing and CI/CD where metrics may be registered
multiple times across test runs.
Also, in case of multiprocess, we need to unregister the metrics from the
global registry.
"""
registry = REGISTRY
# Unregister any existing vLLM collectors
for collector in list(registry._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
registry.unregister(collector)
def shutdown_prometheus():
"""Shutdown prometheus metrics."""
try:
pid = os.getpid()
multiprocess.mark_process_dead(pid)
logger.debug("Marked Prometheus metrics for process %d as dead", pid)
except Exception as e:
logger.error("Error during metrics cleanup: %s", str(e))

View File

@ -26,12 +26,13 @@ class Request:
multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: SamplingParams,
eos_token_id: Optional[int],
arrival_time: float,
client_index: int = 0,
lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None,
) -> None:
self.request_id = request_id
self.client_index = client_index
self.sampling_params = sampling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
@ -90,13 +91,13 @@ class Request:
return cls(
request_id=request.request_id,
client_index=request.client_index,
prompt_token_ids=request.prompt_token_ids,
multi_modal_inputs=request.mm_inputs,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
sampling_params=request.sampling_params,
eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,
lora_request=request.lora_request,
structured_output_request=StructuredOutputRequest(
sampling_params=request.sampling_params),

View File

@ -1,31 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
import os
import argparse
import multiprocessing
import time
import weakref
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing import Process, connection
from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union,
overload)
from multiprocessing.process import BaseProcess
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
Union, overload)
import msgspec
import torch
import zmq
from vllm.config import VllmConfig
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import get_mp_context, kill_process_tree
from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
get_tcp_uri, kill_process_tree)
from vllm.v1.executor.abstract import Executor
if TYPE_CHECKING:
from vllm.attention.layer import Attention
from vllm.v1.engine.coordinator import DPCoordinator
logger = init_logger(__name__)
T = TypeVar("T")
STARTUP_POLL_PERIOD_MS = 10000
class ConstantList(Generic[T], Sequence):
@ -95,6 +105,78 @@ class ConstantList(Generic[T], Sequence):
return f"ConstantList({self._x})"
def get_engine_client_zmq_addr(local_only: bool,
host: str,
port: int = 0) -> str:
return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
host, port or get_open_port()))
class APIServerProcessManager:
"""Manages a group of API server processes.
Handles creation, monitoring, and termination of API server worker
processes. Also monitors extra processes to check if they are healthy.
"""
def __init__(
self,
target_server_fn: Callable,
listen_address: str,
sock: Any,
args: argparse.Namespace,
num_servers: int,
input_addresses: list[str],
output_addresses: list[str],
stats_update_address: Optional[str] = None,
):
"""Initialize and start API server worker processes.
Args:
target_server_fn: Function to call for each API server process
listen_address: Address to listen for client connections
sock: Socket for client connections
args: Command line arguments
num_servers: Number of API server processes to start
input_addresses: Input addresses for each API server
output_addresses: Output addresses for each API server
stats_update_address: Optional stats update address
"""
self.listen_address = listen_address
self.sock = sock
self.args = args
# Start API servers
spawn_context = multiprocessing.get_context("spawn")
self.processes: list[BaseProcess] = []
for i, in_addr, out_addr in zip(range(num_servers), input_addresses,
output_addresses):
client_config = {
"input_address": in_addr,
"output_address": out_addr,
"client_index": i
}
if stats_update_address is not None:
client_config["stats_update_address"] = stats_update_address
proc = spawn_context.Process(target=target_server_fn,
name=f"ApiServer_{i}",
args=(listen_address, sock, args,
client_config))
self.processes.append(proc)
proc.start()
logger.info("Started %d API server processes", len(self.processes))
# Shutdown only the API server processes on garbage collection
# The extra processes are managed by their owners
self._finalizer = weakref.finalize(self, shutdown, self.processes)
def close(self) -> None:
self._finalizer()
class CoreEngineProcManager:
"""
Utility class to handle creation, readiness, and shutdown
@ -109,7 +191,7 @@ class CoreEngineProcManager:
local_start_index: int,
vllm_config: VllmConfig,
on_head_node: bool,
input_address: str,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
):
@ -117,12 +199,12 @@ class CoreEngineProcManager:
common_kwargs = {
"vllm_config": vllm_config,
"on_head_node": on_head_node,
"input_address": input_address,
"handshake_address": handshake_address,
"executor_class": executor_class,
"log_stats": log_stats,
}
self.processes: list[Process] = []
self.processes: list[BaseProcess] = []
for index in range(local_engine_count):
local_index = local_start_index + index
global_index = start_index + index
@ -135,8 +217,7 @@ class CoreEngineProcManager:
"local_dp_rank": local_index,
}))
self._finalizer = weakref.finalize(self, shutdown, self.processes,
input_address)
self._finalizer = weakref.finalize(self, shutdown, self.processes)
try:
for proc in self.processes:
proc.start()
@ -164,9 +245,199 @@ class CoreEngineProcManager:
}
class CoreEngineState(Enum):
NEW = auto()
CONNECTED = auto()
READY = auto()
class CoreEngine:
"""One per data parallel rank."""
def __init__(self, index: int = 0, local: bool = True):
self.local = local
self.index = index
self.identity = index.to_bytes(2, "little")
self.state = CoreEngineState.NEW
@dataclass
class EngineZmqAddresses:
# ZMQ input socket addresses for each front-end client (requests)
inputs: list[str]
# ZMQ output socket addresses for each front-end client (responses)
outputs: list[str]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input: Optional[str] = None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output: Optional[str] = None
@dataclass
class EngineHandshakeMetadata:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
"""
addresses: EngineZmqAddresses
parallel_config: dict[str, Union[int, str]]
def wait_for_engine_startup(
handshake_socket: zmq.Socket,
addresses: EngineZmqAddresses,
core_engines: list[CoreEngine],
parallel_config: ParallelConfig,
cache_config: CacheConfig,
proc_manager: Optional[CoreEngineProcManager],
coord_process: Optional[Process],
):
# Wait for engine core process(es) to send ready messages.
local_count = parallel_config.data_parallel_size_local
remote_count = len(core_engines) - local_count
# [local, remote] counts
conn_pending, start_pending = [local_count, remote_count], [0, 0]
poller = zmq.Poller()
poller.register(handshake_socket, zmq.POLLIN)
if proc_manager is not None:
for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
if coord_process is not None:
poller.register(coord_process.sentinel, zmq.POLLIN)
while any(conn_pending) or any(start_pending):
events = poller.poll(STARTUP_POLL_PERIOD_MS)
if not events:
if any(conn_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect.", *conn_pending)
if any(start_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to start.", *start_pending)
continue
if len(events) > 1 or events[0][0] != handshake_socket:
# One of the local core processes exited.
finished = proc_manager.finished_procs() if proc_manager else {}
if coord_process is not None and coord_process.exitcode is not None:
finished[coord_process.name] = coord_process.exitcode
raise RuntimeError("Engine core initialization failed. "
"See root cause above. "
f"Failed core proc(s): {finished}")
# Receive HELLO and READY messages from the input socket.
eng_identity, ready_msg_bytes = handshake_socket.recv_multipart()
eng_index = int.from_bytes(eng_identity, "little")
engine = next((e for e in core_engines if e.identity == eng_identity),
None)
if engine is None:
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"]
if local != engine.local:
raise RuntimeError(f"{status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}")
if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info.
init_message = msgspec.msgpack.encode(
EngineHandshakeMetadata(
addresses=addresses,
parallel_config={
"data_parallel_master_ip":
parallel_config.data_parallel_master_ip,
"data_parallel_master_port":
parallel_config.data_parallel_master_port,
"data_parallel_size":
parallel_config.data_parallel_size,
}))
handshake_socket.send_multipart((eng_identity, init_message),
copy=False)
conn_pending[0 if local else 1] -= 1
start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED
elif status == "READY" and (engine.state == CoreEngineState.CONNECTED):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks = cache_config.num_gpu_blocks or 0
num_gpu_blocks += msg["num_gpu_blocks"]
cache_config.num_gpu_blocks = num_gpu_blocks
start_pending[0 if local else 1] -= 1
engine.state = CoreEngineState.READY
else:
raise RuntimeError(f"Unexpected {status} message for "
f"{'local' if local else 'remote'} engine "
f"{eng_index} in {engine.state} state.")
logger.debug("%s from %s core engine process %s.", status,
"local" if local else "remote", eng_index)
def wait_for_completion_or_failure(
api_server_manager: APIServerProcessManager,
local_engine_manager: Optional[CoreEngineProcManager] = None,
coordinator: Optional["DPCoordinator"] = None) -> None:
"""Wait for all processes to complete or detect if any fail.
Raises an exception if any process exits with a non-zero status.
"""
try:
logger.info("Waiting for API servers to complete ...")
# Create a mapping of sentinels to their corresponding processes
# for efficient lookup
sentinel_to_proc: dict[Any, BaseProcess] = {
proc.sentinel: proc
for proc in api_server_manager.processes
}
if coordinator:
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
if local_engine_manager:
for proc in local_engine_manager.processes:
sentinel_to_proc[proc.sentinel] = proc
# Check if any process terminates
while sentinel_to_proc:
# Wait for any process to terminate
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc)
# Process any terminated processes
for sentinel in ready_sentinels:
proc = sentinel_to_proc.pop(sentinel)
# Check if process exited with error
if proc.exitcode != 0:
raise RuntimeError(
f"Process {proc.name} (PID: {proc.pid}) "
f"died with exit code {proc.exitcode}")
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, shutting down API servers...")
except Exception as e:
logger.exception("Exception occurred while running API servers: %s",
str(e))
raise
finally:
logger.info("Terminating remaining processes ...")
api_server_manager.close()
if coordinator:
coordinator.close()
if local_engine_manager:
local_engine_manager.close()
# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the objedecoupct.
def shutdown(procs: list[Process], input_address: str):
# else the gc cannot collect the object.
def shutdown(procs: list[BaseProcess]):
# Shutdown the process.
for proc in procs:
if proc.is_alive():
@ -185,12 +456,6 @@ def shutdown(procs: list[Process], input_address: str):
if proc.is_alive() and (pid := proc.pid) is not None:
kill_process_tree(pid)
# Remove zmq ipc socket files.
if input_address.startswith("ipc://"):
socket_file = input_address[len("ipc://"):]
if os and os.path.exists(socket_file):
os.remove(socket_file)
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],