diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 46785a8b3d500..bff2f69c17ba7 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -618,9 +618,11 @@ steps: - vllm/worker/model_runner.py - entrypoints/llm/test_collective_rpc.py - tests/v1/test_async_llm_dp.py + - tests/v1/entrypoints/openai/test_multi_api_servers.py - vllm/v1/engine/ commands: - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py + - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py diff --git a/tests/entrypoints/test_api_server_process_manager.py b/tests/entrypoints/test_api_server_process_manager.py new file mode 100644 index 0000000000000..0dd1fdd996948 --- /dev/null +++ b/tests/entrypoints/test_api_server_process_manager.py @@ -0,0 +1,268 @@ +# SPDX-License-Identifier: Apache-2.0 + +import multiprocessing +import socket +import threading +import time +from typing import Optional +from unittest.mock import patch + +import pytest + +from vllm.v1.utils import (APIServerProcessManager, + wait_for_completion_or_failure) + +# Global variables to control worker behavior +WORKER_RUNTIME_SECONDS = 0.5 + + +# Mock implementation of run_api_server_worker +def mock_run_api_server_worker(listen_address, sock, args, client_config=None): + """Mock run_api_server_worker that runs for a specific time.""" + print(f"Mock worker started with client_config: {client_config}") + time.sleep(WORKER_RUNTIME_SECONDS) + print("Mock worker completed successfully") + + +@pytest.fixture +def api_server_args(): + """Fixture to provide arguments for APIServerProcessManager.""" + sock = socket.socket() + return { + "target_server_fn": + mock_run_api_server_worker, + "listen_address": + "localhost:8000", + "sock": + sock, + "args": + "test_args", # Simple string to avoid pickling issues + "num_servers": + 3, + "input_addresses": [ + "tcp://127.0.0.1:5001", "tcp://127.0.0.1:5002", + "tcp://127.0.0.1:5003" + ], + "output_addresses": [ + "tcp://127.0.0.1:6001", "tcp://127.0.0.1:6002", + "tcp://127.0.0.1:6003" + ], + "stats_update_address": + "tcp://127.0.0.1:7000", + } + + +@pytest.mark.parametrize("with_stats_update", [True, False]) +def test_api_server_process_manager_init(api_server_args, with_stats_update): + """Test initializing the APIServerProcessManager.""" + # Set the worker runtime to ensure tests complete in reasonable time + global WORKER_RUNTIME_SECONDS + WORKER_RUNTIME_SECONDS = 0.5 + + # Copy the args to avoid mutating the + args = api_server_args.copy() + + if not with_stats_update: + args.pop("stats_update_address") + manager = APIServerProcessManager(**args) + + try: + # Verify the manager was initialized correctly + assert len(manager.processes) == 3 + + # Verify all processes are running + for proc in manager.processes: + assert proc.is_alive() + + print("Waiting for processes to run...") + time.sleep(WORKER_RUNTIME_SECONDS / 2) + + # They should still be alive at this point + for proc in manager.processes: + assert proc.is_alive() + + finally: + # Always clean up the processes + print("Cleaning up processes...") + manager.close() + + # Give processes time to terminate + time.sleep(0.2) + + # Verify all processes were terminated + for proc in manager.processes: + assert not proc.is_alive() + + +@patch("vllm.entrypoints.cli.serve.run_api_server_worker", + mock_run_api_server_worker) +def test_wait_for_completion_or_failure(api_server_args): + """Test that wait_for_completion_or_failure works with failures.""" + global WORKER_RUNTIME_SECONDS + WORKER_RUNTIME_SECONDS = 1.0 + + # Create the manager + manager = APIServerProcessManager(**api_server_args) + + try: + assert len(manager.processes) == 3 + + # Create a result capture for the thread + result: dict[str, Optional[Exception]] = {"exception": None} + + def run_with_exception_capture(): + try: + wait_for_completion_or_failure(api_server_manager=manager) + except Exception as e: + result["exception"] = e + + # Start a thread to run wait_for_completion_or_failure + wait_thread = threading.Thread(target=run_with_exception_capture, + daemon=True) + wait_thread.start() + + # Let all processes run for a short time + time.sleep(0.2) + + # All processes should still be running + assert all(proc.is_alive() for proc in manager.processes) + + # Now simulate a process failure + print("Simulating process failure...") + manager.processes[0].terminate() + + # Wait for the wait_for_completion_or_failure + # to detect and handle the failure + # This should trigger it to terminate all other processes + wait_thread.join(timeout=1.0) + + # The wait thread should have exited + assert not wait_thread.is_alive() + + # Verify that an exception was raised with appropriate error message + assert result["exception"] is not None + assert "died with exit code" in str(result["exception"]) + + # All processes should now be terminated + for i, proc in enumerate(manager.processes): + assert not proc.is_alive(), f"Process {i} should not be alive" + + finally: + manager.close() + time.sleep(0.2) + + +@pytest.mark.timeout(30) +def test_normal_completion(api_server_args): + """Test that wait_for_completion_or_failure works in normal completion.""" + global WORKER_RUNTIME_SECONDS + WORKER_RUNTIME_SECONDS = 0.1 + + # Create the manager + manager = APIServerProcessManager(**api_server_args) + + try: + # Give processes time to terminate + # wait for processes to complete + remaining_processes = manager.processes.copy() + while remaining_processes: + for proc in remaining_processes: + if not proc.is_alive(): + remaining_processes.remove(proc) + time.sleep(0.1) + + # Verify all processes have terminated + for i, proc in enumerate(manager.processes): + assert not proc.is_alive( + ), f"Process {i} still alive after terminate()" + + # Now call wait_for_completion_or_failure + # since all processes have already + # terminated, it should return immediately + # with no error + wait_for_completion_or_failure(api_server_manager=manager) + + finally: + # Clean up just in case + manager.close() + time.sleep(0.2) + + +@pytest.mark.timeout(30) +def test_external_process_monitoring(api_server_args): + """Test that wait_for_completion_or_failure handles additional processes.""" + global WORKER_RUNTIME_SECONDS + WORKER_RUNTIME_SECONDS = 100 + + # Create and start the external process + # (simulates local_engine_manager or coordinator) + spawn_context = multiprocessing.get_context("spawn") + external_proc = spawn_context.Process(target=mock_run_api_server_worker, + name="MockExternalProcess") + external_proc.start() + + # Create the class to simulate a coordinator + class MockCoordinator: + + def __init__(self, proc): + self.proc = proc + + def close(self): + if self.proc.is_alive(): + self.proc.terminate() + self.proc.join(timeout=0.5) + + # Create a mock coordinator with the external process + mock_coordinator = MockCoordinator(external_proc) + + # Create the API server manager + manager = APIServerProcessManager(**api_server_args) + + try: + # Verify manager initialization + assert len(manager.processes) == 3 + + # Create a result capture for the thread + result: dict[str, Optional[Exception]] = {"exception": None} + + def run_with_exception_capture(): + try: + wait_for_completion_or_failure(api_server_manager=manager, + coordinator=mock_coordinator) + except Exception as e: + result["exception"] = e + + # Start a thread to run wait_for_completion_or_failure + wait_thread = threading.Thread(target=run_with_exception_capture, + daemon=True) + wait_thread.start() + + # Terminate the external process to trigger a failure + time.sleep(0.2) + external_proc.terminate() + + # Wait for the thread to detect the failure + wait_thread.join(timeout=1.0) + + # The wait thread should have completed + assert not wait_thread.is_alive( + ), "wait_for_completion_or_failure thread still running" + + # Verify that an exception was raised with appropriate error message + assert result["exception"] is not None, "No exception was raised" + error_message = str(result["exception"]) + assert "died with exit code" in error_message, \ + f"Unexpected error message: {error_message}" + assert "MockExternalProcess" in error_message, \ + f"Error doesn't mention external process: {error_message}" + + # Verify that all API server processes were terminated as a result + for i, proc in enumerate(manager.processes): + assert not proc.is_alive( + ), f"API server process {i} was not terminated" + + finally: + # Clean up + manager.close() + mock_coordinator.close() + time.sleep(0.2) diff --git a/tests/utils.py b/tests/utils.py index bf38d7843853d..d21b18470b1bb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -28,7 +28,7 @@ from tests.models.utils import TextTextLogprobs from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.cli.serve import ServeSubcommand from vllm.model_executor.model_loader import get_model_loader from vllm.platforms import current_platform from vllm.transformers_utils.tokenizer import get_tokenizer @@ -99,7 +99,8 @@ class RemoteOpenAIServer: parser = FlexibleArgumentParser( description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) + subparsers = parser.add_subparsers(required=False, dest="subparser") + parser = ServeSubcommand().subparser_init(subparsers) args = parser.parse_args(["--model", model, *vllm_serve_args]) self.host = str(args.host or 'localhost') self.port = int(args.port) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 43a27da2dbe43..d3d62cf09232d 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -45,7 +45,6 @@ def make_request(request_id, multi_modal_placeholders=mm_positions, sampling_params=SamplingParams(max_tokens=17), eos_token_id=100, - arrival_time=0, lora_request=None, cache_salt=cache_salt, ) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 3da27786b1f2f..ba3c0b3cf3169 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -38,7 +38,6 @@ def make_request(request_id, sampling_params=SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs), eos_token_id=100, - arrival_time=0, lora_request=None, cache_salt=cache_salt, ) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index f40d477a00363..f38454b1b2889 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -138,7 +138,6 @@ def create_requests(num_requests: int, multi_modal_placeholders=mm_position, multi_modal_hashes=None, eos_token_id=EOS_TOKEN_ID, - arrival_time=0, ) requests.append(request) return requests @@ -744,7 +743,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i]) # No draft or accepted tokens counted yet - assert engine_core_outputs.scheduler_stats.spec_decoding_stats is None + assert not engine_core_outputs or ( + engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None) # Schedule the speculated tokens for validation output = scheduler.schedule() @@ -772,7 +772,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): engine_core_outputs = scheduler.update_from_output(output, model_runner_output) - scheduler_stats = engine_core_outputs.scheduler_stats + scheduler_stats = engine_core_outputs[0].scheduler_stats \ + if engine_core_outputs else None if expected[0] == 0: assert scheduler_stats.spec_decoding_stats is None else: @@ -843,7 +844,7 @@ def _step_until_done( # We should be in the decode phase now. assert num_scheduled_tokens == 1 assert len(output.kv_connector_metadata.requests) == 0 - ecos = scheduler.update_from_output(output, model_runner_output) + ecos = scheduler.update_from_output(output, model_runner_output)[0] all_done = True for eco in ecos.outputs: if eco.finish_reason is None: diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index ae1d8a762a8e1..e78c7480a837a 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -88,7 +88,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch): assert len(engine_core.scheduler.running) == 4 # Loop through until they are all done. - while len(engine_core.step()[0].outputs) > 0: + while (outs := engine_core.step()[0].get(0)) and outs.outputs: pass assert len(engine_core.scheduler.waiting) == 0 @@ -163,11 +163,11 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch): req0.request_id = req1.request_id = "test" engine_core.add_request(req0) - while len(engine_core.step()[0].outputs) > 0: + while (outs := engine_core.step()[0].get(0)) and outs.outputs: pass engine_core.add_request(req1) - while len(engine_core.step()[0].outputs) > 0: + while (outs := engine_core.step()[0].get(0)) and outs.outputs: pass assert len(engine_core.scheduler.waiting) == 0 @@ -207,7 +207,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch): assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.running) == 0 # Loop through until they are all done. - while len(engine_core.step()[0].outputs) > 0: + while (outs := engine_core.step()[0].get(0)) and outs.outputs: pass assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 0 @@ -327,7 +327,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): assert scheduler_output.num_scheduled_tokens[1] == 4 # Batch queue is full. Finish Batch 2. Get first token of req0. - output = engine_core.step_with_batch_queue()[0] + output = engine_core.step_with_batch_queue()[0].get(0) assert output is not None assert len(output.outputs) == 1 assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13 @@ -339,7 +339,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): assert scheduler_output.num_scheduled_tokens[0] == 1 # Batch queue is full. Finish Batch 3. Get first token of req1. - output = engine_core.step_with_batch_queue()[0] + output = engine_core.step_with_batch_queue()[0].get(0) assert output is not None assert len(output.outputs) == 1 assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13 @@ -362,7 +362,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): if step % 2 == 0: # Even steps consumes an output. assert output is not None - assert len(output.outputs) == 1 + assert len(output[0].outputs) == 1 if req_id in engine_core.scheduler.requests: assert engine_core.scheduler.requests[ req_id].num_tokens == expected_num_tokens[req_id] diff --git a/tests/v1/entrypoints/openai/test_multi_api_servers.py b/tests/v1/entrypoints/openai/test_multi_api_servers.py new file mode 100644 index 0000000000000..7b4583bc3bf37 --- /dev/null +++ b/tests/v1/entrypoints/openai/test_multi_api_servers.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import os + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "ibm-research/PowerMoE-3b" + +DP_SIZE = os.getenv("DP_SIZE", "1") + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager", + "--api-server-count", + "4", + "--data_parallel_size", + DP_SIZE, + ] + + +@pytest.fixture(scope="module") +def server(default_server_args): + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_single_completion(client: openai.AsyncOpenAI, + model_name: str) -> None: + + async def make_request(): + completion = await client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=10, + temperature=1.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + # The exact number of tokens can vary slightly with temperature=1.0, + # so we check for a reasonable minimum length. + assert len(choice.text) >= 1 + # Finish reason might not always be 'length' if the model finishes early + # or due to other reasons, especially with high temperature. + # So, we'll accept 'length' or 'stop'. + assert choice.finish_reason in ("length", "stop") + + # Token counts can also vary, so we check they are positive. + assert completion.usage.completion_tokens > 0 + assert completion.usage.prompt_tokens > 0 + assert completion.usage.total_tokens > 0 + return completion + + # Test single request + result = await make_request() + assert result is not None + + await asyncio.sleep(0.5) + + # Send two bursts of requests + num_requests = 100 + tasks = [make_request() for _ in range(num_requests)] + results = await asyncio.gather(*tasks) + assert len(results) == num_requests + assert all(completion is not None for completion in results) + + await asyncio.sleep(0.5) + + tasks = [make_request() for _ in range(num_requests)] + results = await asyncio.gather(*tasks) + assert len(results) == num_requests + assert all(completion is not None for completion in results) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_completion_streaming(client: openai.AsyncOpenAI, + model_name: str) -> None: + prompt = "What is an LLM?" + + async def make_streaming_request(): + # Perform a non-streaming request to get the expected full output + single_completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + ) + single_output = single_completion.choices[0].text + + # Perform the streaming request + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True) + chunks: list[str] = [] + finish_reason_count = 0 + last_chunk = None + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + last_chunk = chunk # Keep track of the last chunk + + # finish reason should only return in the last block for OpenAI API + assert finish_reason_count == 1, ( + "Finish reason should appear exactly once.") + assert last_chunk is not None, ( + "Stream should have yielded at least one chunk.") + assert last_chunk.choices[ + 0].finish_reason == "length", "Finish reason should be 'length'." + # Check that the combined text matches the non-streamed version. + assert "".join( + chunks + ) == single_output, "Streamed output should match non-streamed output." + return True # Indicate success for this request + + # Test single request + result = await make_streaming_request() + assert result is not None + + await asyncio.sleep(0.5) + + # Send two bursts of requests + num_requests = 100 + tasks = [make_streaming_request() for _ in range(num_requests)] + results = await asyncio.gather(*tasks) + + assert len( + results + ) == num_requests, f"Expected {num_requests} results, got {len(results)}" + assert all(results), "Not all streaming requests completed successfully." + + await asyncio.sleep(0.5) + + tasks = [make_streaming_request() for _ in range(num_requests)] + results = await asyncio.gather(*tasks) + + assert len( + results + ) == num_requests, f"Expected {num_requests} results, got {len(results)}" + assert all(results), "Not all streaming requests completed successfully." diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index 77098140343a0..dc963251c962b 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -43,7 +43,7 @@ def test_basic_lifecycle(): # Ensure the request is finished after 1 tokens. assert request.is_finished() assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED - output = engine_core_outputs.outputs[0] + output = engine_core_outputs[0].outputs[0] assert output.finish_reason == FinishReason.LENGTH assert output.kv_transfer_params is not None @@ -165,7 +165,7 @@ def test_prefix_cache_lifecycle(): scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output(reqs=[request_remote]) eco = scheduler.update_from_output(scheduler_output, model_runner_output) - kv_transfer_params = eco.outputs[0].kv_transfer_params + kv_transfer_params = eco[0].outputs[0].kv_transfer_params # Ensure we send all block ids, even if there is a cache hit. assert (len( diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index 6fcff0d620452..86eacb693869d 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -61,7 +61,7 @@ def test_basic_lifecycle(): # (1c): update_from_output() engine_core_outputs = scheduler.update_from_output(scheduler_output, model_runner_output) - assert len(engine_core_outputs.outputs) == 0 + assert not engine_core_outputs or not engine_core_outputs[0].outputs # STEP (2): # (2a): schedule(): nothing happens! @@ -112,7 +112,7 @@ def test_basic_lifecycle(): model_runner_output) scheduler.schedule() - outputs = engine_core_outputs.outputs + outputs = engine_core_outputs[0].outputs assert len(outputs) == 1 output = outputs[0] assert output.finish_reason == FinishReason.STOP @@ -335,7 +335,7 @@ def test_full_block_prompt(): model_runner_output) scheduler.schedule() - outputs = engine_core_outputs.outputs + outputs = engine_core_outputs[0].outputs assert len(outputs) == 1 output = outputs[0] assert output.finish_reason == FinishReason.STOP diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 53e2d6fda1aea..3c3190b325636 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -153,7 +153,6 @@ def create_request( multi_modal_placeholders=None, multi_modal_hashes=None, eos_token_id=EOS_TOKEN_ID, - arrival_time=0, ) req.kv_transfer_params = kv_transfer_params return req diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 957fec290bf26..e65c97073218b 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -1,24 +1,35 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import os import signal +import sys import uvloop +import zmq import vllm.envs as envs from vllm import AsyncEngineArgs from vllm.entrypoints.cli.types import CLISubcommand -from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.api_server import (run_server, run_server_worker, + setup_server) from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) from vllm.entrypoints.utils import (VLLM_SERVE_PARSER_EPILOG, show_filtered_argument_or_group_from_help) +from vllm.executor.multiproc_worker_utils import _add_prefix from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, get_tcp_uri +from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx +from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.core_client import CoreEngineProcManager from vllm.v1.executor.abstract import Executor +from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus +from vllm.v1.utils import (APIServerProcessManager, CoreEngine, + EngineZmqAddresses, get_engine_client_zmq_addr, + wait_for_completion_or_failure, + wait_for_engine_startup) logger = init_logger(__name__) @@ -36,9 +47,12 @@ class ServeSubcommand(CLISubcommand): if hasattr(args, 'model_tag') and args.model_tag is not None: args.model = args.model_tag - if args.headless: + if args.headless or args.api_server_count < 1: run_headless(args) + elif args.api_server_count > 1: + run_multi_api_server(args) else: + # Single API server (this process). uvloop.run(run_server(args)) def validate(self, args: argparse.Namespace) -> None: @@ -69,6 +83,11 @@ class ServeSubcommand(CLISubcommand): type=int, default=0, help='Starting data parallel rank for secondary nodes.') + serve_parser.add_argument('--api-server-count', + '-asc', + type=int, + default=1, + help='How many API server processes to run.') serve_parser.add_argument( "--config", type=str, @@ -91,23 +110,26 @@ def cmd_init() -> list[CLISubcommand]: def run_headless(args: argparse.Namespace): + if args.api_server_count > 1: + raise ValueError("api_server_count can't be set in headless mode") + # Create the EngineConfig. engine_args = AsyncEngineArgs.from_cli_args(args) usage_context = UsageContext.OPENAI_API_SERVER vllm_config = engine_args.create_engine_config(usage_context=usage_context) if not envs.VLLM_USE_V1: - raise RuntimeError("Headless mode is only supported for V1") + raise ValueError("Headless mode is only supported for V1") parallel_config = vllm_config.parallel_config local_engine_count = parallel_config.data_parallel_size_local host = parallel_config.data_parallel_master_ip port = engine_args.data_parallel_rpc_port # add to config too - input_address = get_tcp_uri(host, port) + handshake_address = get_tcp_uri(host, port) if local_engine_count <= 0: - raise RuntimeError("data_parallel_size_local must be > 0 in " - "headless mode") + raise ValueError("data_parallel_size_local must be > 0 in " + "headless mode") # Catch SIGTERM and SIGINT to allow graceful shutdown. def signal_handler(signum, frame): @@ -119,7 +141,7 @@ def run_headless(args: argparse.Namespace): logger.info( "Launching %d data parallel engine(s) in headless mode, " - "with head node address %s.", local_engine_count, input_address) + "with head node address %s.", local_engine_count, handshake_address) # Create the engines. engine_manager = CoreEngineProcManager( @@ -129,7 +151,7 @@ def run_headless(args: argparse.Namespace): local_start_index=0, vllm_config=vllm_config, on_head_node=False, - input_address=input_address, + handshake_address=handshake_address, executor_class=Executor.get_class(vllm_config), log_stats=not engine_args.disable_log_stats, ) @@ -139,3 +161,142 @@ def run_headless(args: argparse.Namespace): finally: logger.info("Shutting down.") engine_manager.close() + + +def run_multi_api_server(args: argparse.Namespace): + + assert not args.headless + num_api_servers = args.api_server_count + assert num_api_servers > 0 + + if num_api_servers > 1: + setup_multiprocess_prometheus() + + listen_address, sock = setup_server(args) + + engine_args = AsyncEngineArgs.from_cli_args(args) + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = engine_args.create_engine_config(usage_context=usage_context) + model_config = vllm_config.model_config + + if num_api_servers > 1: + if not envs.VLLM_USE_V1: + raise ValueError("api_server_count > 1 is only supported for V1") + + if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used " + "with api_server_count > 1") + + if model_config.is_multimodal_model and not ( + model_config.disable_mm_preprocessor_cache): + logger.warning( + "Multi-model preprocessor cache will be disabled for" + " api_server_count > 1") + model_config.disable_mm_preprocessor_cache = True + + parallel_config = vllm_config.parallel_config + + assert parallel_config.data_parallel_rank == 0 + + dp_size = parallel_config.data_parallel_size + local_engine_count = parallel_config.data_parallel_size_local + host = parallel_config.data_parallel_master_ip + local_only = local_engine_count == dp_size + + # Set up input and output addresses. + input_addresses = [ + get_engine_client_zmq_addr(local_only, host) + for _ in range(num_api_servers) + ] + output_addresses = [ + get_engine_client_zmq_addr(local_only, host) + for _ in range(num_api_servers) + ] + + addresses = EngineZmqAddresses( + inputs=input_addresses, + outputs=output_addresses, + ) + + # Set up coordinator for dp > 1. + coordinator = None + stats_update_address = None + if dp_size > 1: + coordinator = DPCoordinator(parallel_config) + addresses.coordinator_input, addresses.coordinator_output = ( + coordinator.get_engine_socket_addresses()) + stats_update_address = coordinator.get_stats_publish_address() + logger.info("Started DP Coordinator process (PID: %d)", + coordinator.proc.pid) + + handshake_address = get_engine_client_zmq_addr( + local_only, host, parallel_config.data_parallel_rpc_port) + + with zmq_socket_ctx(handshake_address, zmq.ROUTER, + bind=True) as handshake_socket: + + # Start local engines. + if not local_engine_count: + local_engine_manager = None + else: + local_engine_manager = CoreEngineProcManager( + EngineCoreProc.run_engine_core, + vllm_config=vllm_config, + executor_class=Executor.get_class(vllm_config), + log_stats=not engine_args.disable_log_stats, + handshake_address=handshake_address, + on_head_node=True, + local_engine_count=local_engine_count, + start_index=0, + local_start_index=0) + + # Start API servers using the manager + api_server_manager = APIServerProcessManager( + target_server_fn=run_api_server_worker_proc, + listen_address=listen_address, + sock=sock, + args=args, + num_servers=num_api_servers, + input_addresses=input_addresses, + output_addresses=output_addresses, + stats_update_address=stats_update_address) + + # Wait for engine handshakes to complete. + core_engines = [ + CoreEngine(index=i, local=(i < local_engine_count)) + for i in range(dp_size) + ] + wait_for_engine_startup( + handshake_socket, + addresses, + core_engines, + parallel_config, + vllm_config.cache_config, + local_engine_manager, + coordinator.proc if coordinator else None, + ) + + # Wait for API servers + wait_for_completion_or_failure( + api_server_manager=api_server_manager, + local_engine_manager=local_engine_manager, + coordinator=coordinator) + + +def run_api_server_worker_proc(listen_address, + sock, + args, + client_config=None, + **uvicorn_kwargs) -> None: + """Entrypoint for individual API server worker processes.""" + + # Add process-specific prefix to stdout and stderr. + from multiprocessing import current_process + process_name = current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + uvloop.run( + run_server_worker(listen_address, sock, args, client_config, + **uvicorn_kwargs)) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b991cb3a444bc..1e7f88a6a2796 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -17,7 +17,7 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus from json import JSONDecodeError -from typing import Annotated, Optional +from typing import Annotated, Any, Optional import prometheus_client import regex as re @@ -26,6 +26,8 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse +from prometheus_client import make_asgi_app +from prometheus_fastapi_instrumentator import Instrumentator from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import State from starlette.routing import Mount @@ -97,6 +99,7 @@ from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path, is_valid_ipv6_address, set_ulimit) +from vllm.v1.metrics.prometheus import get_prometheus_registry from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -142,14 +145,17 @@ async def lifespan(app: FastAPI): @asynccontextmanager async def build_async_engine_client( - args: Namespace) -> AsyncIterator[EngineClient]: + args: Namespace, + client_config: Optional[dict[str, Any]] = None, +) -> AsyncIterator[EngineClient]: # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit engine_args = AsyncEngineArgs.from_cli_args(args) async with build_async_engine_client_from_engine_args( - engine_args, args.disable_frontend_multiprocessing) as engine: + engine_args, args.disable_frontend_multiprocessing, + client_config) as engine: yield engine @@ -157,6 +163,7 @@ async def build_async_engine_client( async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, + client_config: Optional[dict[str, Any]] = None, ) -> AsyncIterator[EngineClient]: """ Create EngineClient, either: @@ -179,12 +186,16 @@ async def build_async_engine_client_from_engine_args( from vllm.v1.engine.async_llm import AsyncLLM async_llm: Optional[AsyncLLM] = None + client_index = client_config.pop( + "client_index") if client_config else 0 try: async_llm = AsyncLLM.from_vllm_config( vllm_config=vllm_config, usage_context=usage_context, disable_log_requests=engine_args.disable_log_requests, - disable_log_stats=engine_args.disable_log_stats) + disable_log_stats=engine_args.disable_log_stats, + client_addresses=client_config, + client_index=client_index) # Don't keep the dummy data in memory await async_llm.reset_mm_cache() @@ -318,22 +329,9 @@ class PrometheusResponse(Response): def mount_metrics(app: FastAPI): - # Lazy import for prometheus multiprocessing. - # We need to set PROMETHEUS_MULTIPROC_DIR environment variable - # before prometheus_client is imported. - # See https://prometheus.github.io/client_python/multiprocess/ - from prometheus_client import (REGISTRY, CollectorRegistry, make_asgi_app, - multiprocess) - from prometheus_fastapi_instrumentator import Instrumentator + """Mount prometheus metrics to a FastAPI app.""" - registry = REGISTRY - - prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) - if prometheus_multiproc_dir_path is not None: - logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", - prometheus_multiproc_dir_path) - registry = CollectorRegistry() - multiprocess.MultiProcessCollector(registry) + registry = get_prometheus_registry() # `response_class=PrometheusResponse` is needed to return an HTTP response # with header "Content-Type: text/plain; version=0.0.4; charset=utf-8" @@ -1256,16 +1254,10 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket: return sock -async def run_server(args, **uvicorn_kwargs) -> None: - logger.info("vLLM API server version %s", VLLM_VERSION) - log_non_default_args(args) - - if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: - ToolParserManager.import_tool_parser(args.tool_parser_plugin) - +def validate_api_server_args(args): valid_tool_parses = ToolParserManager.tool_parsers.keys() if args.enable_auto_tool_choice \ - and args.tool_call_parser not in valid_tool_parses: + and args.tool_call_parser not in valid_tool_parses: raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " f"(chose from {{ {','.join(valid_tool_parses)} }})") @@ -1276,6 +1268,19 @@ async def run_server(args, **uvicorn_kwargs) -> None: f"invalid reasoning parser: {args.reasoning_parser} " f"(chose from {{ {','.join(valid_reasoning_parses)} }})") + +def setup_server(args): + """Validate API server args, set up signal handler, create socket + ready to serve.""" + + logger.info("vLLM API server version %s", VLLM_VERSION) + log_non_default_args(args) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + validate_api_server_args(args) + # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. # see https://github.com/vllm-project/vllm/issues/8204 @@ -1292,22 +1297,41 @@ async def run_server(args, **uvicorn_kwargs) -> None: signal.signal(signal.SIGTERM, signal_handler) - async with build_async_engine_client(args) as engine_client: + addr, port = sock_addr + is_ssl = args.ssl_keyfile and args.ssl_certfile + host_part = f"[{addr}]" if is_valid_ipv6_address( + addr) else addr or "0.0.0.0" + listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" + + return listen_address, sock + + +async def run_server(args, **uvicorn_kwargs) -> None: + """Run a single-worker API server.""" + listen_address, sock = setup_server(args) + await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) + + +async def run_server_worker(listen_address, + sock, + args, + client_config=None, + **uvicorn_kwargs) -> None: + """Run a single API server worker.""" + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + server_index = client_config.get("client_index", 0) if client_config else 0 + + async with build_async_engine_client(args, client_config) as engine_client: app = build_app(args) vllm_config = await engine_client.get_vllm_config() await init_app_state(engine_client, vllm_config, app.state, args) - def _listen_addr(a: str) -> str: - if is_valid_ipv6_address(a): - return '[' + a + ']' - return a or "0.0.0.0" - - is_ssl = args.ssl_keyfile and args.ssl_certfile - logger.info("Starting vLLM API server on http%s://%s:%d", - "s" if is_ssl else "", _listen_addr(sock_addr[0]), - sock_addr[1]) - + logger.info("Starting vLLM API server %d on %s", server_index, + listen_address) shutdown_task = await serve_http( app, sock=sock, diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index afc8a8dc3b260..f1ae030975074 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -229,6 +229,11 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): self.add_adapter(lora) def add_adapter(self, lora_request: LoRARequest) -> bool: + # Note that this method is not thread-safe. It may be invoked multiple + # times for the same adapter when using multiple API servers. + # This is ok because it's currently only called from + # the single-threaded core engine loop. + if lora_request.lora_int_id not in self.list_adapters(): # Load the new adapter first to ensure it is actually valid, before # evicting any existing adapters. diff --git a/vllm/utils.py b/vllm/utils.py index 25e34446a1cb2..65d3579d5e650 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2420,6 +2420,7 @@ def make_zmq_socket( socket_type: Any, bind: Optional[bool] = None, identity: Optional[bytes] = None, + linger: Optional[int] = None, ) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] """Make a ZMQ socket with the proper bind/connect semantics.""" @@ -2439,7 +2440,7 @@ def make_zmq_socket( buf_size = -1 # Use system default buffer size if bind is None: - bind = socket_type != zmq.PUSH + bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB) if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER): socket.setsockopt(zmq.RCVHWM, 0) @@ -2452,6 +2453,9 @@ def make_zmq_socket( if identity is not None: socket.setsockopt(zmq.IDENTITY, identity) + if linger is not None: + socket.setsockopt(zmq.LINGER, linger) + # Determine if the path is a TCP socket with an IPv6 address. # Enable IPv6 on the zmq socket if so. scheme, host, _ = split_zmq_path(path) diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index c17f80b6ae78a..055ce446051ef 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -45,7 +45,7 @@ class SchedulerInterface(ABC): self, scheduler_output: "SchedulerOutput", model_runner_output: "ModelRunnerOutput", - ) -> "EngineCoreOutputs": + ) -> dict[int, "EngineCoreOutputs"]: """Update the scheduler state based on the model runner output. This method is called after the model runner has processed the scheduled @@ -55,7 +55,8 @@ class SchedulerInterface(ABC): for each request. Returns: - A EngineCoreOutputs object containing the outputs for each request. + A dict of client index to EngineCoreOutputs object containing the + outputs for each request originating from that client. """ raise NotImplementedError @@ -126,6 +127,11 @@ class SchedulerInterface(ABC): """ raise NotImplementedError + @abstractmethod + def get_request_counts(self) -> tuple[int, int]: + """Returns (num_running_reqs, num_waiting_reqs).""" + raise NotImplementedError + @abstractmethod def make_stats(self) -> Optional["SchedulerStats"]: """Make a SchedulerStats object for logging. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4c6b3eea0cb75..ce16a1ed5a096 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -58,7 +58,8 @@ class Scheduler(SchedulerInterface): # request ids should be included in the EngineCoreOutputs returned # by update_from_outputs(). This is currently used in the multi-engine # case to track request lifetimes efficiently. - self.include_finished_set = include_finished_set + self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( + defaultdict(set) if include_finished_set else None) # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs @@ -693,7 +694,7 @@ class Scheduler(SchedulerInterface): self, scheduler_output: SchedulerOutput, model_runner_output: ModelRunnerOutput, - ) -> EngineCoreOutputs: + ) -> dict[int, EngineCoreOutputs]: sampled_token_ids = model_runner_output.sampled_token_ids spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs @@ -701,7 +702,7 @@ class Scheduler(SchedulerInterface): num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: list[Request] = [] - outputs: list[EngineCoreOutput] = [] + outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below @@ -797,7 +798,7 @@ class Scheduler(SchedulerInterface): if new_token_ids or kv_transfer_params: # Add EngineCoreOutput for this Request. - outputs.append( + outputs[request.client_index].append( EngineCoreOutput( request_id=req_id, new_token_ids=new_token_ids, @@ -828,17 +829,38 @@ class Scheduler(SchedulerInterface): self._cached_reqs_data[req_data.req_id].append(req_data) self.running = new_running - engine_core_outputs = EngineCoreOutputs( - outputs=outputs, - scheduler_stats=self.make_stats(spec_decoding_stats), - ) - if self.include_finished_set: - #TODO currently sending duplicates here, improve this - engine_core_outputs.finished_requests = ( - scheduler_output.finished_req_ids | self.finished_req_ids) + + # Create EngineCoreOutputs for all clients that have requests with + # outputs in this step. + engine_core_outputs = { + client_index: EngineCoreOutputs(outputs=outs) + for client_index, outs in outputs.items() + } + + finished_req_ids = self.finished_req_ids_dict + if finished_req_ids is not None: + # Include ids of requests that finished since last outputs + # were sent. + for client_index, finished_set in finished_req_ids.items(): + # Set finished request set in EngineCoreOutputs for this client. + if (eco := engine_core_outputs.get(client_index)) is not None: + eco.finished_requests = finished_set + else: + engine_core_outputs[client_index] = EngineCoreOutputs( + finished_requests=finished_set) + finished_req_ids.clear() + + if engine_core_outputs: + # Return stats to only one of the front-ends. + next(iter(engine_core_outputs.values())).scheduler_stats = ( + self.make_stats(spec_decoding_stats)) return engine_core_outputs + def get_request_counts(self) -> tuple[int, int]: + """Returns (num_running_reqs, num_waiting_reqs).""" + return len(self.running), len(self.waiting) + def add_request(self, request: Request) -> None: self.waiting.append(request) self.requests[request.request_id] = request @@ -880,8 +902,11 @@ class Scheduler(SchedulerInterface): delay_free_blocks, kv_xfer_params = self._connector_finished(request) self.encoder_cache_manager.free(request) - self._cached_reqs_data.pop(request.request_id, None) - self.finished_req_ids.add(request.request_id) + request_id = request.request_id + self._cached_reqs_data.pop(request_id, None) + self.finished_req_ids.add(request_id) + if self.finished_req_ids_dict is not None: + self.finished_req_ids_dict[request.client_index].add(request_id) if not delay_free_blocks: self._free_blocks(request) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 41db99beaad5e..0c9f61a764279 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -44,10 +44,6 @@ class EngineCoreRequest( omit_defaults=True, # type: ignore[call-arg] gc=False): # type: ignore[call-arg] - # NOTE: prompt and prompt_token_ids should be DecoderOnlyInput, - # but this object is currently not playing well with msgspec - # due to circular imports and typing we have in data.py - request_id: str prompt_token_ids: list[int] mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] @@ -59,6 +55,10 @@ class EngineCoreRequest( lora_request: Optional[LoRARequest] cache_salt: Optional[str] + # Index of the client, used to ensure outputs are sent back to the same + # client for this request when scaling out the front-end. + client_index: int = 0 + # Used in DP case to indicate which wave of requests this is expected to # belong to, to cover a race condition where the request is sent before # a wave finished notification is received. diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 74c2251c75214..86781e7528fa3 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -36,6 +36,7 @@ from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory, setup_default_loggers) +from vllm.v1.metrics.prometheus import shutdown_prometheus from vllm.v1.metrics.stats import IterationStats, SchedulerStats logger = init_logger(__name__) @@ -54,6 +55,8 @@ class AsyncLLM(EngineClient): log_requests: bool = True, start_engine_loop: bool = True, stat_loggers: Optional[list[StatLoggerFactory]] = None, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0, ) -> None: """ Create an AsyncLLM. @@ -124,6 +127,8 @@ class AsyncLLM(EngineClient): vllm_config=vllm_config, executor_class=executor_class, log_stats=self.log_stats, + client_addresses=client_addresses, + client_index=client_index, ) if self.stat_loggers: for stat_logger in self.stat_loggers[0]: @@ -145,6 +150,8 @@ class AsyncLLM(EngineClient): stat_loggers: Optional[list[StatLoggerFactory]] = None, disable_log_requests: bool = False, disable_log_stats: bool = False, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0, ) -> "AsyncLLM": if not envs.VLLM_USE_V1: raise ValueError( @@ -162,6 +169,8 @@ class AsyncLLM(EngineClient): log_requests=not disable_log_requests, log_stats=not disable_log_stats, usage_context=usage_context, + client_addresses=client_addresses, + client_index=client_index, ) @classmethod @@ -195,6 +204,8 @@ class AsyncLLM(EngineClient): def shutdown(self): """Shutdown, cleaning up the background proc and IPC.""" + shutdown_prometheus() + if engine_core := getattr(self, "engine_core", None): engine_core.shutdown() @@ -398,7 +409,6 @@ class AsyncLLM(EngineClient): # TODO(rob): make into a coroutine and launch it in # background thread once Prometheus overhead is non-trivial. if stat_loggers: - assert outputs.scheduler_stats is not None AsyncLLM._record_stats( stat_loggers[outputs.engine_index], scheduler_stats=outputs.scheduler_stats, @@ -422,7 +432,7 @@ class AsyncLLM(EngineClient): @staticmethod def _record_stats( stat_loggers: list[StatLoggerBase], - scheduler_stats: SchedulerStats, + scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], ): """static so that it can be used from the output_handler task diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py new file mode 100644 index 0000000000000..b84d4b144b5f2 --- /dev/null +++ b/vllm/v1/engine/coordinator.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 +import multiprocessing +import time +import weakref +from typing import Optional + +import msgspec.msgpack +import zmq + +from vllm.config import ParallelConfig +from vllm.logger import init_logger +from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket +from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType +from vllm.v1.serial_utils import MsgpackDecoder +from vllm.v1.utils import get_engine_client_zmq_addr, shutdown + +logger = init_logger(__name__) + + +class DPCoordinator: + """Coordinator process used for data-parallel deployments (DP>1). + + Intermediates between multiple DP engine rank processes and one or more + front-end API server processes. + + * Collects stats from each DP engine (currently just waiting and running + queue lengths), and publishes these to all front-ends for use in + load-balancing decisions. + + * Keeps track of the current DP "request wave" number and running state + of the engines. This is received from the DP rank 0 engine and published + to the front-end processes along with the current load stats. + + The engines alternate between a global running/paused state. The global + "request wave" number is a count of the number of times that the workers + collectively move from a running state to a paused state. This transition + is synchronized via the all-reduce operation performed in the + DPEngineCoreProc._has_global_unfinished_reqs method. + + * Broadcasts the START_DP_WAVE message to engines to move them from paused + to running state when one engine receives a new request. This can happen + in two cases: + 1) A front-end sending a new request while the engines are paused will + concurrently notify the coordinator. + 2) An engine receiving a request for a stale request wave while in paused + state will notify the coordinator. + + Engines will move into running state when receiving a new request or + START_DP_WAVE message. + """ + + def __init__(self, parallel_config: ParallelConfig): + + # Assume coordinator is colocated with front-end procs. + front_publish_address = get_open_zmq_ipc_path() + + dp_size = parallel_config.data_parallel_size + assert dp_size > 1, "Coordinator only used for data parallel" + + local_only = dp_size == parallel_config.data_parallel_size_local + host = parallel_config.data_parallel_master_ip + back_publish_address = get_engine_client_zmq_addr(local_only, host) + back_output_address = get_engine_client_zmq_addr(local_only, host) + + context = get_mp_context() + self.proc: multiprocessing.Process = context.Process( + target=CoordinatorProc.run_coordinator, + name="VLLM_DP_Coordinator", + kwargs={ + "engine_count": parallel_config.data_parallel_size, + "front_publish_address": front_publish_address, + "back_output_address": back_output_address, + "back_publish_address": back_publish_address, + }, + daemon=True) + self.proc.start() + + self.stats_publish_address = front_publish_address + self.coord_in_address = back_publish_address + self.coord_out_address = back_output_address + self._finalizer = weakref.finalize(self, shutdown, [self.proc]) + + def get_stats_publish_address(self) -> str: + return self.stats_publish_address + + def get_engine_socket_addresses(self) -> tuple[str, str]: + """Returns tuple of ZMQ input address, output address.""" + return self.coord_in_address, self.coord_out_address + + def close(self): + self._finalizer() + + +class EngineState: + + def __init__(self): + self.request_counts = [0, 0] # [waiting, running] + + +class CoordinatorProc: + + def __init__(self, engine_count: int): + + self.ctx = zmq.Context() + + self.engines = [EngineState() for _ in range(engine_count)] + + self.current_wave = 0 + self.engines_running = False + self.stats_changed = False + + @staticmethod + def run_coordinator( + engine_count: int, + front_publish_address: str, + back_output_address: str, + back_publish_address: str, + ): + coordinator = CoordinatorProc(engine_count=engine_count) + try: + coordinator.process_input_socket( + front_publish_address, + back_output_address, + back_publish_address, + ) + except KeyboardInterrupt: + logger.info("DP Coordinator process exiting") + + def process_input_socket(self, front_publish_address: str, + back_output_address: str, + back_publish_address: str): + + decoder = MsgpackDecoder(EngineCoreOutputs) + + with make_zmq_socket( + path=front_publish_address, # IPC + ctx=self.ctx, + socket_type=zmq.XPUB, + bind=True, + ) as publish_front, make_zmq_socket( + path=back_output_address, # IPC or TCP + ctx=self.ctx, + socket_type=zmq.PULL, + bind=True, + ) as output_back, make_zmq_socket( + path=back_publish_address, # IPC or TCP + ctx=self.ctx, + socket_type=zmq.XPUB, + bind=True, + ) as publish_back: + + poller = zmq.Poller() + poller.register(publish_front, zmq.POLLIN) + poller.register(output_back, zmq.POLLIN) + last_publish_time = 0 + while True: + elapsed = int(time.time() * 1000) - last_publish_time + # Send at 100 ms interval if the stats have changed, + # or otherwise every 3 seconds. + wait_for = 100 if self.stats_changed else 3000 + events = poller.poll(timeout=max(0, wait_for - elapsed)) + if not events: + # Poller timeout - publish current stats to front-ends. + engine_req_counts_list = self._get_engine_counts() + to_publish = (engine_req_counts_list, self.current_wave, + self.engines_running) + publish_front.send(msgspec.msgpack.encode(to_publish)) + last_publish_time = int(time.time() * 1000) + self.stats_changed = False + continue + + events = dict(events) + + if publish_front in events: + buffer = publish_front.recv() + if buffer == b'\x01': + # Ignore subscription messages. + continue + + # We received a message on the front-end XPUB socket, + # from an API server sending a new request while the + # engines are paused, so that we can wake the other + # engines. + engine_to_exclude, wave = msgspec.msgpack.decode(buffer) + if wave < self.current_wave: + # If the wave number is stale, ensure the message is + # handled by all the engines. + engine_to_exclude = None + if not self.engines_running: + self.engines_running = True + self.stats_changed = True + self._send_start_wave(publish_back, self.current_wave, + engine_to_exclude) + + if output_back in events: + # We received a message from one of the engines. + + buffer = output_back.recv() + outputs: EngineCoreOutputs = decoder.decode(buffer) + + assert not outputs.outputs + assert outputs.utility_output is None + + eng_index = outputs.engine_index + if outputs.scheduler_stats: + # 1. Updated request load stats - update our local + # state with these. + stats = self.engines[eng_index].request_counts + stats[0] = outputs.scheduler_stats.num_waiting_reqs + stats[1] = outputs.scheduler_stats.num_running_reqs + self.stats_changed = True + + if (wave := outputs.wave_complete) is not None: + # 2. Notification from rank 0 engine that we've + # moved into the global paused state + # (engines_running==False) + if self.current_wave <= wave: + logger.debug("Moving DP wave from %d to %d.", + self.current_wave, wave) + self.current_wave = wave + 1 + self.engines_running = False + self.stats_changed = True + elif (wave := outputs.start_wave) is not None and ( + wave > self.current_wave or + (wave == self.current_wave + and not self.engines_running)): + # 3. The engine received request for a non-current wave + # so we must ensure that other engines progress to the + # next wave (race condition handling). + logger.debug( + "Starting wave %d after notification of " + "stale wave request from engine.", wave) + self.current_wave = wave + self.engines_running = True + self.stats_changed = True + self._send_start_wave(publish_back, wave, eng_index) + + @staticmethod + def _send_start_wave(socket: zmq.Socket, wave: int, + exclude_engine_index: Optional[int]): + """Broadcast the START_DP_WAVE message to all the engines. + It includes the current wave number and index of engine which + has already received a request with this wave number and so doesn't + require additional notification. + """ + wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index)) + socket.send_multipart( + (EngineCoreRequestType.START_DP_WAVE.value, wave_encoded)) + + def _get_engine_counts(self) -> list[list[int]]: + """Return list of [waiting, running] count lists for each engine.""" + return [e.request_counts for e in self.engines] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ed71d9b671096..a02abb62b1f36 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -7,6 +7,7 @@ import threading import time from collections import deque from concurrent.futures import Future +from contextlib import ExitStack from inspect import isclass, signature from logging import DEBUG from typing import Any, Callable, Optional, TypeVar, Union @@ -22,7 +23,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx +from vllm.utils import make_zmq_socket, resolve_obj_by_qualname from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface @@ -33,10 +34,12 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -211,7 +214,7 @@ class EngineCore: # Re-raise exception raise err - def step(self) -> tuple[EngineCoreOutputs, bool]: + def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: """Schedule, execute, and make output. Returns tuple of outputs and a flag indicating whether the model @@ -221,10 +224,7 @@ class EngineCore: # Check for any requests remaining in the scheduler - unfinished, # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): - return EngineCoreOutputs( - outputs=[], - scheduler_stats=self.scheduler.make_stats(), - ), False + return {}, False scheduler_output = self.scheduler.schedule() model_output = self.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( @@ -234,7 +234,7 @@ class EngineCore: scheduler_output.total_num_scheduled_tokens > 0) def step_with_batch_queue( - self) -> tuple[Optional[EngineCoreOutputs], bool]: + self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]: """Schedule and execute batches with the batch queue. Note that if nothing to output in this step, None is returned. @@ -276,8 +276,8 @@ class EngineCore: # Blocking until the first result is available. model_output = future.result() self.batch_queue.task_done() - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output) + engine_core_outputs = (self.scheduler.update_from_output( + scheduler_output, model_output)) return engine_core_outputs, scheduled_batch @@ -362,7 +362,7 @@ class EngineCoreProc(EngineCore): self, vllm_config: VllmConfig, on_head_node: bool, - input_address: str, + handshake_address: str, executor_class: type[Executor], log_stats: bool, engine_index: int = 0, @@ -375,65 +375,70 @@ class EngineCoreProc(EngineCore): # Create input socket. input_ctx = zmq.Context() identity = engine_index.to_bytes(length=2, byteorder="little") - input_socket = make_zmq_socket(input_ctx, - input_address, - zmq.DEALER, - identity=identity, - bind=False) - try: + with make_zmq_socket(input_ctx, + handshake_address, + zmq.DEALER, + identity=identity, + linger=5000, + bind=False) as handshake_socket: + # Register engine with front-end. - output_address = self.startup_handshake( - input_socket, on_head_node, vllm_config.parallel_config) + addresses = self.startup_handshake(handshake_socket, on_head_node, + vllm_config.parallel_config) + self.client_count = len(addresses.outputs) # Update config which may have changed from the handshake. vllm_config.__post_init__() # Set up data parallel environment. + self.has_coordinator = addresses.coordinator_output is not None self._init_data_parallel(vllm_config) # Initialize engine core and model. super().__init__(vllm_config, executor_class, log_stats, executor_fail_callback) + self.engine_index = engine_index self.step_fn = (self.step if self.batch_queue is None else self.step_with_batch_queue) self.engines_running = False + self.last_counts = (0, 0) # Send ready message. num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks - input_socket.send( + handshake_socket.send( msgspec.msgpack.encode({ "status": "READY", "local": on_head_node, "num_gpu_blocks": num_gpu_blocks, })) - # Background Threads and Queues for IO. These enable us to - # overlap ZMQ socket IO with GPU since they release the GIL, - # and to overlap some serialization/deserialization with the - # model forward pass. - # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue = input_queue - self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]() - threading.Thread(target=self.process_input_socket, - args=(input_socket, ), - daemon=True).start() - input_socket = None - self.output_thread = threading.Thread( - target=self.process_output_socket, - args=(output_address, engine_index), - daemon=True) - self.output_thread.start() - finally: - if input_socket is not None: - input_socket.close(linger=0) + # Background Threads and Queues for IO. These enable us to + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. + # Threads handle Socket <-> Queues and core_busy_loop uses Queue. + self.input_queue = input_queue + self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], + bytes]]() + threading.Thread(target=self.process_input_sockets, + args=(addresses.inputs, addresses.coordinator_input, + identity), + daemon=True).start() + self.output_thread = threading.Thread( + target=self.process_output_sockets, + args=(addresses.outputs, addresses.coordinator_output, + engine_index), + daemon=True) + self.output_thread.start() @staticmethod - def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, - parallel_config: ParallelConfig) -> str: + def startup_handshake( + handshake_socket: zmq.Socket, on_head_node: bool, + parallel_config: ParallelConfig) -> EngineZmqAddresses: # Send registration message. - input_socket.send( + handshake_socket.send( msgspec.msgpack.encode({ "status": "HELLO", "local": on_head_node, @@ -441,22 +446,20 @@ class EngineCoreProc(EngineCore): # Receive initialization message. logger.info("Waiting for init message from front-end.") - if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000): + if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000): raise RuntimeError("Did not receive response from front-end " f"process within {HANDSHAKE_TIMEOUT_MINS} " f"minutes") - init_bytes = input_socket.recv() - init_message = msgspec.msgpack.decode(init_bytes) + init_bytes = handshake_socket.recv() + init_message: EngineHandshakeMetadata = msgspec.msgpack.decode( + init_bytes, type=EngineHandshakeMetadata) logger.debug("Received init message: %s", init_message) - output_socket_address = init_message["output_socket_address"] - #TBD(nick) maybe replace IP with configured head node address - - received_parallel_config = init_message["parallel_config"] + received_parallel_config = init_message.parallel_config for key, value in received_parallel_config.items(): setattr(parallel_config, key, value) - return output_socket_address + return init_message.addresses @staticmethod def run_engine_core(*args, @@ -528,7 +531,7 @@ class EngineCoreProc(EngineCore): """Exits when an engine step needs to be performed.""" waited = False - while not self.engines_running and not (self.scheduler.has_requests()): + while not self.engines_running and not self.scheduler.has_requests(): if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): logger.debug("EngineCore waiting for work.") waited = True @@ -549,8 +552,8 @@ class EngineCoreProc(EngineCore): # Step the engine core. outputs, model_executed = self.step_fn() # Put EngineCoreOutputs into the output queue. - if outputs is not None: - self.output_queue.put_nowait(outputs) + for output in (outputs.items() if outputs else ()): + self.output_queue.put_nowait(output) return model_executed @@ -563,7 +566,7 @@ class EngineCoreProc(EngineCore): elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) elif request_type == EngineCoreRequestType.UTILITY: - call_id, method_name, args = request + client_idx, call_id, method_name, args = request output = UtilityOutput(call_id) try: method = getattr(self, method_name) @@ -574,7 +577,7 @@ class EngineCoreProc(EngineCore): output.failure_message = (f"Call to {method_name} method" f" failed: {str(e)}") self.output_queue.put_nowait( - EngineCoreOutputs(utility_output=output)) + (client_idx, EngineCoreOutputs(utility_output=output))) elif request_type == EngineCoreRequestType.EXECUTOR_FAILED: raise RuntimeError("Executor failed.") else: @@ -607,27 +610,68 @@ class EngineCoreProc(EngineCore): logger.fatal("vLLM shutdown signal from EngineCore failed " "to send. Please report this issue.") - def process_input_socket(self, input_socket: zmq.Socket): + def process_input_sockets(self, input_addresses: list[str], + coord_input_address: Optional[str], + identity: bytes): """Input socket IO thread.""" # Msgpack serialization decoding. add_request_decoder = MsgpackDecoder(EngineCoreRequest) generic_decoder = MsgpackDecoder() - while True: - # (RequestType, RequestData) - type_frame, *data_frames = input_socket.recv_multipart(copy=False) - request_type = EngineCoreRequestType(bytes(type_frame.buffer)) + with ExitStack() as stack, zmq.Context() as ctx: + input_sockets = [ + stack.enter_context( + make_zmq_socket(ctx, + input_address, + zmq.DEALER, + identity=identity, + bind=False)) + for input_address in input_addresses + ] + if coord_input_address is None: + coord_socket = None + else: + coord_socket = stack.enter_context( + make_zmq_socket(ctx, + coord_input_address, + zmq.XSUB, + identity=identity, + bind=False)) + # Send subscription message to coordinator. + coord_socket.send(b'\x01') - # Deserialize the request data. - decoder = add_request_decoder if ( - request_type == EngineCoreRequestType.ADD) else generic_decoder - request = decoder.decode(data_frames) + # Register sockets with poller. + poller = zmq.Poller() + for input_socket in input_sockets: + # Send initial message to each input socket - this is required + # before the front-end ROUTER socket can send input messages + # back to us. + input_socket.send(b'') + poller.register(input_socket, zmq.POLLIN) + if coord_socket is not None: + poller.register(coord_socket, zmq.POLLIN) - # Push to input queue for core busy loop. - self.input_queue.put_nowait((request_type, request)) + while True: + for input_socket, _ in poller.poll(): + # (RequestType, RequestData) + type_frame, *data_frames = input_socket.recv_multipart( + copy=False) + request_type = EngineCoreRequestType( + bytes(type_frame.buffer)) - def process_output_socket(self, output_path: str, engine_index: int): + # Deserialize the request data. + decoder = add_request_decoder if ( + request_type + == EngineCoreRequestType.ADD) else generic_decoder + request = decoder.decode(data_frames) + + # Push to input queue for core busy loop. + self.input_queue.put_nowait((request_type, request)) + + def process_output_sockets(self, output_paths: list[str], + coord_output_path: Optional[str], + engine_index: int): """Output socket IO thread.""" # Msgpack serialization encoding. @@ -641,30 +685,49 @@ class EngineCoreProc(EngineCore): # We must set linger to ensure the ENGINE_CORE_DEAD # message is sent prior to closing the socket. - with zmq_socket_ctx(output_path, zmq.constants.PUSH, - linger=4000) as socket: + with ExitStack() as stack, zmq.Context() as ctx: + sockets = [ + stack.enter_context( + make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)) + for output_path in output_paths + ] + coord_socket = stack.enter_context( + make_zmq_socket( + ctx, coord_output_path, zmq.PUSH, bind=False, + linger=4000)) if coord_output_path is not None else None + max_reuse_bufs = len(sockets) + 1 + while True: - outputs = self.output_queue.get() - if outputs == EngineCoreProc.ENGINE_CORE_DEAD: - socket.send(outputs, copy=False) + output = self.output_queue.get() + if output == EngineCoreProc.ENGINE_CORE_DEAD: + for socket in sockets: + socket.send(output) break - assert not isinstance(outputs, bytes) + assert not isinstance(output, bytes) + client_index, outputs = output outputs.engine_index = engine_index + if client_index == -1: + # Don't reuse buffer for coordinator message + # which will be very small. + assert coord_socket is not None + coord_socket.send_multipart(encoder.encode(outputs)) + continue + # Reclaim buffers that zmq is finished with. while pending and pending[-1][0].done: reuse_buffers.append(pending.pop()[2]) buffer = reuse_buffers.pop() if reuse_buffers else bytearray() buffers = encoder.encode_into(outputs, buffer) - tracker = socket.send_multipart(buffers, - copy=False, - track=True) + tracker = sockets[client_index].send_multipart(buffers, + copy=False, + track=True) if not tracker.done: ref = outputs if len(buffers) > 1 else None pending.appendleft((tracker, ref, buffer)) - elif len(reuse_buffers) < 2: - # Keep at most 2 buffers to reuse. + elif len(reuse_buffers) < max_reuse_bufs: + # Limit the number of buffers to reuse. reuse_buffers.append(buffer) @@ -676,7 +739,7 @@ class DPEngineCoreProc(EngineCoreProc): self, vllm_config: VllmConfig, on_head_node: bool, - input_address: str, + handshake_address: str, executor_class: type[Executor], log_stats: bool, ): @@ -691,10 +754,11 @@ class DPEngineCoreProc(EngineCoreProc): # Counts forward-passes of the model so that we can synchronize # finished with DP peers every N steps. self.counter = 0 + self.current_wave = 0 # Initialize the engine. dp_rank = vllm_config.parallel_config.data_parallel_rank - super().__init__(vllm_config, on_head_node, input_address, + super().__init__(vllm_config, on_head_node, handshake_address, executor_class, log_stats, dp_rank) def _init_data_parallel(self, vllm_config: VllmConfig): @@ -726,7 +790,6 @@ class DPEngineCoreProc(EngineCoreProc): self.dp_rank = dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() - self.current_wave = 0 def shutdown(self): super().shutdown() @@ -734,22 +797,23 @@ class DPEngineCoreProc(EngineCoreProc): stateless_destroy_torch_distributed_process_group(dp_group) def add_request(self, request: EngineCoreRequest): - if request.current_wave != self.current_wave: + if self.has_coordinator and request.current_wave != self.current_wave: if request.current_wave > self.current_wave: self.current_wave = request.current_wave elif not self.engines_running: # Request received for an already-completed wave, notify # front-end that we need to start the next one. self.output_queue.put_nowait( - EngineCoreOutputs(start_wave=self.current_wave)) + (-1, EngineCoreOutputs(start_wave=self.current_wave))) super().add_request(request) def _handle_client_request(self, request_type: EngineCoreRequestType, request: Any) -> None: if request_type == EngineCoreRequestType.START_DP_WAVE: - new_wave: int = request - if new_wave >= self.current_wave: + new_wave, exclude_eng_index = request + if exclude_eng_index != self.engine_index and ( + new_wave >= self.current_wave): self.current_wave = new_wave if not self.engines_running: logger.debug("EngineCore starting idle loop for wave %d.", @@ -758,6 +822,18 @@ class DPEngineCoreProc(EngineCoreProc): else: super()._handle_client_request(request_type, request) + def _maybe_publish_request_counts(self): + if not self.has_coordinator: + return + + # Publish our request counts (if they've changed). + counts = self.scheduler.get_request_counts() + if counts != self.last_counts: + self.last_counts = counts + stats = SchedulerStats(*counts) + self.output_queue.put_nowait( + (-1, EngineCoreOutputs(scheduler_stats=stats))) + def run_busy_loop(self): """Core busy loop of the EngineCore for data parallel case.""" @@ -768,6 +844,8 @@ class DPEngineCoreProc(EngineCoreProc): # 2) Step the engine core. executed = self._process_engine_step() + self._maybe_publish_request_counts() + local_unfinished_reqs = self.scheduler.has_unfinished_requests() if not executed: if not local_unfinished_reqs and not self.engines_running: @@ -788,7 +866,8 @@ class DPEngineCoreProc(EngineCoreProc): logger.debug("Wave %d finished, pausing engine loop.", self.current_wave) self.output_queue.put_nowait( - EngineCoreOutputs(wave_complete=self.current_wave)) + (-1, + EngineCoreOutputs(wave_complete=self.current_wave))) self.current_wave += 1 def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 9f8a9b6922200..e9e2d2d8d1e98 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -2,6 +2,7 @@ import asyncio import contextlib import queue +import sys import uuid import weakref from abc import ABC, abstractmethod @@ -9,26 +10,28 @@ from collections import deque from collections.abc import Awaitable, Sequence from concurrent.futures import Future from dataclasses import dataclass -from enum import Enum, auto from threading import Thread from typing import Any, Callable, Optional, TypeVar, Union -import msgspec +import msgspec.msgpack import zmq import zmq.asyncio -from vllm.config import ParallelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.utils import (get_open_port, get_open_zmq_inproc_path, - get_open_zmq_ipc_path, get_tcp_uri, make_zmq_socket) +from vllm.utils import (get_open_zmq_inproc_path, make_zmq_socket, + zmq_socket_ctx) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) +from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr -from vllm.v1.utils import CoreEngineProcManager +from vllm.v1.utils import (CoreEngine, CoreEngineProcManager, + EngineZmqAddresses, get_engine_client_zmq_addr, + wait_for_engine_startup) logger = init_logger(__name__) @@ -36,8 +39,6 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]] _R = TypeVar('_R') # Return type for collective_rpc -STARTUP_POLL_PERIOD_MS = 10000 - class EngineCoreClient(ABC): """ @@ -207,7 +208,7 @@ class InprocClient(EngineCoreClient): def get_output(self) -> EngineCoreOutputs: outputs, _ = self.engine_core.step() - return outputs + return outputs.get(0) or EngineCoreOutputs() def add_request(self, request: EngineCoreRequest) -> None: self.engine_core.add_request(request) @@ -266,24 +267,6 @@ class InprocClient(EngineCoreClient): return self.engine_core.collective_rpc(method, timeout, args, kwargs) -class CoreEngineState(Enum): - NEW = auto() - CONNECTED = auto() - READY = auto() - - -class CoreEngine: - """One per data parallel rank.""" - - def __init__(self, index: int = 0, local: bool = True): - self.local = local - self.index = index - self.identity = index.to_bytes(length=2, byteorder="little") - - self.state = CoreEngineState.NEW - self.num_reqs_in_flight = 0 - - @dataclass class BackgroundResources: """Used as a finalizer for clean shutdown, avoiding @@ -291,9 +274,12 @@ class BackgroundResources: ctx: Union[zmq.Context] local_engine_manager: Optional[CoreEngineProcManager] = None + coordinator: Optional[DPCoordinator] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None + first_req_send_socket: Optional[zmq.asyncio.Socket] = None output_queue_task: Optional[asyncio.Task] = None + stats_update_task: Optional[asyncio.Task] = None shutdown_path: Optional[str] = None # Set if any of the engines are dead. Here so that the output @@ -306,16 +292,21 @@ class BackgroundResources: self.engine_dead = True if self.local_engine_manager is not None: self.local_engine_manager.close() + if self.coordinator is not None: + self.coordinator.close() if self.output_queue_task is not None: self.output_queue_task.cancel() + if self.stats_update_task is not None: + self.stats_update_task.cancel() # ZMQ context termination can hang if the sockets # aren't explicitly closed first. - if self.output_socket is not None: - self.output_socket.close(linger=0) - if self.input_socket is not None: - self.input_socket.close(linger=0) + for socket in (self.output_socket, self.input_socket, + self.first_req_send_socket): + if socket is not None: + socket.close(linger=0) + if self.shutdown_path is not None: # We must ensure that the sync output socket is # closed cleanly in its own thread. @@ -350,6 +341,7 @@ class MPClient(EngineCoreClient): vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, ): self.vllm_config = vllm_config # Serialization setup. @@ -369,8 +361,8 @@ class MPClient(EngineCoreClient): try: parallel_config = vllm_config.parallel_config local_engine_count = parallel_config.data_parallel_size_local - start_index = parallel_config.data_parallel_rank local_start_index = parallel_config.data_parallel_rank_local + dp_size = parallel_config.data_parallel_size # SPMD mode is where there is an LLM instance per DP rank and # one core engine per LLM, see @@ -382,42 +374,53 @@ class MPClient(EngineCoreClient): CoreEngine(index=local_start_index, local=True) ] else: - assert start_index == 0 + assert parallel_config.data_parallel_rank == 0 local_start_index = 0 self.core_engines = [ CoreEngine(index=i, local=(i < local_engine_count)) - for i in range(parallel_config.data_parallel_size) + for i in range(dp_size) ] - input_address, output_address = self._get_zmq_addresses( - parallel_config, spmd_mode) + local_only = spmd_mode or local_engine_count == dp_size + + self.stats_update_address: Optional[str] = None + if client_addresses is not None: + input_address = client_addresses["input_address"] + output_address = client_addresses["output_address"] + self.stats_update_address = client_addresses.get( + "stats_update_address") + else: + host = parallel_config.data_parallel_master_ip + input_address = get_engine_client_zmq_addr(local_only, host) + output_address = get_engine_client_zmq_addr(local_only, host) # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( self.ctx, input_address, zmq.ROUTER, bind=True) - self.resources.output_socket = make_zmq_socket( - self.ctx, output_address, zmq.constants.PULL) - # Start local engines. - if local_engine_count: - # In server mode, start_index and local_start_index will - # both be 0. - self.resources.local_engine_manager = CoreEngineProcManager( - EngineCoreProc.run_engine_core, - vllm_config=vllm_config, - executor_class=executor_class, - log_stats=log_stats, - input_address=input_address, - on_head_node=True, - local_engine_count=local_engine_count, - start_index=start_index, - local_start_index=local_start_index) + self.ctx, output_address, zmq.PULL) + + if client_addresses is None: + self._init_engines_direct(vllm_config, local_only, + local_start_index, input_address, + output_address, executor_class, + log_stats) + coordinator = self.resources.coordinator + if coordinator: + self.stats_update_address = ( + coordinator.get_stats_publish_address()) + + # Wait for ready messages from each engine on the input socket. + identities = set(e.identity for e in self.core_engines) + sync_input_socket = zmq.Socket.shadow(self.input_socket) + while identities: + if not sync_input_socket.poll(timeout=600_000): + raise TimeoutError("Timed out waiting for engines to send" + "initial message on input socket.") + identity, _ = sync_input_socket.recv_multipart() + identities.remove(identity) self.core_engine = self.core_engines[0] - - # Wait for engine core process(es) to start. - self._wait_for_engine_startup(output_address, parallel_config) - self.utility_results: dict[int, AnyFuture] = {} # Request objects which may contain pytorch-allocated tensors @@ -430,116 +433,67 @@ class MPClient(EngineCoreClient): if not success: self._finalizer() - @staticmethod - def _get_zmq_addresses(parallel_config: ParallelConfig, - spmd_mode: bool) -> tuple[str, str]: - """Returns (input_address, output_address).""" - dp_size = parallel_config.data_parallel_size + def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, + local_start_index: int, input_address: str, + output_address: str, + executor_class: type[Executor], log_stats: bool): + """Self-contained client mode, launch engine and coordinator process + as needed.""" + + parallel_config = vllm_config.parallel_config local_engine_count = parallel_config.data_parallel_size_local + start_index = parallel_config.data_parallel_rank + host = parallel_config.data_parallel_master_ip - if local_engine_count == dp_size or spmd_mode: - input_address = get_open_zmq_ipc_path() - output_address = get_open_zmq_ipc_path() - else: - host = parallel_config.data_parallel_master_ip - input_port = parallel_config.data_parallel_rpc_port - output_port = get_open_port() - input_address = get_tcp_uri(host, input_port) - output_address = get_tcp_uri(host, output_port) + if len(self.core_engines) > 1: + self.resources.coordinator = DPCoordinator(parallel_config) - return input_address, output_address + handshake_address = get_engine_client_zmq_addr( + local_only, host, parallel_config.data_parallel_rpc_port) - def _wait_for_engine_startup(self, output_address: str, - parallel_config: ParallelConfig): - # Get a sync handle to the socket which can be sync or async. - sync_input_socket = zmq.Socket.shadow(self.input_socket) + with zmq_socket_ctx(handshake_address, zmq.ROUTER, + bind=True) as handshake_socket: - # Wait for engine core process(es) to send ready messages. - local_count = parallel_config.data_parallel_size_local - remote_count = len(self.core_engines) - local_count - # [local, remote] counts - conn_pending, start_pending = [local_count, remote_count], [0, 0] + # Start local engines. + if local_engine_count: + # In server mode, start_index and local_start_index will + # both be 0. + self.resources.local_engine_manager = CoreEngineProcManager( + EngineCoreProc.run_engine_core, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=log_stats, + handshake_address=handshake_address, + on_head_node=True, + local_engine_count=local_engine_count, + start_index=start_index, + local_start_index=local_start_index) - poller = zmq.Poller() - poller.register(sync_input_socket, zmq.POLLIN) - proc_manager = self.resources.local_engine_manager - if proc_manager is not None: - for sentinel in proc_manager.sentinels(): - poller.register(sentinel, zmq.POLLIN) - while any(conn_pending) or any(start_pending): - events = poller.poll(STARTUP_POLL_PERIOD_MS) - if not events: - if any(conn_pending): - logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to connect.", *conn_pending) - if any(start_pending): - logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to start.", *start_pending) - continue - if len(events) > 1 or events[0][0] != sync_input_socket: - # One of the local core processes exited. - finished = proc_manager.finished_procs( - ) if proc_manager else {} - raise RuntimeError("Engine core initialization failed. " - "See root cause above. " - f"Failed core proc(s): {finished}") + # Wait for engine core process(es) to start. + self._wait_for_engine_startup(handshake_socket, input_address, + output_address) - # Receive HELLO and READY messages from the input socket. - eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart() - eng_index = int.from_bytes(eng_identity, byteorder="little") - engine = next( - (e for e in self.core_engines if e.identity == eng_identity), - None) - if engine is None: - raise RuntimeError(f"Message from engine with unexpected data " - f"parallel rank: {eng_index}") - msg = msgspec.msgpack.decode(ready_msg_bytes) - status, local = msg["status"], msg["local"] - if local != engine.local: - raise RuntimeError(f"{status} message from " - f"{'local' if local else 'remote'} " - f"engine {eng_index}, expected it to be " - f"{'local' if engine.local else 'remote'}") + def _wait_for_engine_startup(self, handshake_socket: zmq.Socket, + input_address: str, output_address: str): + addresses = EngineZmqAddresses( + inputs=[input_address], + outputs=[output_address], + ) - if status == "HELLO" and engine.state == CoreEngineState.NEW: + coordinator = self.resources.coordinator + if coordinator is not None: + addresses.coordinator_input, addresses.coordinator_output = ( + coordinator.get_engine_socket_addresses()) - # Send init message with DP config info. - init_message = self.encoder.encode({ - "output_socket_address": output_address, - "parallel_config": { - "data_parallel_master_ip": - parallel_config.data_parallel_master_ip, - "data_parallel_master_port": - parallel_config.data_parallel_master_port, - "data_parallel_size": - parallel_config.data_parallel_size, - }, - }) - sync_input_socket.send_multipart((eng_identity, *init_message), - copy=False) - conn_pending[0 if local else 1] -= 1 - start_pending[0 if local else 1] += 1 - engine.state = CoreEngineState.CONNECTED - elif status == "READY" and (engine.state - == CoreEngineState.CONNECTED): - # Setup KV cache config with initialization state from - # engine core process. Sum values from all engines in DP case. - cache_config = self.vllm_config.cache_config - num_gpu_blocks = cache_config.num_gpu_blocks or 0 - num_gpu_blocks += msg['num_gpu_blocks'] - cache_config.num_gpu_blocks = num_gpu_blocks - - start_pending[0 if local else 1] -= 1 - engine.state = CoreEngineState.READY - else: - raise RuntimeError(f"Unexpected {status} message for " - f"{'local' if local else 'remote'} engine " - f"{eng_index} in {engine.state} state.") - - logger.debug("%s from %s core engine process %s.", status, - "local" if local else "remote", eng_index) + wait_for_engine_startup( + handshake_socket, + addresses, + self.core_engines, + self.vllm_config.parallel_config, + self.vllm_config.cache_config, + self.resources.local_engine_manager, + coordinator.proc if coordinator else None, + ) def shutdown(self): # Terminate background resources. @@ -605,8 +559,8 @@ class SyncMPClient(MPClient): try: shutdown_socket.bind(shutdown_path) poller = zmq.Poller() - poller.register(shutdown_socket) - poller.register(out_socket) + poller.register(shutdown_socket, zmq.POLLIN) + poller.register(out_socket, zmq.POLLIN) while True: socks = poller.poll() if not socks: @@ -668,7 +622,7 @@ class SyncMPClient(MPClient): future: Future[Any] = Future() self.utility_results[call_id] = future self._send_input(EngineCoreRequestType.UTILITY, - (call_id, method, args)) + (0, call_id, method, args)) return future.result() @@ -730,15 +684,21 @@ class SyncMPClient(MPClient): class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" - def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], - log_stats: bool): + def __init__(self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0): super().__init__( asyncio_mode=True, vllm_config=vllm_config, executor_class=executor_class, log_stats=log_stats, + client_addresses=client_addresses, ) + self.client_index = client_index self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, Exception]]() try: @@ -854,12 +814,13 @@ class AsyncMPClient(MPClient): future = asyncio.get_running_loop().create_future() self.utility_results[call_id] = future message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( - (call_id, method, args))) + (self.client_index, call_id, method, args))) await self._send_input_message(message, engine, args) self._ensure_output_queue_task() return await future async def add_request_async(self, request: EngineCoreRequest) -> None: + request.client_index = self.client_index await self._send_input(EngineCoreRequestType.ADD, request) self._ensure_output_queue_task() @@ -921,17 +882,120 @@ class DPAsyncMPClient(AsyncMPClient): """Asyncio-compatible client for multi-proc, multi-engine (data parallel) EngineCore.""" - def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], - log_stats: bool): + def __init__(self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0): self.current_wave = 0 self.engines_running = False + # To route aborts to the correct engine. self.reqs_in_flight: dict[str, CoreEngine] = {} - super().__init__(vllm_config, executor_class, log_stats) + super().__init__(vllm_config, executor_class, log_stats, + client_addresses, client_index) assert len(self.core_engines) > 1 + # List of [waiting, running] pair per engine. + self.lb_engines: list[list[int]] = [] + + self.first_req_sock_addr = get_open_zmq_inproc_path() + self.first_req_send_socket = self.resources.first_req_send_socket = ( + make_zmq_socket(self.ctx, + self.first_req_sock_addr, + zmq.PAIR, + bind=True)) + try: + # If we are running in an asyncio event loop, start the stats task. + # Otherwise, it will be started lazily. + asyncio.get_running_loop() + self._ensure_stats_update_task() + except RuntimeError: + pass + + def _ensure_stats_update_task(self): + resources = self.resources + if resources.stats_update_task is not None: + return + + assert self.stats_update_address is not None + + async def run_engine_stats_update_task(): + with make_zmq_socket(self.ctx, self.stats_update_address, + zmq.XSUB) as socket, make_zmq_socket( + self.ctx, + self.first_req_sock_addr, + zmq.PAIR, + bind=False) as first_req_rcv_socket: + # Send subscription message. + await socket.send(b'\x01') + + poller = zmq.asyncio.Poller() + poller.register(socket, zmq.POLLIN) + poller.register(first_req_rcv_socket, zmq.POLLIN) + + while True: + events = await poller.poll() + if not self.engines_running and len(events) == 2 or ( + events[0][0] == first_req_rcv_socket): + # Send a message to notify the coordinator that + # we're sending a request while the engines are + # paused, so that it can wake the others up + # (to run dummy EP loop). + self.engines_running = True + buf = first_req_rcv_socket.recv( + flags=zmq.NOBLOCK).result() + target_eng_index = int.from_bytes(buf, "little") + msg = msgspec.msgpack.encode( + (target_eng_index, self.current_wave)) + await socket.send(msg) + + buf = None + while True: + # Drain all stats events (we only care about latest). + future: asyncio.Future[bytes] = socket.recv( + flags=zmq.NOBLOCK) + if isinstance(future.exception(), zmq.Again): + break + buf = future.result() + if buf is None: + continue + + # Update local load-balancing state. + counts, wave, running = msgspec.msgpack.decode(buf) + self.current_wave = wave + self.engines_running = running + self.lb_engines = counts + + resources.stats_update_task = asyncio.create_task( + run_engine_stats_update_task()) + + def get_core_engine_for_request(self) -> CoreEngine: + if not self.lb_engines: + return self.core_engines[0] + # TODO use P2C alg for larger DP sizes + num_engines = len(self.lb_engines) + min_counts = [sys.maxsize, sys.maxsize] + eng_index = 0 + for i in range(num_engines): + # Start from client_index to help with balancing when engines + # are empty. + idx = (self.client_index + i) % num_engines + counts = self.lb_engines[idx] + if counts < min_counts: + min_counts = counts + eng_index = idx + # Adjust local counts for better balancing between stats updates + # from the coordinator (which happen every 100ms). + if min_counts[0]: + min_counts[0] += 1 + else: + min_counts[1] += 1 + return self.core_engines[eng_index] + async def call_utility_async(self, method: str, *args) -> Any: # Only the result from the first engine is returned. return (await asyncio.gather(*[ @@ -940,62 +1004,30 @@ class DPAsyncMPClient(AsyncMPClient): ]))[0] async def add_request_async(self, request: EngineCoreRequest) -> None: + self._ensure_stats_update_task() + request.current_wave = self.current_wave + request.client_index = self.client_index chosen_engine = self.get_core_engine_for_request() self.reqs_in_flight[request.request_id] = chosen_engine - chosen_engine.num_reqs_in_flight += 1 to_await = self._send_input(EngineCoreRequestType.ADD, request, chosen_engine) if not self.engines_running: - # Send request to chosen engine and dp start loop - # control message to all other engines. - self.engines_running = True - to_await = asyncio.gather( - to_await, # type: ignore[assignment] - *self._start_wave_coros(exclude_index=chosen_engine.index)) + # Notify coordinator that we're sending a request + await self.first_req_send_socket.send(chosen_engine.identity) await to_await self._ensure_output_queue_task() - def get_core_engine_for_request(self) -> CoreEngine: - return min(self.core_engines, key=lambda e: e.num_reqs_in_flight) - @staticmethod async def process_engine_outputs(self: "DPAsyncMPClient", outputs: EngineCoreOutputs): - if self.reqs_in_flight: - for req_id in outputs.finished_requests or (): - if engine := self.reqs_in_flight.pop(req_id, None): - engine.num_reqs_in_flight -= 1 - - if outputs.wave_complete is not None: - # Current wave is complete, move to next wave number - # and mark engines as paused. - if self.current_wave <= outputs.wave_complete: - self.current_wave = outputs.wave_complete + 1 - self.engines_running = False - - elif outputs.start_wave is not None and ( - outputs.start_wave > self.current_wave or - (outputs.start_wave == self.current_wave - and not self.engines_running)): - # Engine received request for a non-current wave so we must ensure - # that other engines progress to the next wave. - self.current_wave = outputs.start_wave - self.engines_running = True - await asyncio.gather(*self._start_wave_coros( - exclude_index=outputs.engine_index)) - - def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]: - logger.debug("Sending start DP wave %d.", self.current_wave) - return [ - self._send_input(EngineCoreRequestType.START_DP_WAVE, - self.current_wave, engine) - for engine in self.core_engines if engine.index != exclude_index - ] + if outputs.finished_requests and self.reqs_in_flight: + for req_id in outputs.finished_requests: + self.reqs_in_flight.pop(req_id, None) async def abort_requests_async(self, request_ids: list[str]) -> None: if not request_ids: diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 3dc2f77444f63..665e5873d5891 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -12,13 +12,12 @@ from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason +from vllm.v1.metrics.prometheus import unregister_vllm_metrics from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm logger = init_logger(__name__) -_LOCAL_LOGGING_INTERVAL_SEC = 5.0 - StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] @@ -35,7 +34,7 @@ class StatLoggerBase(ABC): ... @abstractmethod - def record(self, scheduler_stats: SchedulerStats, + def record(self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats]): ... @@ -78,20 +77,22 @@ class LoggingStatLogger(StatLoggerBase): # Compute summary metrics for tracked stats return float(np.sum(tracked_stats) / (now - self.last_log_time)) - def record(self, scheduler_stats: SchedulerStats, + def record(self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats]): """Log Stats to standard output.""" if iteration_stats: self._track_iteration_stats(iteration_stats) - self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) + if scheduler_stats is not None: + self.prefix_caching_metrics.observe( + scheduler_stats.prefix_cache_stats) - if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_logging.observe( - scheduler_stats.spec_decoding_stats) + if scheduler_stats.spec_decoding_stats is not None: + self.spec_decoding_logging.observe( + scheduler_stats.spec_decoding_stats) - self.last_scheduler_stats = scheduler_stats + self.last_scheduler_stats = scheduler_stats def log(self): now = time.monotonic() @@ -131,10 +132,11 @@ class LoggingStatLogger(StatLoggerBase): self.spec_decoding_logging.log(log_fn=log_fn) def log_engine_initialized(self): - logger.info( - "vllm cache_config_info with initialization " \ - "after num_gpu_blocks is: %d", - self.vllm_config.cache_config.num_gpu_blocks) + if self.vllm_config.cache_config.num_gpu_blocks: + logger.info( + "Engine %03d: vllm cache_config_info with initialization " + "after num_gpu_blocks is: %d", self.engine_index, + self.vllm_config.cache_config.num_gpu_blocks) class PrometheusStatLogger(StatLoggerBase): @@ -144,7 +146,8 @@ class PrometheusStatLogger(StatLoggerBase): _spec_decoding_cls = SpecDecodingProm def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): - self._unregister_vllm_metrics() + + unregister_vllm_metrics() self.vllm_config = vllm_config self.engine_index = engine_index # Use this flag to hide metrics that were deprecated in @@ -169,11 +172,13 @@ class PrometheusStatLogger(StatLoggerBase): self.gauge_scheduler_running = self._gauge_cls( name="vllm:num_requests_running", documentation="Number of requests in model execution batches.", + multiprocess_mode="mostrecent", labelnames=labelnames).labels(*labelvalues) self.gauge_scheduler_waiting = self._gauge_cls( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", + multiprocess_mode="mostrecent", labelnames=labelnames).labels(*labelvalues) # @@ -182,6 +187,7 @@ class PrometheusStatLogger(StatLoggerBase): self.gauge_gpu_cache_usage = self._gauge_cls( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", + multiprocess_mode="mostrecent", labelnames=labelnames).labels(*labelvalues) self.counter_gpu_prefix_cache_queries = self._counter_cls( @@ -242,6 +248,9 @@ class PrometheusStatLogger(StatLoggerBase): buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames).labels(*labelvalues) + # TODO: This metric might be incorrect in case of using multiple + # api_server counts which uses prometheus mp. + # See: https://github.com/vllm-project/vllm/pull/18053 self.histogram_iteration_tokens = \ self._histogram_cls( name="vllm:iteration_tokens_total", @@ -340,6 +349,9 @@ class PrometheusStatLogger(StatLoggerBase): # # LoRA metrics # + + # TODO: This metric might be incorrect in case of using multiple + # api_server counts which uses prometheus mp. self.gauge_lora_info: Optional[prometheus_client.Gauge] = None if vllm_config.lora_config is not None: self.labelname_max_lora = "max_lora" @@ -350,13 +362,16 @@ class PrometheusStatLogger(StatLoggerBase): self._gauge_cls( name="vllm:lora_requests_info", documentation="Running stats on lora requests.", + multiprocess_mode="sum", labelnames=[ self.labelname_max_lora, self.labelname_waiting_lora_adapters, self.labelname_running_lora_adapters, - ]) + ], + ) def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): + metrics_info = config_obj.metrics_info() metrics_info["engine"] = self.engine_index @@ -372,25 +387,28 @@ class PrometheusStatLogger(StatLoggerBase): info_gauge = self._gauge_cls( name=name, documentation=documentation, - labelnames=metrics_info.keys()).labels(**metrics_info) + multiprocess_mode="mostrecent", + labelnames=metrics_info.keys(), + ).labels(**metrics_info) info_gauge.set(1) - def record(self, scheduler_stats: SchedulerStats, + def record(self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats]): """Log to prometheus.""" - self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) - self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) + if scheduler_stats is not None: + self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) + self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) - self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) + self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) - self.counter_gpu_prefix_cache_queries.inc( - scheduler_stats.prefix_cache_stats.queries) - self.counter_gpu_prefix_cache_hits.inc( - scheduler_stats.prefix_cache_stats.hits) + self.counter_gpu_prefix_cache_queries.inc( + scheduler_stats.prefix_cache_stats.queries) + self.counter_gpu_prefix_cache_hits.inc( + scheduler_stats.prefix_cache_stats.hits) - if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_prom.observe( - scheduler_stats.spec_decoding_stats) + if scheduler_stats.spec_decoding_stats is not None: + self.spec_decoding_prom.observe( + scheduler_stats.spec_decoding_stats) if iteration_stats is None: return @@ -445,13 +463,6 @@ class PrometheusStatLogger(StatLoggerBase): self.gauge_lora_info.labels(**lora_info_labels)\ .set_to_current_time() - @staticmethod - def _unregister_vllm_metrics(): - # Unregister any existing vLLM collectors (for CI/CD - for collector in list(prometheus_client.REGISTRY._collector_to_names): - if hasattr(collector, "_name") and "vllm" in collector._name: - prometheus_client.REGISTRY.unregister(collector) - def log_engine_initialized(self): self.log_metrics_info("cache_config", self.vllm_config.cache_config) diff --git a/vllm/v1/metrics/prometheus.py b/vllm/v1/metrics/prometheus.py new file mode 100644 index 0000000000000..f125685353614 --- /dev/null +++ b/vllm/v1/metrics/prometheus.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import tempfile +from typing import Optional + +from prometheus_client import REGISTRY, CollectorRegistry, multiprocess + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# Global temporary directory for prometheus multiprocessing +_prometheus_multiproc_dir: Optional[tempfile.TemporaryDirectory] = None + + +def setup_multiprocess_prometheus(): + """Set up prometheus multiprocessing directory if not already configured. + + """ + global _prometheus_multiproc_dir + + if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: + # Make TemporaryDirectory for prometheus multiprocessing + # Note: global TemporaryDirectory will be automatically + # cleaned up upon exit. + _prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name + logger.debug("Created PROMETHEUS_MULTIPROC_DIR at %s", + _prometheus_multiproc_dir.name) + else: + logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. " + "This directory must be wiped between vLLM runs or " + "you will find inaccurate metrics. Unset the variable " + "and vLLM will properly handle cleanup.") + + +def get_prometheus_registry(): + """Get the appropriate prometheus registry based on multiprocessing + configuration. + + Returns: + Registry: A prometheus registry + """ + if os.getenv("PROMETHEUS_MULTIPROC_DIR") is not None: + logger.debug("Using multiprocess registry for prometheus metrics") + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + return registry + + return REGISTRY + + +def unregister_vllm_metrics(): + """Unregister any existing vLLM collectors from the prometheus registry. + + This is useful for testing and CI/CD where metrics may be registered + multiple times across test runs. + + Also, in case of multiprocess, we need to unregister the metrics from the + global registry. + """ + registry = REGISTRY + # Unregister any existing vLLM collectors + for collector in list(registry._collector_to_names): + if hasattr(collector, "_name") and "vllm" in collector._name: + registry.unregister(collector) + + +def shutdown_prometheus(): + """Shutdown prometheus metrics.""" + try: + pid = os.getpid() + multiprocess.mark_process_dead(pid) + logger.debug("Marked Prometheus metrics for process %d as dead", pid) + except Exception as e: + logger.error("Error during metrics cleanup: %s", str(e)) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index b4c84507532a1..42c75ef964016 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -26,12 +26,13 @@ class Request: multi_modal_placeholders: Optional[list[PlaceholderRange]], sampling_params: SamplingParams, eos_token_id: Optional[int], - arrival_time: float, + client_index: int = 0, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, ) -> None: self.request_id = request_id + self.client_index = client_index self.sampling_params = sampling_params # Because of LoRA, the eos token id can be different for each request. self.eos_token_id = eos_token_id @@ -90,13 +91,13 @@ class Request: return cls( request_id=request.request_id, + client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, multi_modal_inputs=request.mm_inputs, multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, sampling_params=request.sampling_params, eos_token_id=request.eos_token_id, - arrival_time=request.arrival_time, lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( sampling_params=request.sampling_params), diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 0758747a83cc6..a26794561a526 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,31 +1,41 @@ # SPDX-License-Identifier: Apache-2.0 -import os +import argparse +import multiprocessing import time import weakref from collections import defaultdict from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum, auto from multiprocessing import Process, connection -from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union, - overload) +from multiprocessing.process import BaseProcess +from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, + Union, overload) +import msgspec import torch +import zmq -from vllm.config import VllmConfig +from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.models.utils import extract_layer_index from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import get_mp_context, kill_process_tree +from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path, + get_tcp_uri, kill_process_tree) from vllm.v1.executor.abstract import Executor if TYPE_CHECKING: from vllm.attention.layer import Attention + from vllm.v1.engine.coordinator import DPCoordinator logger = init_logger(__name__) T = TypeVar("T") +STARTUP_POLL_PERIOD_MS = 10000 + class ConstantList(Generic[T], Sequence): @@ -95,6 +105,78 @@ class ConstantList(Generic[T], Sequence): return f"ConstantList({self._x})" +def get_engine_client_zmq_addr(local_only: bool, + host: str, + port: int = 0) -> str: + return get_open_zmq_ipc_path() if local_only else (get_tcp_uri( + host, port or get_open_port())) + + +class APIServerProcessManager: + """Manages a group of API server processes. + + Handles creation, monitoring, and termination of API server worker + processes. Also monitors extra processes to check if they are healthy. + """ + + def __init__( + self, + target_server_fn: Callable, + listen_address: str, + sock: Any, + args: argparse.Namespace, + num_servers: int, + input_addresses: list[str], + output_addresses: list[str], + stats_update_address: Optional[str] = None, + ): + """Initialize and start API server worker processes. + + Args: + target_server_fn: Function to call for each API server process + listen_address: Address to listen for client connections + sock: Socket for client connections + args: Command line arguments + num_servers: Number of API server processes to start + input_addresses: Input addresses for each API server + output_addresses: Output addresses for each API server + stats_update_address: Optional stats update address + """ + self.listen_address = listen_address + self.sock = sock + self.args = args + + # Start API servers + spawn_context = multiprocessing.get_context("spawn") + self.processes: list[BaseProcess] = [] + + for i, in_addr, out_addr in zip(range(num_servers), input_addresses, + output_addresses): + client_config = { + "input_address": in_addr, + "output_address": out_addr, + "client_index": i + } + if stats_update_address is not None: + client_config["stats_update_address"] = stats_update_address + + proc = spawn_context.Process(target=target_server_fn, + name=f"ApiServer_{i}", + args=(listen_address, sock, args, + client_config)) + self.processes.append(proc) + proc.start() + + logger.info("Started %d API server processes", len(self.processes)) + + # Shutdown only the API server processes on garbage collection + # The extra processes are managed by their owners + self._finalizer = weakref.finalize(self, shutdown, self.processes) + + def close(self) -> None: + self._finalizer() + + class CoreEngineProcManager: """ Utility class to handle creation, readiness, and shutdown @@ -109,7 +191,7 @@ class CoreEngineProcManager: local_start_index: int, vllm_config: VllmConfig, on_head_node: bool, - input_address: str, + handshake_address: str, executor_class: type[Executor], log_stats: bool, ): @@ -117,12 +199,12 @@ class CoreEngineProcManager: common_kwargs = { "vllm_config": vllm_config, "on_head_node": on_head_node, - "input_address": input_address, + "handshake_address": handshake_address, "executor_class": executor_class, "log_stats": log_stats, } - self.processes: list[Process] = [] + self.processes: list[BaseProcess] = [] for index in range(local_engine_count): local_index = local_start_index + index global_index = start_index + index @@ -135,8 +217,7 @@ class CoreEngineProcManager: "local_dp_rank": local_index, })) - self._finalizer = weakref.finalize(self, shutdown, self.processes, - input_address) + self._finalizer = weakref.finalize(self, shutdown, self.processes) try: for proc in self.processes: proc.start() @@ -164,9 +245,199 @@ class CoreEngineProcManager: } +class CoreEngineState(Enum): + NEW = auto() + CONNECTED = auto() + READY = auto() + + +class CoreEngine: + """One per data parallel rank.""" + + def __init__(self, index: int = 0, local: bool = True): + self.local = local + self.index = index + self.identity = index.to_bytes(2, "little") + + self.state = CoreEngineState.NEW + + +@dataclass +class EngineZmqAddresses: + # ZMQ input socket addresses for each front-end client (requests) + inputs: list[str] + # ZMQ output socket addresses for each front-end client (responses) + outputs: list[str] + # ZMQ input socket address of DP coordinator if applicable + coordinator_input: Optional[str] = None + # ZMQ output socket address of DP coordinator if applicable + coordinator_output: Optional[str] = None + + +@dataclass +class EngineHandshakeMetadata: + """Metadata sent to each engine process during startup handshake, + including addresses of the front-end ZMQ queues that they should + connect to. + """ + addresses: EngineZmqAddresses + parallel_config: dict[str, Union[int, str]] + + +def wait_for_engine_startup( + handshake_socket: zmq.Socket, + addresses: EngineZmqAddresses, + core_engines: list[CoreEngine], + parallel_config: ParallelConfig, + cache_config: CacheConfig, + proc_manager: Optional[CoreEngineProcManager], + coord_process: Optional[Process], +): + + # Wait for engine core process(es) to send ready messages. + local_count = parallel_config.data_parallel_size_local + remote_count = len(core_engines) - local_count + # [local, remote] counts + conn_pending, start_pending = [local_count, remote_count], [0, 0] + poller = zmq.Poller() + poller.register(handshake_socket, zmq.POLLIN) + + if proc_manager is not None: + for sentinel in proc_manager.sentinels(): + poller.register(sentinel, zmq.POLLIN) + if coord_process is not None: + poller.register(coord_process.sentinel, zmq.POLLIN) + while any(conn_pending) or any(start_pending): + events = poller.poll(STARTUP_POLL_PERIOD_MS) + if not events: + if any(conn_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to connect.", *conn_pending) + if any(start_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to start.", *start_pending) + continue + if len(events) > 1 or events[0][0] != handshake_socket: + # One of the local core processes exited. + finished = proc_manager.finished_procs() if proc_manager else {} + if coord_process is not None and coord_process.exitcode is not None: + finished[coord_process.name] = coord_process.exitcode + raise RuntimeError("Engine core initialization failed. " + "See root cause above. " + f"Failed core proc(s): {finished}") + + # Receive HELLO and READY messages from the input socket. + eng_identity, ready_msg_bytes = handshake_socket.recv_multipart() + eng_index = int.from_bytes(eng_identity, "little") + engine = next((e for e in core_engines if e.identity == eng_identity), + None) + if engine is None: + raise RuntimeError(f"Message from engine with unexpected data " + f"parallel rank: {eng_index}") + msg = msgspec.msgpack.decode(ready_msg_bytes) + status, local = msg["status"], msg["local"] + if local != engine.local: + raise RuntimeError(f"{status} message from " + f"{'local' if local else 'remote'} " + f"engine {eng_index}, expected it to be " + f"{'local' if engine.local else 'remote'}") + + if status == "HELLO" and engine.state == CoreEngineState.NEW: + + # Send init message with DP config info. + init_message = msgspec.msgpack.encode( + EngineHandshakeMetadata( + addresses=addresses, + parallel_config={ + "data_parallel_master_ip": + parallel_config.data_parallel_master_ip, + "data_parallel_master_port": + parallel_config.data_parallel_master_port, + "data_parallel_size": + parallel_config.data_parallel_size, + })) + handshake_socket.send_multipart((eng_identity, init_message), + copy=False) + conn_pending[0 if local else 1] -= 1 + start_pending[0 if local else 1] += 1 + engine.state = CoreEngineState.CONNECTED + elif status == "READY" and (engine.state == CoreEngineState.CONNECTED): + # Setup KV cache config with initialization state from + # engine core process. Sum values from all engines in DP case. + num_gpu_blocks = cache_config.num_gpu_blocks or 0 + num_gpu_blocks += msg["num_gpu_blocks"] + cache_config.num_gpu_blocks = num_gpu_blocks + + start_pending[0 if local else 1] -= 1 + engine.state = CoreEngineState.READY + else: + raise RuntimeError(f"Unexpected {status} message for " + f"{'local' if local else 'remote'} engine " + f"{eng_index} in {engine.state} state.") + + logger.debug("%s from %s core engine process %s.", status, + "local" if local else "remote", eng_index) + + +def wait_for_completion_or_failure( + api_server_manager: APIServerProcessManager, + local_engine_manager: Optional[CoreEngineProcManager] = None, + coordinator: Optional["DPCoordinator"] = None) -> None: + """Wait for all processes to complete or detect if any fail. + + Raises an exception if any process exits with a non-zero status. + """ + + try: + logger.info("Waiting for API servers to complete ...") + # Create a mapping of sentinels to their corresponding processes + # for efficient lookup + sentinel_to_proc: dict[Any, BaseProcess] = { + proc.sentinel: proc + for proc in api_server_manager.processes + } + + if coordinator: + sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc + + if local_engine_manager: + for proc in local_engine_manager.processes: + sentinel_to_proc[proc.sentinel] = proc + + # Check if any process terminates + while sentinel_to_proc: + # Wait for any process to terminate + ready_sentinels: list[Any] = connection.wait(sentinel_to_proc) + + # Process any terminated processes + for sentinel in ready_sentinels: + proc = sentinel_to_proc.pop(sentinel) + + # Check if process exited with error + if proc.exitcode != 0: + raise RuntimeError( + f"Process {proc.name} (PID: {proc.pid}) " + f"died with exit code {proc.exitcode}") + except KeyboardInterrupt: + logger.info("Received KeyboardInterrupt, shutting down API servers...") + except Exception as e: + logger.exception("Exception occurred while running API servers: %s", + str(e)) + raise + finally: + logger.info("Terminating remaining processes ...") + api_server_manager.close() + if coordinator: + coordinator.close() + if local_engine_manager: + local_engine_manager.close() + + # Note(rob): shutdown function cannot be a bound method, -# else the gc cannot collect the objedecoupct. -def shutdown(procs: list[Process], input_address: str): +# else the gc cannot collect the object. +def shutdown(procs: list[BaseProcess]): # Shutdown the process. for proc in procs: if proc.is_alive(): @@ -185,12 +456,6 @@ def shutdown(procs: list[Process], input_address: str): if proc.is_alive() and (pid := proc.pid) is not None: kill_process_tree(pid) - # Remove zmq ipc socket files. - if input_address.startswith("ipc://"): - socket_file = input_address[len("ipc://"):] - if os and os.path.exists(socket_file): - os.remove(socket_file) - def bind_kv_cache( kv_caches: dict[str, torch.Tensor],