mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 15:44:57 +08:00
[Perf] API-server scaleout with many-to-many server-engine comms (#17546)
This commit is contained in:
parent
84ec470fca
commit
2dbe8c0774
@ -618,9 +618,11 @@ steps:
|
||||
- vllm/worker/model_runner.py
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- vllm/v1/engine/
|
||||
commands:
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
|
||||
268
tests/entrypoints/test_api_server_process_manager.py
Normal file
268
tests/entrypoints/test_api_server_process_manager.py
Normal file
@ -0,0 +1,268 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import multiprocessing
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.utils import (APIServerProcessManager,
|
||||
wait_for_completion_or_failure)
|
||||
|
||||
# Global variables to control worker behavior
|
||||
WORKER_RUNTIME_SECONDS = 0.5
|
||||
|
||||
|
||||
# Mock implementation of run_api_server_worker
|
||||
def mock_run_api_server_worker(listen_address, sock, args, client_config=None):
|
||||
"""Mock run_api_server_worker that runs for a specific time."""
|
||||
print(f"Mock worker started with client_config: {client_config}")
|
||||
time.sleep(WORKER_RUNTIME_SECONDS)
|
||||
print("Mock worker completed successfully")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_server_args():
|
||||
"""Fixture to provide arguments for APIServerProcessManager."""
|
||||
sock = socket.socket()
|
||||
return {
|
||||
"target_server_fn":
|
||||
mock_run_api_server_worker,
|
||||
"listen_address":
|
||||
"localhost:8000",
|
||||
"sock":
|
||||
sock,
|
||||
"args":
|
||||
"test_args", # Simple string to avoid pickling issues
|
||||
"num_servers":
|
||||
3,
|
||||
"input_addresses": [
|
||||
"tcp://127.0.0.1:5001", "tcp://127.0.0.1:5002",
|
||||
"tcp://127.0.0.1:5003"
|
||||
],
|
||||
"output_addresses": [
|
||||
"tcp://127.0.0.1:6001", "tcp://127.0.0.1:6002",
|
||||
"tcp://127.0.0.1:6003"
|
||||
],
|
||||
"stats_update_address":
|
||||
"tcp://127.0.0.1:7000",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("with_stats_update", [True, False])
|
||||
def test_api_server_process_manager_init(api_server_args, with_stats_update):
|
||||
"""Test initializing the APIServerProcessManager."""
|
||||
# Set the worker runtime to ensure tests complete in reasonable time
|
||||
global WORKER_RUNTIME_SECONDS
|
||||
WORKER_RUNTIME_SECONDS = 0.5
|
||||
|
||||
# Copy the args to avoid mutating the
|
||||
args = api_server_args.copy()
|
||||
|
||||
if not with_stats_update:
|
||||
args.pop("stats_update_address")
|
||||
manager = APIServerProcessManager(**args)
|
||||
|
||||
try:
|
||||
# Verify the manager was initialized correctly
|
||||
assert len(manager.processes) == 3
|
||||
|
||||
# Verify all processes are running
|
||||
for proc in manager.processes:
|
||||
assert proc.is_alive()
|
||||
|
||||
print("Waiting for processes to run...")
|
||||
time.sleep(WORKER_RUNTIME_SECONDS / 2)
|
||||
|
||||
# They should still be alive at this point
|
||||
for proc in manager.processes:
|
||||
assert proc.is_alive()
|
||||
|
||||
finally:
|
||||
# Always clean up the processes
|
||||
print("Cleaning up processes...")
|
||||
manager.close()
|
||||
|
||||
# Give processes time to terminate
|
||||
time.sleep(0.2)
|
||||
|
||||
# Verify all processes were terminated
|
||||
for proc in manager.processes:
|
||||
assert not proc.is_alive()
|
||||
|
||||
|
||||
@patch("vllm.entrypoints.cli.serve.run_api_server_worker",
|
||||
mock_run_api_server_worker)
|
||||
def test_wait_for_completion_or_failure(api_server_args):
|
||||
"""Test that wait_for_completion_or_failure works with failures."""
|
||||
global WORKER_RUNTIME_SECONDS
|
||||
WORKER_RUNTIME_SECONDS = 1.0
|
||||
|
||||
# Create the manager
|
||||
manager = APIServerProcessManager(**api_server_args)
|
||||
|
||||
try:
|
||||
assert len(manager.processes) == 3
|
||||
|
||||
# Create a result capture for the thread
|
||||
result: dict[str, Optional[Exception]] = {"exception": None}
|
||||
|
||||
def run_with_exception_capture():
|
||||
try:
|
||||
wait_for_completion_or_failure(api_server_manager=manager)
|
||||
except Exception as e:
|
||||
result["exception"] = e
|
||||
|
||||
# Start a thread to run wait_for_completion_or_failure
|
||||
wait_thread = threading.Thread(target=run_with_exception_capture,
|
||||
daemon=True)
|
||||
wait_thread.start()
|
||||
|
||||
# Let all processes run for a short time
|
||||
time.sleep(0.2)
|
||||
|
||||
# All processes should still be running
|
||||
assert all(proc.is_alive() for proc in manager.processes)
|
||||
|
||||
# Now simulate a process failure
|
||||
print("Simulating process failure...")
|
||||
manager.processes[0].terminate()
|
||||
|
||||
# Wait for the wait_for_completion_or_failure
|
||||
# to detect and handle the failure
|
||||
# This should trigger it to terminate all other processes
|
||||
wait_thread.join(timeout=1.0)
|
||||
|
||||
# The wait thread should have exited
|
||||
assert not wait_thread.is_alive()
|
||||
|
||||
# Verify that an exception was raised with appropriate error message
|
||||
assert result["exception"] is not None
|
||||
assert "died with exit code" in str(result["exception"])
|
||||
|
||||
# All processes should now be terminated
|
||||
for i, proc in enumerate(manager.processes):
|
||||
assert not proc.is_alive(), f"Process {i} should not be alive"
|
||||
|
||||
finally:
|
||||
manager.close()
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
@pytest.mark.timeout(30)
|
||||
def test_normal_completion(api_server_args):
|
||||
"""Test that wait_for_completion_or_failure works in normal completion."""
|
||||
global WORKER_RUNTIME_SECONDS
|
||||
WORKER_RUNTIME_SECONDS = 0.1
|
||||
|
||||
# Create the manager
|
||||
manager = APIServerProcessManager(**api_server_args)
|
||||
|
||||
try:
|
||||
# Give processes time to terminate
|
||||
# wait for processes to complete
|
||||
remaining_processes = manager.processes.copy()
|
||||
while remaining_processes:
|
||||
for proc in remaining_processes:
|
||||
if not proc.is_alive():
|
||||
remaining_processes.remove(proc)
|
||||
time.sleep(0.1)
|
||||
|
||||
# Verify all processes have terminated
|
||||
for i, proc in enumerate(manager.processes):
|
||||
assert not proc.is_alive(
|
||||
), f"Process {i} still alive after terminate()"
|
||||
|
||||
# Now call wait_for_completion_or_failure
|
||||
# since all processes have already
|
||||
# terminated, it should return immediately
|
||||
# with no error
|
||||
wait_for_completion_or_failure(api_server_manager=manager)
|
||||
|
||||
finally:
|
||||
# Clean up just in case
|
||||
manager.close()
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
@pytest.mark.timeout(30)
|
||||
def test_external_process_monitoring(api_server_args):
|
||||
"""Test that wait_for_completion_or_failure handles additional processes."""
|
||||
global WORKER_RUNTIME_SECONDS
|
||||
WORKER_RUNTIME_SECONDS = 100
|
||||
|
||||
# Create and start the external process
|
||||
# (simulates local_engine_manager or coordinator)
|
||||
spawn_context = multiprocessing.get_context("spawn")
|
||||
external_proc = spawn_context.Process(target=mock_run_api_server_worker,
|
||||
name="MockExternalProcess")
|
||||
external_proc.start()
|
||||
|
||||
# Create the class to simulate a coordinator
|
||||
class MockCoordinator:
|
||||
|
||||
def __init__(self, proc):
|
||||
self.proc = proc
|
||||
|
||||
def close(self):
|
||||
if self.proc.is_alive():
|
||||
self.proc.terminate()
|
||||
self.proc.join(timeout=0.5)
|
||||
|
||||
# Create a mock coordinator with the external process
|
||||
mock_coordinator = MockCoordinator(external_proc)
|
||||
|
||||
# Create the API server manager
|
||||
manager = APIServerProcessManager(**api_server_args)
|
||||
|
||||
try:
|
||||
# Verify manager initialization
|
||||
assert len(manager.processes) == 3
|
||||
|
||||
# Create a result capture for the thread
|
||||
result: dict[str, Optional[Exception]] = {"exception": None}
|
||||
|
||||
def run_with_exception_capture():
|
||||
try:
|
||||
wait_for_completion_or_failure(api_server_manager=manager,
|
||||
coordinator=mock_coordinator)
|
||||
except Exception as e:
|
||||
result["exception"] = e
|
||||
|
||||
# Start a thread to run wait_for_completion_or_failure
|
||||
wait_thread = threading.Thread(target=run_with_exception_capture,
|
||||
daemon=True)
|
||||
wait_thread.start()
|
||||
|
||||
# Terminate the external process to trigger a failure
|
||||
time.sleep(0.2)
|
||||
external_proc.terminate()
|
||||
|
||||
# Wait for the thread to detect the failure
|
||||
wait_thread.join(timeout=1.0)
|
||||
|
||||
# The wait thread should have completed
|
||||
assert not wait_thread.is_alive(
|
||||
), "wait_for_completion_or_failure thread still running"
|
||||
|
||||
# Verify that an exception was raised with appropriate error message
|
||||
assert result["exception"] is not None, "No exception was raised"
|
||||
error_message = str(result["exception"])
|
||||
assert "died with exit code" in error_message, \
|
||||
f"Unexpected error message: {error_message}"
|
||||
assert "MockExternalProcess" in error_message, \
|
||||
f"Error doesn't mention external process: {error_message}"
|
||||
|
||||
# Verify that all API server processes were terminated as a result
|
||||
for i, proc in enumerate(manager.processes):
|
||||
assert not proc.is_alive(
|
||||
), f"API server process {i} was not terminated"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
manager.close()
|
||||
mock_coordinator.close()
|
||||
time.sleep(0.2)
|
||||
@ -28,7 +28,7 @@ from tests.models.utils import TextTextLogprobs
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.entrypoints.cli.serve import ServeSubcommand
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
@ -99,7 +99,8 @@ class RemoteOpenAIServer:
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM's remote OpenAI server.")
|
||||
parser = make_arg_parser(parser)
|
||||
subparsers = parser.add_subparsers(required=False, dest="subparser")
|
||||
parser = ServeSubcommand().subparser_init(subparsers)
|
||||
args = parser.parse_args(["--model", model, *vllm_serve_args])
|
||||
self.host = str(args.host or 'localhost')
|
||||
self.port = int(args.port)
|
||||
|
||||
@ -45,7 +45,6 @@ def make_request(request_id,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
sampling_params=SamplingParams(max_tokens=17),
|
||||
eos_token_id=100,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
@ -38,7 +38,6 @@ def make_request(request_id,
|
||||
sampling_params=SamplingParams(max_tokens=17,
|
||||
prompt_logprobs=prompt_logprobs),
|
||||
eos_token_id=100,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
@ -138,7 +138,6 @@ def create_requests(num_requests: int,
|
||||
multi_modal_placeholders=mm_position,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
arrival_time=0,
|
||||
)
|
||||
requests.append(request)
|
||||
return requests
|
||||
@ -744,7 +743,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
|
||||
|
||||
# No draft or accepted tokens counted yet
|
||||
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is None
|
||||
assert not engine_core_outputs or (
|
||||
engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None)
|
||||
|
||||
# Schedule the speculated tokens for validation
|
||||
output = scheduler.schedule()
|
||||
@ -772,7 +772,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||
engine_core_outputs = scheduler.update_from_output(output,
|
||||
model_runner_output)
|
||||
|
||||
scheduler_stats = engine_core_outputs.scheduler_stats
|
||||
scheduler_stats = engine_core_outputs[0].scheduler_stats \
|
||||
if engine_core_outputs else None
|
||||
if expected[0] == 0:
|
||||
assert scheduler_stats.spec_decoding_stats is None
|
||||
else:
|
||||
@ -843,7 +844,7 @@ def _step_until_done(
|
||||
# We should be in the decode phase now.
|
||||
assert num_scheduled_tokens == 1
|
||||
assert len(output.kv_connector_metadata.requests) == 0
|
||||
ecos = scheduler.update_from_output(output, model_runner_output)
|
||||
ecos = scheduler.update_from_output(output, model_runner_output)[0]
|
||||
all_done = True
|
||||
for eco in ecos.outputs:
|
||||
if eco.finish_reason is None:
|
||||
|
||||
@ -88,7 +88,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
|
||||
assert len(engine_core.scheduler.running) == 4
|
||||
|
||||
# Loop through until they are all done.
|
||||
while len(engine_core.step()[0].outputs) > 0:
|
||||
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
||||
pass
|
||||
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
@ -163,11 +163,11 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
|
||||
req0.request_id = req1.request_id = "test"
|
||||
engine_core.add_request(req0)
|
||||
|
||||
while len(engine_core.step()[0].outputs) > 0:
|
||||
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
||||
pass
|
||||
|
||||
engine_core.add_request(req1)
|
||||
while len(engine_core.step()[0].outputs) > 0:
|
||||
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
||||
pass
|
||||
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
@ -207,7 +207,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
# Loop through until they are all done.
|
||||
while len(engine_core.step()[0].outputs) > 0:
|
||||
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
||||
pass
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
@ -327,7 +327,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
assert scheduler_output.num_scheduled_tokens[1] == 4
|
||||
|
||||
# Batch queue is full. Finish Batch 2. Get first token of req0.
|
||||
output = engine_core.step_with_batch_queue()[0]
|
||||
output = engine_core.step_with_batch_queue()[0].get(0)
|
||||
assert output is not None
|
||||
assert len(output.outputs) == 1
|
||||
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
|
||||
@ -339,7 +339,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
assert scheduler_output.num_scheduled_tokens[0] == 1
|
||||
|
||||
# Batch queue is full. Finish Batch 3. Get first token of req1.
|
||||
output = engine_core.step_with_batch_queue()[0]
|
||||
output = engine_core.step_with_batch_queue()[0].get(0)
|
||||
assert output is not None
|
||||
assert len(output.outputs) == 1
|
||||
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
|
||||
@ -362,7 +362,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
if step % 2 == 0:
|
||||
# Even steps consumes an output.
|
||||
assert output is not None
|
||||
assert len(output.outputs) == 1
|
||||
assert len(output[0].outputs) == 1
|
||||
if req_id in engine_core.scheduler.requests:
|
||||
assert engine_core.scheduler.requests[
|
||||
req_id].num_tokens == expected_num_tokens[req_id]
|
||||
|
||||
171
tests/v1/entrypoints/openai/test_multi_api_servers.py
Normal file
171
tests/v1/entrypoints/openai/test_multi_api_servers.py
Normal file
@ -0,0 +1,171 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
||||
|
||||
DP_SIZE = os.getenv("DP_SIZE", "1")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args():
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--enforce-eager",
|
||||
"--api-server-count",
|
||||
"4",
|
||||
"--data_parallel_size",
|
||||
DP_SIZE,
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(default_server_args):
|
||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_single_completion(client: openai.AsyncOpenAI,
|
||||
model_name: str) -> None:
|
||||
|
||||
async def make_request():
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=10,
|
||||
temperature=1.0)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
|
||||
choice = completion.choices[0]
|
||||
# The exact number of tokens can vary slightly with temperature=1.0,
|
||||
# so we check for a reasonable minimum length.
|
||||
assert len(choice.text) >= 1
|
||||
# Finish reason might not always be 'length' if the model finishes early
|
||||
# or due to other reasons, especially with high temperature.
|
||||
# So, we'll accept 'length' or 'stop'.
|
||||
assert choice.finish_reason in ("length", "stop")
|
||||
|
||||
# Token counts can also vary, so we check they are positive.
|
||||
assert completion.usage.completion_tokens > 0
|
||||
assert completion.usage.prompt_tokens > 0
|
||||
assert completion.usage.total_tokens > 0
|
||||
return completion
|
||||
|
||||
# Test single request
|
||||
result = await make_request()
|
||||
assert result is not None
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Send two bursts of requests
|
||||
num_requests = 100
|
||||
tasks = [make_request() for _ in range(num_requests)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
assert len(results) == num_requests
|
||||
assert all(completion is not None for completion in results)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
tasks = [make_request() for _ in range(num_requests)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
assert len(results) == num_requests
|
||||
assert all(completion is not None for completion in results)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
model_name: str) -> None:
|
||||
prompt = "What is an LLM?"
|
||||
|
||||
async def make_streaming_request():
|
||||
# Perform a non-streaming request to get the expected full output
|
||||
single_completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
single_output = single_completion.choices[0].text
|
||||
|
||||
# Perform the streaming request
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True)
|
||||
chunks: list[str] = []
|
||||
finish_reason_count = 0
|
||||
last_chunk = None
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk.choices[0].text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
last_chunk = chunk # Keep track of the last chunk
|
||||
|
||||
# finish reason should only return in the last block for OpenAI API
|
||||
assert finish_reason_count == 1, (
|
||||
"Finish reason should appear exactly once.")
|
||||
assert last_chunk is not None, (
|
||||
"Stream should have yielded at least one chunk.")
|
||||
assert last_chunk.choices[
|
||||
0].finish_reason == "length", "Finish reason should be 'length'."
|
||||
# Check that the combined text matches the non-streamed version.
|
||||
assert "".join(
|
||||
chunks
|
||||
) == single_output, "Streamed output should match non-streamed output."
|
||||
return True # Indicate success for this request
|
||||
|
||||
# Test single request
|
||||
result = await make_streaming_request()
|
||||
assert result is not None
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Send two bursts of requests
|
||||
num_requests = 100
|
||||
tasks = [make_streaming_request() for _ in range(num_requests)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(
|
||||
results
|
||||
) == num_requests, f"Expected {num_requests} results, got {len(results)}"
|
||||
assert all(results), "Not all streaming requests completed successfully."
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
tasks = [make_streaming_request() for _ in range(num_requests)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(
|
||||
results
|
||||
) == num_requests, f"Expected {num_requests} results, got {len(results)}"
|
||||
assert all(results), "Not all streaming requests completed successfully."
|
||||
@ -43,7 +43,7 @@ def test_basic_lifecycle():
|
||||
# Ensure the request is finished after 1 tokens.
|
||||
assert request.is_finished()
|
||||
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
output = engine_core_outputs.outputs[0]
|
||||
output = engine_core_outputs[0].outputs[0]
|
||||
assert output.finish_reason == FinishReason.LENGTH
|
||||
assert output.kv_transfer_params is not None
|
||||
|
||||
@ -165,7 +165,7 @@ def test_prefix_cache_lifecycle():
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
kv_transfer_params = eco.outputs[0].kv_transfer_params
|
||||
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
|
||||
|
||||
# Ensure we send all block ids, even if there is a cache hit.
|
||||
assert (len(
|
||||
|
||||
@ -61,7 +61,7 @@ def test_basic_lifecycle():
|
||||
# (1c): update_from_output()
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
assert len(engine_core_outputs.outputs) == 0
|
||||
assert not engine_core_outputs or not engine_core_outputs[0].outputs
|
||||
|
||||
# STEP (2):
|
||||
# (2a): schedule(): nothing happens!
|
||||
@ -112,7 +112,7 @@ def test_basic_lifecycle():
|
||||
model_runner_output)
|
||||
scheduler.schedule()
|
||||
|
||||
outputs = engine_core_outputs.outputs
|
||||
outputs = engine_core_outputs[0].outputs
|
||||
assert len(outputs) == 1
|
||||
output = outputs[0]
|
||||
assert output.finish_reason == FinishReason.STOP
|
||||
@ -335,7 +335,7 @@ def test_full_block_prompt():
|
||||
model_runner_output)
|
||||
scheduler.schedule()
|
||||
|
||||
outputs = engine_core_outputs.outputs
|
||||
outputs = engine_core_outputs[0].outputs
|
||||
assert len(outputs) == 1
|
||||
output = outputs[0]
|
||||
assert output.finish_reason == FinishReason.STOP
|
||||
|
||||
@ -153,7 +153,6 @@ def create_request(
|
||||
multi_modal_placeholders=None,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
arrival_time=0,
|
||||
)
|
||||
req.kv_transfer_params = kv_transfer_params
|
||||
return req
|
||||
|
||||
@ -1,24 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import AsyncEngineArgs
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.openai.api_server import run_server
|
||||
from vllm.entrypoints.openai.api_server import (run_server, run_server_worker,
|
||||
setup_server)
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
from vllm.entrypoints.utils import (VLLM_SERVE_PARSER_EPILOG,
|
||||
show_filtered_argument_or_group_from_help)
|
||||
from vllm.executor.multiproc_worker_utils import _add_prefix
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
|
||||
from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
from vllm.v1.engine.core_client import CoreEngineProcManager
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
|
||||
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
|
||||
EngineZmqAddresses, get_engine_client_zmq_addr,
|
||||
wait_for_completion_or_failure,
|
||||
wait_for_engine_startup)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -36,9 +47,12 @@ class ServeSubcommand(CLISubcommand):
|
||||
if hasattr(args, 'model_tag') and args.model_tag is not None:
|
||||
args.model = args.model_tag
|
||||
|
||||
if args.headless:
|
||||
if args.headless or args.api_server_count < 1:
|
||||
run_headless(args)
|
||||
elif args.api_server_count > 1:
|
||||
run_multi_api_server(args)
|
||||
else:
|
||||
# Single API server (this process).
|
||||
uvloop.run(run_server(args))
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
@ -69,6 +83,11 @@ class ServeSubcommand(CLISubcommand):
|
||||
type=int,
|
||||
default=0,
|
||||
help='Starting data parallel rank for secondary nodes.')
|
||||
serve_parser.add_argument('--api-server-count',
|
||||
'-asc',
|
||||
type=int,
|
||||
default=1,
|
||||
help='How many API server processes to run.')
|
||||
serve_parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
@ -91,23 +110,26 @@ def cmd_init() -> list[CLISubcommand]:
|
||||
|
||||
def run_headless(args: argparse.Namespace):
|
||||
|
||||
if args.api_server_count > 1:
|
||||
raise ValueError("api_server_count can't be set in headless mode")
|
||||
|
||||
# Create the EngineConfig.
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
usage_context = UsageContext.OPENAI_API_SERVER
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise RuntimeError("Headless mode is only supported for V1")
|
||||
raise ValueError("Headless mode is only supported for V1")
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
port = engine_args.data_parallel_rpc_port # add to config too
|
||||
input_address = get_tcp_uri(host, port)
|
||||
handshake_address = get_tcp_uri(host, port)
|
||||
|
||||
if local_engine_count <= 0:
|
||||
raise RuntimeError("data_parallel_size_local must be > 0 in "
|
||||
"headless mode")
|
||||
raise ValueError("data_parallel_size_local must be > 0 in "
|
||||
"headless mode")
|
||||
|
||||
# Catch SIGTERM and SIGINT to allow graceful shutdown.
|
||||
def signal_handler(signum, frame):
|
||||
@ -119,7 +141,7 @@ def run_headless(args: argparse.Namespace):
|
||||
|
||||
logger.info(
|
||||
"Launching %d data parallel engine(s) in headless mode, "
|
||||
"with head node address %s.", local_engine_count, input_address)
|
||||
"with head node address %s.", local_engine_count, handshake_address)
|
||||
|
||||
# Create the engines.
|
||||
engine_manager = CoreEngineProcManager(
|
||||
@ -129,7 +151,7 @@ def run_headless(args: argparse.Namespace):
|
||||
local_start_index=0,
|
||||
vllm_config=vllm_config,
|
||||
on_head_node=False,
|
||||
input_address=input_address,
|
||||
handshake_address=handshake_address,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
)
|
||||
@ -139,3 +161,142 @@ def run_headless(args: argparse.Namespace):
|
||||
finally:
|
||||
logger.info("Shutting down.")
|
||||
engine_manager.close()
|
||||
|
||||
|
||||
def run_multi_api_server(args: argparse.Namespace):
|
||||
|
||||
assert not args.headless
|
||||
num_api_servers = args.api_server_count
|
||||
assert num_api_servers > 0
|
||||
|
||||
if num_api_servers > 1:
|
||||
setup_multiprocess_prometheus()
|
||||
|
||||
listen_address, sock = setup_server(args)
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
usage_context = UsageContext.OPENAI_API_SERVER
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if num_api_servers > 1:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError("api_server_count > 1 is only supported for V1")
|
||||
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
|
||||
"with api_server_count > 1")
|
||||
|
||||
if model_config.is_multimodal_model and not (
|
||||
model_config.disable_mm_preprocessor_cache):
|
||||
logger.warning(
|
||||
"Multi-model preprocessor cache will be disabled for"
|
||||
" api_server_count > 1")
|
||||
model_config.disable_mm_preprocessor_cache = True
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
assert parallel_config.data_parallel_rank == 0
|
||||
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
local_only = local_engine_count == dp_size
|
||||
|
||||
# Set up input and output addresses.
|
||||
input_addresses = [
|
||||
get_engine_client_zmq_addr(local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
]
|
||||
output_addresses = [
|
||||
get_engine_client_zmq_addr(local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
]
|
||||
|
||||
addresses = EngineZmqAddresses(
|
||||
inputs=input_addresses,
|
||||
outputs=output_addresses,
|
||||
)
|
||||
|
||||
# Set up coordinator for dp > 1.
|
||||
coordinator = None
|
||||
stats_update_address = None
|
||||
if dp_size > 1:
|
||||
coordinator = DPCoordinator(parallel_config)
|
||||
addresses.coordinator_input, addresses.coordinator_output = (
|
||||
coordinator.get_engine_socket_addresses())
|
||||
stats_update_address = coordinator.get_stats_publish_address()
|
||||
logger.info("Started DP Coordinator process (PID: %d)",
|
||||
coordinator.proc.pid)
|
||||
|
||||
handshake_address = get_engine_client_zmq_addr(
|
||||
local_only, host, parallel_config.data_parallel_rpc_port)
|
||||
|
||||
with zmq_socket_ctx(handshake_address, zmq.ROUTER,
|
||||
bind=True) as handshake_socket:
|
||||
|
||||
# Start local engines.
|
||||
if not local_engine_count:
|
||||
local_engine_manager = None
|
||||
else:
|
||||
local_engine_manager = CoreEngineProcManager(
|
||||
EngineCoreProc.run_engine_core,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
handshake_address=handshake_address,
|
||||
on_head_node=True,
|
||||
local_engine_count=local_engine_count,
|
||||
start_index=0,
|
||||
local_start_index=0)
|
||||
|
||||
# Start API servers using the manager
|
||||
api_server_manager = APIServerProcessManager(
|
||||
target_server_fn=run_api_server_worker_proc,
|
||||
listen_address=listen_address,
|
||||
sock=sock,
|
||||
args=args,
|
||||
num_servers=num_api_servers,
|
||||
input_addresses=input_addresses,
|
||||
output_addresses=output_addresses,
|
||||
stats_update_address=stats_update_address)
|
||||
|
||||
# Wait for engine handshakes to complete.
|
||||
core_engines = [
|
||||
CoreEngine(index=i, local=(i < local_engine_count))
|
||||
for i in range(dp_size)
|
||||
]
|
||||
wait_for_engine_startup(
|
||||
handshake_socket,
|
||||
addresses,
|
||||
core_engines,
|
||||
parallel_config,
|
||||
vllm_config.cache_config,
|
||||
local_engine_manager,
|
||||
coordinator.proc if coordinator else None,
|
||||
)
|
||||
|
||||
# Wait for API servers
|
||||
wait_for_completion_or_failure(
|
||||
api_server_manager=api_server_manager,
|
||||
local_engine_manager=local_engine_manager,
|
||||
coordinator=coordinator)
|
||||
|
||||
|
||||
def run_api_server_worker_proc(listen_address,
|
||||
sock,
|
||||
args,
|
||||
client_config=None,
|
||||
**uvicorn_kwargs) -> None:
|
||||
"""Entrypoint for individual API server worker processes."""
|
||||
|
||||
# Add process-specific prefix to stdout and stderr.
|
||||
from multiprocessing import current_process
|
||||
process_name = current_process().name
|
||||
pid = os.getpid()
|
||||
_add_prefix(sys.stdout, process_name, pid)
|
||||
_add_prefix(sys.stderr, process_name, pid)
|
||||
|
||||
uvloop.run(
|
||||
run_server_worker(listen_address, sock, args, client_config,
|
||||
**uvicorn_kwargs))
|
||||
|
||||
@ -17,7 +17,7 @@ from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from json import JSONDecodeError
|
||||
from typing import Annotated, Optional
|
||||
from typing import Annotated, Any, Optional
|
||||
|
||||
import prometheus_client
|
||||
import regex as re
|
||||
@ -26,6 +26,8 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from prometheus_client import make_asgi_app
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.datastructures import State
|
||||
from starlette.routing import Mount
|
||||
@ -97,6 +99,7 @@ from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
|
||||
is_valid_ipv6_address, set_ulimit)
|
||||
from vllm.v1.metrics.prometheus import get_prometheus_registry
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
@ -142,14 +145,17 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client(
|
||||
args: Namespace) -> AsyncIterator[EngineClient]:
|
||||
args: Namespace,
|
||||
client_config: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncIterator[EngineClient]:
|
||||
|
||||
# Context manager to handle engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, args.disable_frontend_multiprocessing) as engine:
|
||||
engine_args, args.disable_frontend_multiprocessing,
|
||||
client_config) as engine:
|
||||
yield engine
|
||||
|
||||
|
||||
@ -157,6 +163,7 @@ async def build_async_engine_client(
|
||||
async def build_async_engine_client_from_engine_args(
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
client_config: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncIterator[EngineClient]:
|
||||
"""
|
||||
Create EngineClient, either:
|
||||
@ -179,12 +186,16 @@ async def build_async_engine_client_from_engine_args(
|
||||
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
async_llm: Optional[AsyncLLM] = None
|
||||
client_index = client_config.pop(
|
||||
"client_index") if client_config else 0
|
||||
try:
|
||||
async_llm = AsyncLLM.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
disable_log_requests=engine_args.disable_log_requests,
|
||||
disable_log_stats=engine_args.disable_log_stats)
|
||||
disable_log_stats=engine_args.disable_log_stats,
|
||||
client_addresses=client_config,
|
||||
client_index=client_index)
|
||||
|
||||
# Don't keep the dummy data in memory
|
||||
await async_llm.reset_mm_cache()
|
||||
@ -318,22 +329,9 @@ class PrometheusResponse(Response):
|
||||
|
||||
|
||||
def mount_metrics(app: FastAPI):
|
||||
# Lazy import for prometheus multiprocessing.
|
||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||
# before prometheus_client is imported.
|
||||
# See https://prometheus.github.io/client_python/multiprocess/
|
||||
from prometheus_client import (REGISTRY, CollectorRegistry, make_asgi_app,
|
||||
multiprocess)
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
"""Mount prometheus metrics to a FastAPI app."""
|
||||
|
||||
registry = REGISTRY
|
||||
|
||||
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
||||
if prometheus_multiproc_dir_path is not None:
|
||||
logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
|
||||
prometheus_multiproc_dir_path)
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
registry = get_prometheus_registry()
|
||||
|
||||
# `response_class=PrometheusResponse` is needed to return an HTTP response
|
||||
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
|
||||
@ -1256,16 +1254,10 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket:
|
||||
return sock
|
||||
|
||||
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
log_non_default_args(args)
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
def validate_api_server_args(args):
|
||||
valid_tool_parses = ToolParserManager.tool_parsers.keys()
|
||||
if args.enable_auto_tool_choice \
|
||||
and args.tool_call_parser not in valid_tool_parses:
|
||||
and args.tool_call_parser not in valid_tool_parses:
|
||||
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
||||
f"(chose from {{ {','.join(valid_tool_parses)} }})")
|
||||
|
||||
@ -1276,6 +1268,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
f"invalid reasoning parser: {args.reasoning_parser} "
|
||||
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
|
||||
|
||||
|
||||
def setup_server(args):
|
||||
"""Validate API server args, set up signal handler, create socket
|
||||
ready to serve."""
|
||||
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
log_non_default_args(args)
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
validate_api_server_args(args)
|
||||
|
||||
# workaround to make sure that we bind the port before the engine is set up.
|
||||
# This avoids race conditions with ray.
|
||||
# see https://github.com/vllm-project/vllm/issues/8204
|
||||
@ -1292,22 +1297,41 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async with build_async_engine_client(args) as engine_client:
|
||||
addr, port = sock_addr
|
||||
is_ssl = args.ssl_keyfile and args.ssl_certfile
|
||||
host_part = f"[{addr}]" if is_valid_ipv6_address(
|
||||
addr) else addr or "0.0.0.0"
|
||||
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
|
||||
|
||||
return listen_address, sock
|
||||
|
||||
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
"""Run a single-worker API server."""
|
||||
listen_address, sock = setup_server(args)
|
||||
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
|
||||
|
||||
|
||||
async def run_server_worker(listen_address,
|
||||
sock,
|
||||
args,
|
||||
client_config=None,
|
||||
**uvicorn_kwargs) -> None:
|
||||
"""Run a single API server worker."""
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
server_index = client_config.get("client_index", 0) if client_config else 0
|
||||
|
||||
async with build_async_engine_client(args, client_config) as engine_client:
|
||||
app = build_app(args)
|
||||
|
||||
vllm_config = await engine_client.get_vllm_config()
|
||||
await init_app_state(engine_client, vllm_config, app.state, args)
|
||||
|
||||
def _listen_addr(a: str) -> str:
|
||||
if is_valid_ipv6_address(a):
|
||||
return '[' + a + ']'
|
||||
return a or "0.0.0.0"
|
||||
|
||||
is_ssl = args.ssl_keyfile and args.ssl_certfile
|
||||
logger.info("Starting vLLM API server on http%s://%s:%d",
|
||||
"s" if is_ssl else "", _listen_addr(sock_addr[0]),
|
||||
sock_addr[1])
|
||||
|
||||
logger.info("Starting vLLM API server %d on %s", server_index,
|
||||
listen_address)
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
sock=sock,
|
||||
|
||||
@ -229,6 +229,11 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||
self.add_adapter(lora)
|
||||
|
||||
def add_adapter(self, lora_request: LoRARequest) -> bool:
|
||||
# Note that this method is not thread-safe. It may be invoked multiple
|
||||
# times for the same adapter when using multiple API servers.
|
||||
# This is ok because it's currently only called from
|
||||
# the single-threaded core engine loop.
|
||||
|
||||
if lora_request.lora_int_id not in self.list_adapters():
|
||||
# Load the new adapter first to ensure it is actually valid, before
|
||||
# evicting any existing adapters.
|
||||
|
||||
@ -2420,6 +2420,7 @@ def make_zmq_socket(
|
||||
socket_type: Any,
|
||||
bind: Optional[bool] = None,
|
||||
identity: Optional[bytes] = None,
|
||||
linger: Optional[int] = None,
|
||||
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
|
||||
"""Make a ZMQ socket with the proper bind/connect semantics."""
|
||||
|
||||
@ -2439,7 +2440,7 @@ def make_zmq_socket(
|
||||
buf_size = -1 # Use system default buffer size
|
||||
|
||||
if bind is None:
|
||||
bind = socket_type != zmq.PUSH
|
||||
bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB)
|
||||
|
||||
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
|
||||
socket.setsockopt(zmq.RCVHWM, 0)
|
||||
@ -2452,6 +2453,9 @@ def make_zmq_socket(
|
||||
if identity is not None:
|
||||
socket.setsockopt(zmq.IDENTITY, identity)
|
||||
|
||||
if linger is not None:
|
||||
socket.setsockopt(zmq.LINGER, linger)
|
||||
|
||||
# Determine if the path is a TCP socket with an IPv6 address.
|
||||
# Enable IPv6 on the zmq socket if so.
|
||||
scheme, host, _ = split_zmq_path(path)
|
||||
|
||||
@ -45,7 +45,7 @@ class SchedulerInterface(ABC):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
model_runner_output: "ModelRunnerOutput",
|
||||
) -> "EngineCoreOutputs":
|
||||
) -> dict[int, "EngineCoreOutputs"]:
|
||||
"""Update the scheduler state based on the model runner output.
|
||||
|
||||
This method is called after the model runner has processed the scheduled
|
||||
@ -55,7 +55,8 @@ class SchedulerInterface(ABC):
|
||||
for each request.
|
||||
|
||||
Returns:
|
||||
A EngineCoreOutputs object containing the outputs for each request.
|
||||
A dict of client index to EngineCoreOutputs object containing the
|
||||
outputs for each request originating from that client.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -126,6 +127,11 @@ class SchedulerInterface(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_request_counts(self) -> tuple[int, int]:
|
||||
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def make_stats(self) -> Optional["SchedulerStats"]:
|
||||
"""Make a SchedulerStats object for logging.
|
||||
|
||||
@ -58,7 +58,8 @@ class Scheduler(SchedulerInterface):
|
||||
# request ids should be included in the EngineCoreOutputs returned
|
||||
# by update_from_outputs(). This is currently used in the multi-engine
|
||||
# case to track request lifetimes efficiently.
|
||||
self.include_finished_set = include_finished_set
|
||||
self.finished_req_ids_dict: Optional[dict[int, set[str]]] = (
|
||||
defaultdict(set) if include_finished_set else None)
|
||||
|
||||
# Scheduling constraints.
|
||||
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
|
||||
@ -693,7 +694,7 @@ class Scheduler(SchedulerInterface):
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
) -> EngineCoreOutputs:
|
||||
) -> dict[int, EngineCoreOutputs]:
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
spec_token_ids = model_runner_output.spec_token_ids
|
||||
logprobs = model_runner_output.logprobs
|
||||
@ -701,7 +702,7 @@ class Scheduler(SchedulerInterface):
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
|
||||
new_running: list[Request] = []
|
||||
outputs: list[EngineCoreOutput] = []
|
||||
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
||||
|
||||
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
|
||||
@ -797,7 +798,7 @@ class Scheduler(SchedulerInterface):
|
||||
if new_token_ids or kv_transfer_params:
|
||||
|
||||
# Add EngineCoreOutput for this Request.
|
||||
outputs.append(
|
||||
outputs[request.client_index].append(
|
||||
EngineCoreOutput(
|
||||
request_id=req_id,
|
||||
new_token_ids=new_token_ids,
|
||||
@ -828,17 +829,38 @@ class Scheduler(SchedulerInterface):
|
||||
self._cached_reqs_data[req_data.req_id].append(req_data)
|
||||
|
||||
self.running = new_running
|
||||
engine_core_outputs = EngineCoreOutputs(
|
||||
outputs=outputs,
|
||||
scheduler_stats=self.make_stats(spec_decoding_stats),
|
||||
)
|
||||
if self.include_finished_set:
|
||||
#TODO currently sending duplicates here, improve this
|
||||
engine_core_outputs.finished_requests = (
|
||||
scheduler_output.finished_req_ids | self.finished_req_ids)
|
||||
|
||||
# Create EngineCoreOutputs for all clients that have requests with
|
||||
# outputs in this step.
|
||||
engine_core_outputs = {
|
||||
client_index: EngineCoreOutputs(outputs=outs)
|
||||
for client_index, outs in outputs.items()
|
||||
}
|
||||
|
||||
finished_req_ids = self.finished_req_ids_dict
|
||||
if finished_req_ids is not None:
|
||||
# Include ids of requests that finished since last outputs
|
||||
# were sent.
|
||||
for client_index, finished_set in finished_req_ids.items():
|
||||
# Set finished request set in EngineCoreOutputs for this client.
|
||||
if (eco := engine_core_outputs.get(client_index)) is not None:
|
||||
eco.finished_requests = finished_set
|
||||
else:
|
||||
engine_core_outputs[client_index] = EngineCoreOutputs(
|
||||
finished_requests=finished_set)
|
||||
finished_req_ids.clear()
|
||||
|
||||
if engine_core_outputs:
|
||||
# Return stats to only one of the front-ends.
|
||||
next(iter(engine_core_outputs.values())).scheduler_stats = (
|
||||
self.make_stats(spec_decoding_stats))
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
def get_request_counts(self) -> tuple[int, int]:
|
||||
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
||||
return len(self.running), len(self.waiting)
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
self.waiting.append(request)
|
||||
self.requests[request.request_id] = request
|
||||
@ -880,8 +902,11 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
|
||||
self.encoder_cache_manager.free(request)
|
||||
self._cached_reqs_data.pop(request.request_id, None)
|
||||
self.finished_req_ids.add(request.request_id)
|
||||
request_id = request.request_id
|
||||
self._cached_reqs_data.pop(request_id, None)
|
||||
self.finished_req_ids.add(request_id)
|
||||
if self.finished_req_ids_dict is not None:
|
||||
self.finished_req_ids_dict[request.client_index].add(request_id)
|
||||
|
||||
if not delay_free_blocks:
|
||||
self._free_blocks(request)
|
||||
|
||||
@ -44,10 +44,6 @@ class EngineCoreRequest(
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
|
||||
# but this object is currently not playing well with msgspec
|
||||
# due to circular imports and typing we have in data.py
|
||||
|
||||
request_id: str
|
||||
prompt_token_ids: list[int]
|
||||
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
|
||||
@ -59,6 +55,10 @@ class EngineCoreRequest(
|
||||
lora_request: Optional[LoRARequest]
|
||||
cache_salt: Optional[str]
|
||||
|
||||
# Index of the client, used to ensure outputs are sent back to the same
|
||||
# client for this request when scaling out the front-end.
|
||||
client_index: int = 0
|
||||
|
||||
# Used in DP case to indicate which wave of requests this is expected to
|
||||
# belong to, to cover a race condition where the request is sent before
|
||||
# a wave finished notification is received.
|
||||
|
||||
@ -36,6 +36,7 @@ from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
|
||||
setup_default_loggers)
|
||||
from vllm.v1.metrics.prometheus import shutdown_prometheus
|
||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -54,6 +55,8 @@ class AsyncLLM(EngineClient):
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = True,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_index: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Create an AsyncLLM.
|
||||
@ -124,6 +127,8 @@ class AsyncLLM(EngineClient):
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=self.log_stats,
|
||||
client_addresses=client_addresses,
|
||||
client_index=client_index,
|
||||
)
|
||||
if self.stat_loggers:
|
||||
for stat_logger in self.stat_loggers[0]:
|
||||
@ -145,6 +150,8 @@ class AsyncLLM(EngineClient):
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
disable_log_requests: bool = False,
|
||||
disable_log_stats: bool = False,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_index: int = 0,
|
||||
) -> "AsyncLLM":
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
@ -162,6 +169,8 @@ class AsyncLLM(EngineClient):
|
||||
log_requests=not disable_log_requests,
|
||||
log_stats=not disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
client_addresses=client_addresses,
|
||||
client_index=client_index,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -195,6 +204,8 @@ class AsyncLLM(EngineClient):
|
||||
def shutdown(self):
|
||||
"""Shutdown, cleaning up the background proc and IPC."""
|
||||
|
||||
shutdown_prometheus()
|
||||
|
||||
if engine_core := getattr(self, "engine_core", None):
|
||||
engine_core.shutdown()
|
||||
|
||||
@ -398,7 +409,6 @@ class AsyncLLM(EngineClient):
|
||||
# TODO(rob): make into a coroutine and launch it in
|
||||
# background thread once Prometheus overhead is non-trivial.
|
||||
if stat_loggers:
|
||||
assert outputs.scheduler_stats is not None
|
||||
AsyncLLM._record_stats(
|
||||
stat_loggers[outputs.engine_index],
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
@ -422,7 +432,7 @@ class AsyncLLM(EngineClient):
|
||||
@staticmethod
|
||||
def _record_stats(
|
||||
stat_loggers: list[StatLoggerBase],
|
||||
scheduler_stats: SchedulerStats,
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
):
|
||||
"""static so that it can be used from the output_handler task
|
||||
|
||||
252
vllm/v1/engine/coordinator.py
Normal file
252
vllm/v1/engine/coordinator.py
Normal file
@ -0,0 +1,252 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import multiprocessing
|
||||
import time
|
||||
import weakref
|
||||
from typing import Optional
|
||||
|
||||
import msgspec.msgpack
|
||||
import zmq
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket
|
||||
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
|
||||
from vllm.v1.serial_utils import MsgpackDecoder
|
||||
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DPCoordinator:
|
||||
"""Coordinator process used for data-parallel deployments (DP>1).
|
||||
|
||||
Intermediates between multiple DP engine rank processes and one or more
|
||||
front-end API server processes.
|
||||
|
||||
* Collects stats from each DP engine (currently just waiting and running
|
||||
queue lengths), and publishes these to all front-ends for use in
|
||||
load-balancing decisions.
|
||||
|
||||
* Keeps track of the current DP "request wave" number and running state
|
||||
of the engines. This is received from the DP rank 0 engine and published
|
||||
to the front-end processes along with the current load stats.
|
||||
|
||||
The engines alternate between a global running/paused state. The global
|
||||
"request wave" number is a count of the number of times that the workers
|
||||
collectively move from a running state to a paused state. This transition
|
||||
is synchronized via the all-reduce operation performed in the
|
||||
DPEngineCoreProc._has_global_unfinished_reqs method.
|
||||
|
||||
* Broadcasts the START_DP_WAVE message to engines to move them from paused
|
||||
to running state when one engine receives a new request. This can happen
|
||||
in two cases:
|
||||
1) A front-end sending a new request while the engines are paused will
|
||||
concurrently notify the coordinator.
|
||||
2) An engine receiving a request for a stale request wave while in paused
|
||||
state will notify the coordinator.
|
||||
|
||||
Engines will move into running state when receiving a new request or
|
||||
START_DP_WAVE message.
|
||||
"""
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
|
||||
# Assume coordinator is colocated with front-end procs.
|
||||
front_publish_address = get_open_zmq_ipc_path()
|
||||
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
assert dp_size > 1, "Coordinator only used for data parallel"
|
||||
|
||||
local_only = dp_size == parallel_config.data_parallel_size_local
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
back_publish_address = get_engine_client_zmq_addr(local_only, host)
|
||||
back_output_address = get_engine_client_zmq_addr(local_only, host)
|
||||
|
||||
context = get_mp_context()
|
||||
self.proc: multiprocessing.Process = context.Process(
|
||||
target=CoordinatorProc.run_coordinator,
|
||||
name="VLLM_DP_Coordinator",
|
||||
kwargs={
|
||||
"engine_count": parallel_config.data_parallel_size,
|
||||
"front_publish_address": front_publish_address,
|
||||
"back_output_address": back_output_address,
|
||||
"back_publish_address": back_publish_address,
|
||||
},
|
||||
daemon=True)
|
||||
self.proc.start()
|
||||
|
||||
self.stats_publish_address = front_publish_address
|
||||
self.coord_in_address = back_publish_address
|
||||
self.coord_out_address = back_output_address
|
||||
self._finalizer = weakref.finalize(self, shutdown, [self.proc])
|
||||
|
||||
def get_stats_publish_address(self) -> str:
|
||||
return self.stats_publish_address
|
||||
|
||||
def get_engine_socket_addresses(self) -> tuple[str, str]:
|
||||
"""Returns tuple of ZMQ input address, output address."""
|
||||
return self.coord_in_address, self.coord_out_address
|
||||
|
||||
def close(self):
|
||||
self._finalizer()
|
||||
|
||||
|
||||
class EngineState:
|
||||
|
||||
def __init__(self):
|
||||
self.request_counts = [0, 0] # [waiting, running]
|
||||
|
||||
|
||||
class CoordinatorProc:
|
||||
|
||||
def __init__(self, engine_count: int):
|
||||
|
||||
self.ctx = zmq.Context()
|
||||
|
||||
self.engines = [EngineState() for _ in range(engine_count)]
|
||||
|
||||
self.current_wave = 0
|
||||
self.engines_running = False
|
||||
self.stats_changed = False
|
||||
|
||||
@staticmethod
|
||||
def run_coordinator(
|
||||
engine_count: int,
|
||||
front_publish_address: str,
|
||||
back_output_address: str,
|
||||
back_publish_address: str,
|
||||
):
|
||||
coordinator = CoordinatorProc(engine_count=engine_count)
|
||||
try:
|
||||
coordinator.process_input_socket(
|
||||
front_publish_address,
|
||||
back_output_address,
|
||||
back_publish_address,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("DP Coordinator process exiting")
|
||||
|
||||
def process_input_socket(self, front_publish_address: str,
|
||||
back_output_address: str,
|
||||
back_publish_address: str):
|
||||
|
||||
decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||
|
||||
with make_zmq_socket(
|
||||
path=front_publish_address, # IPC
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.XPUB,
|
||||
bind=True,
|
||||
) as publish_front, make_zmq_socket(
|
||||
path=back_output_address, # IPC or TCP
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.PULL,
|
||||
bind=True,
|
||||
) as output_back, make_zmq_socket(
|
||||
path=back_publish_address, # IPC or TCP
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.XPUB,
|
||||
bind=True,
|
||||
) as publish_back:
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(publish_front, zmq.POLLIN)
|
||||
poller.register(output_back, zmq.POLLIN)
|
||||
last_publish_time = 0
|
||||
while True:
|
||||
elapsed = int(time.time() * 1000) - last_publish_time
|
||||
# Send at 100 ms interval if the stats have changed,
|
||||
# or otherwise every 3 seconds.
|
||||
wait_for = 100 if self.stats_changed else 3000
|
||||
events = poller.poll(timeout=max(0, wait_for - elapsed))
|
||||
if not events:
|
||||
# Poller timeout - publish current stats to front-ends.
|
||||
engine_req_counts_list = self._get_engine_counts()
|
||||
to_publish = (engine_req_counts_list, self.current_wave,
|
||||
self.engines_running)
|
||||
publish_front.send(msgspec.msgpack.encode(to_publish))
|
||||
last_publish_time = int(time.time() * 1000)
|
||||
self.stats_changed = False
|
||||
continue
|
||||
|
||||
events = dict(events)
|
||||
|
||||
if publish_front in events:
|
||||
buffer = publish_front.recv()
|
||||
if buffer == b'\x01':
|
||||
# Ignore subscription messages.
|
||||
continue
|
||||
|
||||
# We received a message on the front-end XPUB socket,
|
||||
# from an API server sending a new request while the
|
||||
# engines are paused, so that we can wake the other
|
||||
# engines.
|
||||
engine_to_exclude, wave = msgspec.msgpack.decode(buffer)
|
||||
if wave < self.current_wave:
|
||||
# If the wave number is stale, ensure the message is
|
||||
# handled by all the engines.
|
||||
engine_to_exclude = None
|
||||
if not self.engines_running:
|
||||
self.engines_running = True
|
||||
self.stats_changed = True
|
||||
self._send_start_wave(publish_back, self.current_wave,
|
||||
engine_to_exclude)
|
||||
|
||||
if output_back in events:
|
||||
# We received a message from one of the engines.
|
||||
|
||||
buffer = output_back.recv()
|
||||
outputs: EngineCoreOutputs = decoder.decode(buffer)
|
||||
|
||||
assert not outputs.outputs
|
||||
assert outputs.utility_output is None
|
||||
|
||||
eng_index = outputs.engine_index
|
||||
if outputs.scheduler_stats:
|
||||
# 1. Updated request load stats - update our local
|
||||
# state with these.
|
||||
stats = self.engines[eng_index].request_counts
|
||||
stats[0] = outputs.scheduler_stats.num_waiting_reqs
|
||||
stats[1] = outputs.scheduler_stats.num_running_reqs
|
||||
self.stats_changed = True
|
||||
|
||||
if (wave := outputs.wave_complete) is not None:
|
||||
# 2. Notification from rank 0 engine that we've
|
||||
# moved into the global paused state
|
||||
# (engines_running==False)
|
||||
if self.current_wave <= wave:
|
||||
logger.debug("Moving DP wave from %d to %d.",
|
||||
self.current_wave, wave)
|
||||
self.current_wave = wave + 1
|
||||
self.engines_running = False
|
||||
self.stats_changed = True
|
||||
elif (wave := outputs.start_wave) is not None and (
|
||||
wave > self.current_wave or
|
||||
(wave == self.current_wave
|
||||
and not self.engines_running)):
|
||||
# 3. The engine received request for a non-current wave
|
||||
# so we must ensure that other engines progress to the
|
||||
# next wave (race condition handling).
|
||||
logger.debug(
|
||||
"Starting wave %d after notification of "
|
||||
"stale wave request from engine.", wave)
|
||||
self.current_wave = wave
|
||||
self.engines_running = True
|
||||
self.stats_changed = True
|
||||
self._send_start_wave(publish_back, wave, eng_index)
|
||||
|
||||
@staticmethod
|
||||
def _send_start_wave(socket: zmq.Socket, wave: int,
|
||||
exclude_engine_index: Optional[int]):
|
||||
"""Broadcast the START_DP_WAVE message to all the engines.
|
||||
It includes the current wave number and index of engine which
|
||||
has already received a request with this wave number and so doesn't
|
||||
require additional notification.
|
||||
"""
|
||||
wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index))
|
||||
socket.send_multipart(
|
||||
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
|
||||
|
||||
def _get_engine_counts(self) -> list[list[int]]:
|
||||
"""Return list of [waiting, running] count lists for each engine."""
|
||||
return [e.request_counts for e in self.engines]
|
||||
@ -7,6 +7,7 @@ import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from concurrent.futures import Future
|
||||
from contextlib import ExitStack
|
||||
from inspect import isclass, signature
|
||||
from logging import DEBUG
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
@ -22,7 +23,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx
|
||||
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname
|
||||
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
|
||||
unify_kv_cache_configs)
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
@ -33,10 +34,12 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -211,7 +214,7 @@ class EngineCore:
|
||||
# Re-raise exception
|
||||
raise err
|
||||
|
||||
def step(self) -> tuple[EngineCoreOutputs, bool]:
|
||||
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
|
||||
"""Schedule, execute, and make output.
|
||||
|
||||
Returns tuple of outputs and a flag indicating whether the model
|
||||
@ -221,10 +224,7 @@ class EngineCore:
|
||||
# Check for any requests remaining in the scheduler - unfinished,
|
||||
# or finished and not yet removed from the batch.
|
||||
if not self.scheduler.has_requests():
|
||||
return EngineCoreOutputs(
|
||||
outputs=[],
|
||||
scheduler_stats=self.scheduler.make_stats(),
|
||||
), False
|
||||
return {}, False
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
model_output = self.execute_model(scheduler_output)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
@ -234,7 +234,7 @@ class EngineCore:
|
||||
scheduler_output.total_num_scheduled_tokens > 0)
|
||||
|
||||
def step_with_batch_queue(
|
||||
self) -> tuple[Optional[EngineCoreOutputs], bool]:
|
||||
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
|
||||
"""Schedule and execute batches with the batch queue.
|
||||
Note that if nothing to output in this step, None is returned.
|
||||
|
||||
@ -276,8 +276,8 @@ class EngineCore:
|
||||
# Blocking until the first result is available.
|
||||
model_output = future.result()
|
||||
self.batch_queue.task_done()
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output)
|
||||
engine_core_outputs = (self.scheduler.update_from_output(
|
||||
scheduler_output, model_output))
|
||||
|
||||
return engine_core_outputs, scheduled_batch
|
||||
|
||||
@ -362,7 +362,7 @@ class EngineCoreProc(EngineCore):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
input_address: str,
|
||||
handshake_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
engine_index: int = 0,
|
||||
@ -375,65 +375,70 @@ class EngineCoreProc(EngineCore):
|
||||
# Create input socket.
|
||||
input_ctx = zmq.Context()
|
||||
identity = engine_index.to_bytes(length=2, byteorder="little")
|
||||
input_socket = make_zmq_socket(input_ctx,
|
||||
input_address,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
bind=False)
|
||||
try:
|
||||
with make_zmq_socket(input_ctx,
|
||||
handshake_address,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
linger=5000,
|
||||
bind=False) as handshake_socket:
|
||||
|
||||
# Register engine with front-end.
|
||||
output_address = self.startup_handshake(
|
||||
input_socket, on_head_node, vllm_config.parallel_config)
|
||||
addresses = self.startup_handshake(handshake_socket, on_head_node,
|
||||
vllm_config.parallel_config)
|
||||
self.client_count = len(addresses.outputs)
|
||||
|
||||
# Update config which may have changed from the handshake.
|
||||
vllm_config.__post_init__()
|
||||
|
||||
# Set up data parallel environment.
|
||||
self.has_coordinator = addresses.coordinator_output is not None
|
||||
self._init_data_parallel(vllm_config)
|
||||
|
||||
# Initialize engine core and model.
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
executor_fail_callback)
|
||||
|
||||
self.engine_index = engine_index
|
||||
self.step_fn = (self.step if self.batch_queue is None else
|
||||
self.step_with_batch_queue)
|
||||
self.engines_running = False
|
||||
self.last_counts = (0, 0)
|
||||
|
||||
# Send ready message.
|
||||
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
||||
input_socket.send(
|
||||
handshake_socket.send(
|
||||
msgspec.msgpack.encode({
|
||||
"status": "READY",
|
||||
"local": on_head_node,
|
||||
"num_gpu_blocks": num_gpu_blocks,
|
||||
}))
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
self.input_queue = input_queue
|
||||
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
|
||||
threading.Thread(target=self.process_input_socket,
|
||||
args=(input_socket, ),
|
||||
daemon=True).start()
|
||||
input_socket = None
|
||||
self.output_thread = threading.Thread(
|
||||
target=self.process_output_socket,
|
||||
args=(output_address, engine_index),
|
||||
daemon=True)
|
||||
self.output_thread.start()
|
||||
finally:
|
||||
if input_socket is not None:
|
||||
input_socket.close(linger=0)
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
self.input_queue = input_queue
|
||||
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
|
||||
bytes]]()
|
||||
threading.Thread(target=self.process_input_sockets,
|
||||
args=(addresses.inputs, addresses.coordinator_input,
|
||||
identity),
|
||||
daemon=True).start()
|
||||
self.output_thread = threading.Thread(
|
||||
target=self.process_output_sockets,
|
||||
args=(addresses.outputs, addresses.coordinator_output,
|
||||
engine_index),
|
||||
daemon=True)
|
||||
self.output_thread.start()
|
||||
|
||||
@staticmethod
|
||||
def startup_handshake(input_socket: zmq.Socket, on_head_node: bool,
|
||||
parallel_config: ParallelConfig) -> str:
|
||||
def startup_handshake(
|
||||
handshake_socket: zmq.Socket, on_head_node: bool,
|
||||
parallel_config: ParallelConfig) -> EngineZmqAddresses:
|
||||
|
||||
# Send registration message.
|
||||
input_socket.send(
|
||||
handshake_socket.send(
|
||||
msgspec.msgpack.encode({
|
||||
"status": "HELLO",
|
||||
"local": on_head_node,
|
||||
@ -441,22 +446,20 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
# Receive initialization message.
|
||||
logger.info("Waiting for init message from front-end.")
|
||||
if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000):
|
||||
if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
|
||||
raise RuntimeError("Did not receive response from front-end "
|
||||
f"process within {HANDSHAKE_TIMEOUT_MINS} "
|
||||
f"minutes")
|
||||
init_bytes = input_socket.recv()
|
||||
init_message = msgspec.msgpack.decode(init_bytes)
|
||||
init_bytes = handshake_socket.recv()
|
||||
init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
|
||||
init_bytes, type=EngineHandshakeMetadata)
|
||||
logger.debug("Received init message: %s", init_message)
|
||||
|
||||
output_socket_address = init_message["output_socket_address"]
|
||||
#TBD(nick) maybe replace IP with configured head node address
|
||||
|
||||
received_parallel_config = init_message["parallel_config"]
|
||||
received_parallel_config = init_message.parallel_config
|
||||
for key, value in received_parallel_config.items():
|
||||
setattr(parallel_config, key, value)
|
||||
|
||||
return output_socket_address
|
||||
return init_message.addresses
|
||||
|
||||
@staticmethod
|
||||
def run_engine_core(*args,
|
||||
@ -528,7 +531,7 @@ class EngineCoreProc(EngineCore):
|
||||
"""Exits when an engine step needs to be performed."""
|
||||
|
||||
waited = False
|
||||
while not self.engines_running and not (self.scheduler.has_requests()):
|
||||
while not self.engines_running and not self.scheduler.has_requests():
|
||||
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
|
||||
logger.debug("EngineCore waiting for work.")
|
||||
waited = True
|
||||
@ -549,8 +552,8 @@ class EngineCoreProc(EngineCore):
|
||||
# Step the engine core.
|
||||
outputs, model_executed = self.step_fn()
|
||||
# Put EngineCoreOutputs into the output queue.
|
||||
if outputs is not None:
|
||||
self.output_queue.put_nowait(outputs)
|
||||
for output in (outputs.items() if outputs else ()):
|
||||
self.output_queue.put_nowait(output)
|
||||
|
||||
return model_executed
|
||||
|
||||
@ -563,7 +566,7 @@ class EngineCoreProc(EngineCore):
|
||||
elif request_type == EngineCoreRequestType.ABORT:
|
||||
self.abort_requests(request)
|
||||
elif request_type == EngineCoreRequestType.UTILITY:
|
||||
call_id, method_name, args = request
|
||||
client_idx, call_id, method_name, args = request
|
||||
output = UtilityOutput(call_id)
|
||||
try:
|
||||
method = getattr(self, method_name)
|
||||
@ -574,7 +577,7 @@ class EngineCoreProc(EngineCore):
|
||||
output.failure_message = (f"Call to {method_name} method"
|
||||
f" failed: {str(e)}")
|
||||
self.output_queue.put_nowait(
|
||||
EngineCoreOutputs(utility_output=output))
|
||||
(client_idx, EngineCoreOutputs(utility_output=output)))
|
||||
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
|
||||
raise RuntimeError("Executor failed.")
|
||||
else:
|
||||
@ -607,27 +610,68 @@ class EngineCoreProc(EngineCore):
|
||||
logger.fatal("vLLM shutdown signal from EngineCore failed "
|
||||
"to send. Please report this issue.")
|
||||
|
||||
def process_input_socket(self, input_socket: zmq.Socket):
|
||||
def process_input_sockets(self, input_addresses: list[str],
|
||||
coord_input_address: Optional[str],
|
||||
identity: bytes):
|
||||
"""Input socket IO thread."""
|
||||
|
||||
# Msgpack serialization decoding.
|
||||
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||
generic_decoder = MsgpackDecoder()
|
||||
|
||||
while True:
|
||||
# (RequestType, RequestData)
|
||||
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
|
||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||
with ExitStack() as stack, zmq.Context() as ctx:
|
||||
input_sockets = [
|
||||
stack.enter_context(
|
||||
make_zmq_socket(ctx,
|
||||
input_address,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
bind=False))
|
||||
for input_address in input_addresses
|
||||
]
|
||||
if coord_input_address is None:
|
||||
coord_socket = None
|
||||
else:
|
||||
coord_socket = stack.enter_context(
|
||||
make_zmq_socket(ctx,
|
||||
coord_input_address,
|
||||
zmq.XSUB,
|
||||
identity=identity,
|
||||
bind=False))
|
||||
# Send subscription message to coordinator.
|
||||
coord_socket.send(b'\x01')
|
||||
|
||||
# Deserialize the request data.
|
||||
decoder = add_request_decoder if (
|
||||
request_type == EngineCoreRequestType.ADD) else generic_decoder
|
||||
request = decoder.decode(data_frames)
|
||||
# Register sockets with poller.
|
||||
poller = zmq.Poller()
|
||||
for input_socket in input_sockets:
|
||||
# Send initial message to each input socket - this is required
|
||||
# before the front-end ROUTER socket can send input messages
|
||||
# back to us.
|
||||
input_socket.send(b'')
|
||||
poller.register(input_socket, zmq.POLLIN)
|
||||
if coord_socket is not None:
|
||||
poller.register(coord_socket, zmq.POLLIN)
|
||||
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
while True:
|
||||
for input_socket, _ in poller.poll():
|
||||
# (RequestType, RequestData)
|
||||
type_frame, *data_frames = input_socket.recv_multipart(
|
||||
copy=False)
|
||||
request_type = EngineCoreRequestType(
|
||||
bytes(type_frame.buffer))
|
||||
|
||||
def process_output_socket(self, output_path: str, engine_index: int):
|
||||
# Deserialize the request data.
|
||||
decoder = add_request_decoder if (
|
||||
request_type
|
||||
== EngineCoreRequestType.ADD) else generic_decoder
|
||||
request = decoder.decode(data_frames)
|
||||
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
|
||||
def process_output_sockets(self, output_paths: list[str],
|
||||
coord_output_path: Optional[str],
|
||||
engine_index: int):
|
||||
"""Output socket IO thread."""
|
||||
|
||||
# Msgpack serialization encoding.
|
||||
@ -641,30 +685,49 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
# We must set linger to ensure the ENGINE_CORE_DEAD
|
||||
# message is sent prior to closing the socket.
|
||||
with zmq_socket_ctx(output_path, zmq.constants.PUSH,
|
||||
linger=4000) as socket:
|
||||
with ExitStack() as stack, zmq.Context() as ctx:
|
||||
sockets = [
|
||||
stack.enter_context(
|
||||
make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
|
||||
for output_path in output_paths
|
||||
]
|
||||
coord_socket = stack.enter_context(
|
||||
make_zmq_socket(
|
||||
ctx, coord_output_path, zmq.PUSH, bind=False,
|
||||
linger=4000)) if coord_output_path is not None else None
|
||||
max_reuse_bufs = len(sockets) + 1
|
||||
|
||||
while True:
|
||||
outputs = self.output_queue.get()
|
||||
if outputs == EngineCoreProc.ENGINE_CORE_DEAD:
|
||||
socket.send(outputs, copy=False)
|
||||
output = self.output_queue.get()
|
||||
if output == EngineCoreProc.ENGINE_CORE_DEAD:
|
||||
for socket in sockets:
|
||||
socket.send(output)
|
||||
break
|
||||
assert not isinstance(outputs, bytes)
|
||||
assert not isinstance(output, bytes)
|
||||
client_index, outputs = output
|
||||
outputs.engine_index = engine_index
|
||||
|
||||
if client_index == -1:
|
||||
# Don't reuse buffer for coordinator message
|
||||
# which will be very small.
|
||||
assert coord_socket is not None
|
||||
coord_socket.send_multipart(encoder.encode(outputs))
|
||||
continue
|
||||
|
||||
# Reclaim buffers that zmq is finished with.
|
||||
while pending and pending[-1][0].done:
|
||||
reuse_buffers.append(pending.pop()[2])
|
||||
|
||||
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
|
||||
buffers = encoder.encode_into(outputs, buffer)
|
||||
tracker = socket.send_multipart(buffers,
|
||||
copy=False,
|
||||
track=True)
|
||||
tracker = sockets[client_index].send_multipart(buffers,
|
||||
copy=False,
|
||||
track=True)
|
||||
if not tracker.done:
|
||||
ref = outputs if len(buffers) > 1 else None
|
||||
pending.appendleft((tracker, ref, buffer))
|
||||
elif len(reuse_buffers) < 2:
|
||||
# Keep at most 2 buffers to reuse.
|
||||
elif len(reuse_buffers) < max_reuse_bufs:
|
||||
# Limit the number of buffers to reuse.
|
||||
reuse_buffers.append(buffer)
|
||||
|
||||
|
||||
@ -676,7 +739,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
input_address: str,
|
||||
handshake_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
):
|
||||
@ -691,10 +754,11 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
# Counts forward-passes of the model so that we can synchronize
|
||||
# finished with DP peers every N steps.
|
||||
self.counter = 0
|
||||
self.current_wave = 0
|
||||
|
||||
# Initialize the engine.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
super().__init__(vllm_config, on_head_node, input_address,
|
||||
super().__init__(vllm_config, on_head_node, handshake_address,
|
||||
executor_class, log_stats, dp_rank)
|
||||
|
||||
def _init_data_parallel(self, vllm_config: VllmConfig):
|
||||
@ -726,7 +790,6 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
self.dp_rank = dp_rank
|
||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||
self.current_wave = 0
|
||||
|
||||
def shutdown(self):
|
||||
super().shutdown()
|
||||
@ -734,22 +797,23 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
|
||||
def add_request(self, request: EngineCoreRequest):
|
||||
if request.current_wave != self.current_wave:
|
||||
if self.has_coordinator and request.current_wave != self.current_wave:
|
||||
if request.current_wave > self.current_wave:
|
||||
self.current_wave = request.current_wave
|
||||
elif not self.engines_running:
|
||||
# Request received for an already-completed wave, notify
|
||||
# front-end that we need to start the next one.
|
||||
self.output_queue.put_nowait(
|
||||
EngineCoreOutputs(start_wave=self.current_wave))
|
||||
(-1, EngineCoreOutputs(start_wave=self.current_wave)))
|
||||
|
||||
super().add_request(request)
|
||||
|
||||
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
if request_type == EngineCoreRequestType.START_DP_WAVE:
|
||||
new_wave: int = request
|
||||
if new_wave >= self.current_wave:
|
||||
new_wave, exclude_eng_index = request
|
||||
if exclude_eng_index != self.engine_index and (
|
||||
new_wave >= self.current_wave):
|
||||
self.current_wave = new_wave
|
||||
if not self.engines_running:
|
||||
logger.debug("EngineCore starting idle loop for wave %d.",
|
||||
@ -758,6 +822,18 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
else:
|
||||
super()._handle_client_request(request_type, request)
|
||||
|
||||
def _maybe_publish_request_counts(self):
|
||||
if not self.has_coordinator:
|
||||
return
|
||||
|
||||
# Publish our request counts (if they've changed).
|
||||
counts = self.scheduler.get_request_counts()
|
||||
if counts != self.last_counts:
|
||||
self.last_counts = counts
|
||||
stats = SchedulerStats(*counts)
|
||||
self.output_queue.put_nowait(
|
||||
(-1, EngineCoreOutputs(scheduler_stats=stats)))
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore for data parallel case."""
|
||||
|
||||
@ -768,6 +844,8 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
# 2) Step the engine core.
|
||||
executed = self._process_engine_step()
|
||||
self._maybe_publish_request_counts()
|
||||
|
||||
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
|
||||
if not executed:
|
||||
if not local_unfinished_reqs and not self.engines_running:
|
||||
@ -788,7 +866,8 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
logger.debug("Wave %d finished, pausing engine loop.",
|
||||
self.current_wave)
|
||||
self.output_queue.put_nowait(
|
||||
EngineCoreOutputs(wave_complete=self.current_wave))
|
||||
(-1,
|
||||
EngineCoreOutputs(wave_complete=self.current_wave)))
|
||||
self.current_wave += 1
|
||||
|
||||
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import queue
|
||||
import sys
|
||||
import uuid
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
@ -9,26 +10,28 @@ from collections import deque
|
||||
from collections.abc import Awaitable, Sequence
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
import msgspec
|
||||
import msgspec.msgpack
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import (get_open_port, get_open_zmq_inproc_path,
|
||||
get_open_zmq_ipc_path, get_tcp_uri, make_zmq_socket)
|
||||
from vllm.utils import (get_open_zmq_inproc_path, make_zmq_socket,
|
||||
zmq_socket_ctx)
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
||||
from vllm.v1.utils import CoreEngineProcManager
|
||||
from vllm.v1.utils import (CoreEngine, CoreEngineProcManager,
|
||||
EngineZmqAddresses, get_engine_client_zmq_addr,
|
||||
wait_for_engine_startup)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -36,8 +39,6 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]]
|
||||
|
||||
_R = TypeVar('_R') # Return type for collective_rpc
|
||||
|
||||
STARTUP_POLL_PERIOD_MS = 10000
|
||||
|
||||
|
||||
class EngineCoreClient(ABC):
|
||||
"""
|
||||
@ -207,7 +208,7 @@ class InprocClient(EngineCoreClient):
|
||||
|
||||
def get_output(self) -> EngineCoreOutputs:
|
||||
outputs, _ = self.engine_core.step()
|
||||
return outputs
|
||||
return outputs.get(0) or EngineCoreOutputs()
|
||||
|
||||
def add_request(self, request: EngineCoreRequest) -> None:
|
||||
self.engine_core.add_request(request)
|
||||
@ -266,24 +267,6 @@ class InprocClient(EngineCoreClient):
|
||||
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
|
||||
class CoreEngineState(Enum):
|
||||
NEW = auto()
|
||||
CONNECTED = auto()
|
||||
READY = auto()
|
||||
|
||||
|
||||
class CoreEngine:
|
||||
"""One per data parallel rank."""
|
||||
|
||||
def __init__(self, index: int = 0, local: bool = True):
|
||||
self.local = local
|
||||
self.index = index
|
||||
self.identity = index.to_bytes(length=2, byteorder="little")
|
||||
|
||||
self.state = CoreEngineState.NEW
|
||||
self.num_reqs_in_flight = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackgroundResources:
|
||||
"""Used as a finalizer for clean shutdown, avoiding
|
||||
@ -291,9 +274,12 @@ class BackgroundResources:
|
||||
|
||||
ctx: Union[zmq.Context]
|
||||
local_engine_manager: Optional[CoreEngineProcManager] = None
|
||||
coordinator: Optional[DPCoordinator] = None
|
||||
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
||||
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
||||
first_req_send_socket: Optional[zmq.asyncio.Socket] = None
|
||||
output_queue_task: Optional[asyncio.Task] = None
|
||||
stats_update_task: Optional[asyncio.Task] = None
|
||||
shutdown_path: Optional[str] = None
|
||||
|
||||
# Set if any of the engines are dead. Here so that the output
|
||||
@ -306,16 +292,21 @@ class BackgroundResources:
|
||||
self.engine_dead = True
|
||||
if self.local_engine_manager is not None:
|
||||
self.local_engine_manager.close()
|
||||
if self.coordinator is not None:
|
||||
self.coordinator.close()
|
||||
|
||||
if self.output_queue_task is not None:
|
||||
self.output_queue_task.cancel()
|
||||
if self.stats_update_task is not None:
|
||||
self.stats_update_task.cancel()
|
||||
|
||||
# ZMQ context termination can hang if the sockets
|
||||
# aren't explicitly closed first.
|
||||
if self.output_socket is not None:
|
||||
self.output_socket.close(linger=0)
|
||||
if self.input_socket is not None:
|
||||
self.input_socket.close(linger=0)
|
||||
for socket in (self.output_socket, self.input_socket,
|
||||
self.first_req_send_socket):
|
||||
if socket is not None:
|
||||
socket.close(linger=0)
|
||||
|
||||
if self.shutdown_path is not None:
|
||||
# We must ensure that the sync output socket is
|
||||
# closed cleanly in its own thread.
|
||||
@ -350,6 +341,7 @@ class MPClient(EngineCoreClient):
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
# Serialization setup.
|
||||
@ -369,8 +361,8 @@ class MPClient(EngineCoreClient):
|
||||
try:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
start_index = parallel_config.data_parallel_rank
|
||||
local_start_index = parallel_config.data_parallel_rank_local
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
|
||||
# SPMD mode is where there is an LLM instance per DP rank and
|
||||
# one core engine per LLM, see
|
||||
@ -382,42 +374,53 @@ class MPClient(EngineCoreClient):
|
||||
CoreEngine(index=local_start_index, local=True)
|
||||
]
|
||||
else:
|
||||
assert start_index == 0
|
||||
assert parallel_config.data_parallel_rank == 0
|
||||
local_start_index = 0
|
||||
self.core_engines = [
|
||||
CoreEngine(index=i, local=(i < local_engine_count))
|
||||
for i in range(parallel_config.data_parallel_size)
|
||||
for i in range(dp_size)
|
||||
]
|
||||
|
||||
input_address, output_address = self._get_zmq_addresses(
|
||||
parallel_config, spmd_mode)
|
||||
local_only = spmd_mode or local_engine_count == dp_size
|
||||
|
||||
self.stats_update_address: Optional[str] = None
|
||||
if client_addresses is not None:
|
||||
input_address = client_addresses["input_address"]
|
||||
output_address = client_addresses["output_address"]
|
||||
self.stats_update_address = client_addresses.get(
|
||||
"stats_update_address")
|
||||
else:
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
input_address = get_engine_client_zmq_addr(local_only, host)
|
||||
output_address = get_engine_client_zmq_addr(local_only, host)
|
||||
|
||||
# Create input and output sockets.
|
||||
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
||||
self.ctx, input_address, zmq.ROUTER, bind=True)
|
||||
|
||||
self.resources.output_socket = make_zmq_socket(
|
||||
self.ctx, output_address, zmq.constants.PULL)
|
||||
# Start local engines.
|
||||
if local_engine_count:
|
||||
# In server mode, start_index and local_start_index will
|
||||
# both be 0.
|
||||
self.resources.local_engine_manager = CoreEngineProcManager(
|
||||
EngineCoreProc.run_engine_core,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
input_address=input_address,
|
||||
on_head_node=True,
|
||||
local_engine_count=local_engine_count,
|
||||
start_index=start_index,
|
||||
local_start_index=local_start_index)
|
||||
self.ctx, output_address, zmq.PULL)
|
||||
|
||||
if client_addresses is None:
|
||||
self._init_engines_direct(vllm_config, local_only,
|
||||
local_start_index, input_address,
|
||||
output_address, executor_class,
|
||||
log_stats)
|
||||
coordinator = self.resources.coordinator
|
||||
if coordinator:
|
||||
self.stats_update_address = (
|
||||
coordinator.get_stats_publish_address())
|
||||
|
||||
# Wait for ready messages from each engine on the input socket.
|
||||
identities = set(e.identity for e in self.core_engines)
|
||||
sync_input_socket = zmq.Socket.shadow(self.input_socket)
|
||||
while identities:
|
||||
if not sync_input_socket.poll(timeout=600_000):
|
||||
raise TimeoutError("Timed out waiting for engines to send"
|
||||
"initial message on input socket.")
|
||||
identity, _ = sync_input_socket.recv_multipart()
|
||||
identities.remove(identity)
|
||||
|
||||
self.core_engine = self.core_engines[0]
|
||||
|
||||
# Wait for engine core process(es) to start.
|
||||
self._wait_for_engine_startup(output_address, parallel_config)
|
||||
|
||||
self.utility_results: dict[int, AnyFuture] = {}
|
||||
|
||||
# Request objects which may contain pytorch-allocated tensors
|
||||
@ -430,116 +433,67 @@ class MPClient(EngineCoreClient):
|
||||
if not success:
|
||||
self._finalizer()
|
||||
|
||||
@staticmethod
|
||||
def _get_zmq_addresses(parallel_config: ParallelConfig,
|
||||
spmd_mode: bool) -> tuple[str, str]:
|
||||
"""Returns (input_address, output_address)."""
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool,
|
||||
local_start_index: int, input_address: str,
|
||||
output_address: str,
|
||||
executor_class: type[Executor], log_stats: bool):
|
||||
"""Self-contained client mode, launch engine and coordinator process
|
||||
as needed."""
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
start_index = parallel_config.data_parallel_rank
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
|
||||
if local_engine_count == dp_size or spmd_mode:
|
||||
input_address = get_open_zmq_ipc_path()
|
||||
output_address = get_open_zmq_ipc_path()
|
||||
else:
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
input_port = parallel_config.data_parallel_rpc_port
|
||||
output_port = get_open_port()
|
||||
input_address = get_tcp_uri(host, input_port)
|
||||
output_address = get_tcp_uri(host, output_port)
|
||||
if len(self.core_engines) > 1:
|
||||
self.resources.coordinator = DPCoordinator(parallel_config)
|
||||
|
||||
return input_address, output_address
|
||||
handshake_address = get_engine_client_zmq_addr(
|
||||
local_only, host, parallel_config.data_parallel_rpc_port)
|
||||
|
||||
def _wait_for_engine_startup(self, output_address: str,
|
||||
parallel_config: ParallelConfig):
|
||||
# Get a sync handle to the socket which can be sync or async.
|
||||
sync_input_socket = zmq.Socket.shadow(self.input_socket)
|
||||
with zmq_socket_ctx(handshake_address, zmq.ROUTER,
|
||||
bind=True) as handshake_socket:
|
||||
|
||||
# Wait for engine core process(es) to send ready messages.
|
||||
local_count = parallel_config.data_parallel_size_local
|
||||
remote_count = len(self.core_engines) - local_count
|
||||
# [local, remote] counts
|
||||
conn_pending, start_pending = [local_count, remote_count], [0, 0]
|
||||
# Start local engines.
|
||||
if local_engine_count:
|
||||
# In server mode, start_index and local_start_index will
|
||||
# both be 0.
|
||||
self.resources.local_engine_manager = CoreEngineProcManager(
|
||||
EngineCoreProc.run_engine_core,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
handshake_address=handshake_address,
|
||||
on_head_node=True,
|
||||
local_engine_count=local_engine_count,
|
||||
start_index=start_index,
|
||||
local_start_index=local_start_index)
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(sync_input_socket, zmq.POLLIN)
|
||||
proc_manager = self.resources.local_engine_manager
|
||||
if proc_manager is not None:
|
||||
for sentinel in proc_manager.sentinels():
|
||||
poller.register(sentinel, zmq.POLLIN)
|
||||
while any(conn_pending) or any(start_pending):
|
||||
events = poller.poll(STARTUP_POLL_PERIOD_MS)
|
||||
if not events:
|
||||
if any(conn_pending):
|
||||
logger.debug(
|
||||
"Waiting for %d local, %d remote core engine proc(s) "
|
||||
"to connect.", *conn_pending)
|
||||
if any(start_pending):
|
||||
logger.debug(
|
||||
"Waiting for %d local, %d remote core engine proc(s) "
|
||||
"to start.", *start_pending)
|
||||
continue
|
||||
if len(events) > 1 or events[0][0] != sync_input_socket:
|
||||
# One of the local core processes exited.
|
||||
finished = proc_manager.finished_procs(
|
||||
) if proc_manager else {}
|
||||
raise RuntimeError("Engine core initialization failed. "
|
||||
"See root cause above. "
|
||||
f"Failed core proc(s): {finished}")
|
||||
# Wait for engine core process(es) to start.
|
||||
self._wait_for_engine_startup(handshake_socket, input_address,
|
||||
output_address)
|
||||
|
||||
# Receive HELLO and READY messages from the input socket.
|
||||
eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart()
|
||||
eng_index = int.from_bytes(eng_identity, byteorder="little")
|
||||
engine = next(
|
||||
(e for e in self.core_engines if e.identity == eng_identity),
|
||||
None)
|
||||
if engine is None:
|
||||
raise RuntimeError(f"Message from engine with unexpected data "
|
||||
f"parallel rank: {eng_index}")
|
||||
msg = msgspec.msgpack.decode(ready_msg_bytes)
|
||||
status, local = msg["status"], msg["local"]
|
||||
if local != engine.local:
|
||||
raise RuntimeError(f"{status} message from "
|
||||
f"{'local' if local else 'remote'} "
|
||||
f"engine {eng_index}, expected it to be "
|
||||
f"{'local' if engine.local else 'remote'}")
|
||||
def _wait_for_engine_startup(self, handshake_socket: zmq.Socket,
|
||||
input_address: str, output_address: str):
|
||||
addresses = EngineZmqAddresses(
|
||||
inputs=[input_address],
|
||||
outputs=[output_address],
|
||||
)
|
||||
|
||||
if status == "HELLO" and engine.state == CoreEngineState.NEW:
|
||||
coordinator = self.resources.coordinator
|
||||
if coordinator is not None:
|
||||
addresses.coordinator_input, addresses.coordinator_output = (
|
||||
coordinator.get_engine_socket_addresses())
|
||||
|
||||
# Send init message with DP config info.
|
||||
init_message = self.encoder.encode({
|
||||
"output_socket_address": output_address,
|
||||
"parallel_config": {
|
||||
"data_parallel_master_ip":
|
||||
parallel_config.data_parallel_master_ip,
|
||||
"data_parallel_master_port":
|
||||
parallel_config.data_parallel_master_port,
|
||||
"data_parallel_size":
|
||||
parallel_config.data_parallel_size,
|
||||
},
|
||||
})
|
||||
sync_input_socket.send_multipart((eng_identity, *init_message),
|
||||
copy=False)
|
||||
conn_pending[0 if local else 1] -= 1
|
||||
start_pending[0 if local else 1] += 1
|
||||
engine.state = CoreEngineState.CONNECTED
|
||||
elif status == "READY" and (engine.state
|
||||
== CoreEngineState.CONNECTED):
|
||||
# Setup KV cache config with initialization state from
|
||||
# engine core process. Sum values from all engines in DP case.
|
||||
cache_config = self.vllm_config.cache_config
|
||||
num_gpu_blocks = cache_config.num_gpu_blocks or 0
|
||||
num_gpu_blocks += msg['num_gpu_blocks']
|
||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
|
||||
start_pending[0 if local else 1] -= 1
|
||||
engine.state = CoreEngineState.READY
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected {status} message for "
|
||||
f"{'local' if local else 'remote'} engine "
|
||||
f"{eng_index} in {engine.state} state.")
|
||||
|
||||
logger.debug("%s from %s core engine process %s.", status,
|
||||
"local" if local else "remote", eng_index)
|
||||
wait_for_engine_startup(
|
||||
handshake_socket,
|
||||
addresses,
|
||||
self.core_engines,
|
||||
self.vllm_config.parallel_config,
|
||||
self.vllm_config.cache_config,
|
||||
self.resources.local_engine_manager,
|
||||
coordinator.proc if coordinator else None,
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
# Terminate background resources.
|
||||
@ -605,8 +559,8 @@ class SyncMPClient(MPClient):
|
||||
try:
|
||||
shutdown_socket.bind(shutdown_path)
|
||||
poller = zmq.Poller()
|
||||
poller.register(shutdown_socket)
|
||||
poller.register(out_socket)
|
||||
poller.register(shutdown_socket, zmq.POLLIN)
|
||||
poller.register(out_socket, zmq.POLLIN)
|
||||
while True:
|
||||
socks = poller.poll()
|
||||
if not socks:
|
||||
@ -668,7 +622,7 @@ class SyncMPClient(MPClient):
|
||||
future: Future[Any] = Future()
|
||||
self.utility_results[call_id] = future
|
||||
self._send_input(EngineCoreRequestType.UTILITY,
|
||||
(call_id, method, args))
|
||||
(0, call_id, method, args))
|
||||
|
||||
return future.result()
|
||||
|
||||
@ -730,15 +684,21 @@ class SyncMPClient(MPClient):
|
||||
class AsyncMPClient(MPClient):
|
||||
"""Asyncio-compatible client for multi-proc EngineCore."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
|
||||
log_stats: bool):
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_index: int = 0):
|
||||
super().__init__(
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
client_addresses=client_addresses,
|
||||
)
|
||||
|
||||
self.client_index = client_index
|
||||
self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs,
|
||||
Exception]]()
|
||||
try:
|
||||
@ -854,12 +814,13 @@ class AsyncMPClient(MPClient):
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self.utility_results[call_id] = future
|
||||
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
|
||||
(call_id, method, args)))
|
||||
(self.client_index, call_id, method, args)))
|
||||
await self._send_input_message(message, engine, args)
|
||||
self._ensure_output_queue_task()
|
||||
return await future
|
||||
|
||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||
request.client_index = self.client_index
|
||||
await self._send_input(EngineCoreRequestType.ADD, request)
|
||||
self._ensure_output_queue_task()
|
||||
|
||||
@ -921,17 +882,120 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
|
||||
EngineCore."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
|
||||
log_stats: bool):
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_index: int = 0):
|
||||
|
||||
self.current_wave = 0
|
||||
self.engines_running = False
|
||||
# To route aborts to the correct engine.
|
||||
self.reqs_in_flight: dict[str, CoreEngine] = {}
|
||||
|
||||
super().__init__(vllm_config, executor_class, log_stats)
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
client_addresses, client_index)
|
||||
|
||||
assert len(self.core_engines) > 1
|
||||
|
||||
# List of [waiting, running] pair per engine.
|
||||
self.lb_engines: list[list[int]] = []
|
||||
|
||||
self.first_req_sock_addr = get_open_zmq_inproc_path()
|
||||
self.first_req_send_socket = self.resources.first_req_send_socket = (
|
||||
make_zmq_socket(self.ctx,
|
||||
self.first_req_sock_addr,
|
||||
zmq.PAIR,
|
||||
bind=True))
|
||||
try:
|
||||
# If we are running in an asyncio event loop, start the stats task.
|
||||
# Otherwise, it will be started lazily.
|
||||
asyncio.get_running_loop()
|
||||
self._ensure_stats_update_task()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def _ensure_stats_update_task(self):
|
||||
resources = self.resources
|
||||
if resources.stats_update_task is not None:
|
||||
return
|
||||
|
||||
assert self.stats_update_address is not None
|
||||
|
||||
async def run_engine_stats_update_task():
|
||||
with make_zmq_socket(self.ctx, self.stats_update_address,
|
||||
zmq.XSUB) as socket, make_zmq_socket(
|
||||
self.ctx,
|
||||
self.first_req_sock_addr,
|
||||
zmq.PAIR,
|
||||
bind=False) as first_req_rcv_socket:
|
||||
# Send subscription message.
|
||||
await socket.send(b'\x01')
|
||||
|
||||
poller = zmq.asyncio.Poller()
|
||||
poller.register(socket, zmq.POLLIN)
|
||||
poller.register(first_req_rcv_socket, zmq.POLLIN)
|
||||
|
||||
while True:
|
||||
events = await poller.poll()
|
||||
if not self.engines_running and len(events) == 2 or (
|
||||
events[0][0] == first_req_rcv_socket):
|
||||
# Send a message to notify the coordinator that
|
||||
# we're sending a request while the engines are
|
||||
# paused, so that it can wake the others up
|
||||
# (to run dummy EP loop).
|
||||
self.engines_running = True
|
||||
buf = first_req_rcv_socket.recv(
|
||||
flags=zmq.NOBLOCK).result()
|
||||
target_eng_index = int.from_bytes(buf, "little")
|
||||
msg = msgspec.msgpack.encode(
|
||||
(target_eng_index, self.current_wave))
|
||||
await socket.send(msg)
|
||||
|
||||
buf = None
|
||||
while True:
|
||||
# Drain all stats events (we only care about latest).
|
||||
future: asyncio.Future[bytes] = socket.recv(
|
||||
flags=zmq.NOBLOCK)
|
||||
if isinstance(future.exception(), zmq.Again):
|
||||
break
|
||||
buf = future.result()
|
||||
if buf is None:
|
||||
continue
|
||||
|
||||
# Update local load-balancing state.
|
||||
counts, wave, running = msgspec.msgpack.decode(buf)
|
||||
self.current_wave = wave
|
||||
self.engines_running = running
|
||||
self.lb_engines = counts
|
||||
|
||||
resources.stats_update_task = asyncio.create_task(
|
||||
run_engine_stats_update_task())
|
||||
|
||||
def get_core_engine_for_request(self) -> CoreEngine:
|
||||
if not self.lb_engines:
|
||||
return self.core_engines[0]
|
||||
# TODO use P2C alg for larger DP sizes
|
||||
num_engines = len(self.lb_engines)
|
||||
min_counts = [sys.maxsize, sys.maxsize]
|
||||
eng_index = 0
|
||||
for i in range(num_engines):
|
||||
# Start from client_index to help with balancing when engines
|
||||
# are empty.
|
||||
idx = (self.client_index + i) % num_engines
|
||||
counts = self.lb_engines[idx]
|
||||
if counts < min_counts:
|
||||
min_counts = counts
|
||||
eng_index = idx
|
||||
# Adjust local counts for better balancing between stats updates
|
||||
# from the coordinator (which happen every 100ms).
|
||||
if min_counts[0]:
|
||||
min_counts[0] += 1
|
||||
else:
|
||||
min_counts[1] += 1
|
||||
return self.core_engines[eng_index]
|
||||
|
||||
async def call_utility_async(self, method: str, *args) -> Any:
|
||||
# Only the result from the first engine is returned.
|
||||
return (await asyncio.gather(*[
|
||||
@ -940,62 +1004,30 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
]))[0]
|
||||
|
||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||
self._ensure_stats_update_task()
|
||||
|
||||
request.current_wave = self.current_wave
|
||||
request.client_index = self.client_index
|
||||
|
||||
chosen_engine = self.get_core_engine_for_request()
|
||||
self.reqs_in_flight[request.request_id] = chosen_engine
|
||||
chosen_engine.num_reqs_in_flight += 1
|
||||
|
||||
to_await = self._send_input(EngineCoreRequestType.ADD, request,
|
||||
chosen_engine)
|
||||
if not self.engines_running:
|
||||
# Send request to chosen engine and dp start loop
|
||||
# control message to all other engines.
|
||||
self.engines_running = True
|
||||
to_await = asyncio.gather(
|
||||
to_await, # type: ignore[assignment]
|
||||
*self._start_wave_coros(exclude_index=chosen_engine.index))
|
||||
# Notify coordinator that we're sending a request
|
||||
await self.first_req_send_socket.send(chosen_engine.identity)
|
||||
|
||||
await to_await
|
||||
|
||||
self._ensure_output_queue_task()
|
||||
|
||||
def get_core_engine_for_request(self) -> CoreEngine:
|
||||
return min(self.core_engines, key=lambda e: e.num_reqs_in_flight)
|
||||
|
||||
@staticmethod
|
||||
async def process_engine_outputs(self: "DPAsyncMPClient",
|
||||
outputs: EngineCoreOutputs):
|
||||
if self.reqs_in_flight:
|
||||
for req_id in outputs.finished_requests or ():
|
||||
if engine := self.reqs_in_flight.pop(req_id, None):
|
||||
engine.num_reqs_in_flight -= 1
|
||||
|
||||
if outputs.wave_complete is not None:
|
||||
# Current wave is complete, move to next wave number
|
||||
# and mark engines as paused.
|
||||
if self.current_wave <= outputs.wave_complete:
|
||||
self.current_wave = outputs.wave_complete + 1
|
||||
self.engines_running = False
|
||||
|
||||
elif outputs.start_wave is not None and (
|
||||
outputs.start_wave > self.current_wave or
|
||||
(outputs.start_wave == self.current_wave
|
||||
and not self.engines_running)):
|
||||
# Engine received request for a non-current wave so we must ensure
|
||||
# that other engines progress to the next wave.
|
||||
self.current_wave = outputs.start_wave
|
||||
self.engines_running = True
|
||||
await asyncio.gather(*self._start_wave_coros(
|
||||
exclude_index=outputs.engine_index))
|
||||
|
||||
def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]:
|
||||
logger.debug("Sending start DP wave %d.", self.current_wave)
|
||||
return [
|
||||
self._send_input(EngineCoreRequestType.START_DP_WAVE,
|
||||
self.current_wave, engine)
|
||||
for engine in self.core_engines if engine.index != exclude_index
|
||||
]
|
||||
if outputs.finished_requests and self.reqs_in_flight:
|
||||
for req_id in outputs.finished_requests:
|
||||
self.reqs_in_flight.pop(req_id, None)
|
||||
|
||||
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
||||
if not request_ids:
|
||||
|
||||
@ -12,13 +12,12 @@ from vllm.config import SupportsMetricsInfo, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
|
||||
from vllm.v1.engine import FinishReason
|
||||
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
|
||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_LOCAL_LOGGING_INTERVAL_SEC = 5.0
|
||||
|
||||
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
|
||||
|
||||
|
||||
@ -35,7 +34,7 @@ class StatLoggerBase(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def record(self, scheduler_stats: SchedulerStats,
|
||||
def record(self, scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
...
|
||||
|
||||
@ -78,20 +77,22 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
# Compute summary metrics for tracked stats
|
||||
return float(np.sum(tracked_stats) / (now - self.last_log_time))
|
||||
|
||||
def record(self, scheduler_stats: SchedulerStats,
|
||||
def record(self, scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
"""Log Stats to standard output."""
|
||||
|
||||
if iteration_stats:
|
||||
self._track_iteration_stats(iteration_stats)
|
||||
|
||||
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
|
||||
if scheduler_stats is not None:
|
||||
self.prefix_caching_metrics.observe(
|
||||
scheduler_stats.prefix_cache_stats)
|
||||
|
||||
if scheduler_stats.spec_decoding_stats is not None:
|
||||
self.spec_decoding_logging.observe(
|
||||
scheduler_stats.spec_decoding_stats)
|
||||
if scheduler_stats.spec_decoding_stats is not None:
|
||||
self.spec_decoding_logging.observe(
|
||||
scheduler_stats.spec_decoding_stats)
|
||||
|
||||
self.last_scheduler_stats = scheduler_stats
|
||||
self.last_scheduler_stats = scheduler_stats
|
||||
|
||||
def log(self):
|
||||
now = time.monotonic()
|
||||
@ -131,10 +132,11 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||
|
||||
def log_engine_initialized(self):
|
||||
logger.info(
|
||||
"vllm cache_config_info with initialization " \
|
||||
"after num_gpu_blocks is: %d",
|
||||
self.vllm_config.cache_config.num_gpu_blocks)
|
||||
if self.vllm_config.cache_config.num_gpu_blocks:
|
||||
logger.info(
|
||||
"Engine %03d: vllm cache_config_info with initialization "
|
||||
"after num_gpu_blocks is: %d", self.engine_index,
|
||||
self.vllm_config.cache_config.num_gpu_blocks)
|
||||
|
||||
|
||||
class PrometheusStatLogger(StatLoggerBase):
|
||||
@ -144,7 +146,8 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
_spec_decoding_cls = SpecDecodingProm
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
self._unregister_vllm_metrics()
|
||||
|
||||
unregister_vllm_metrics()
|
||||
self.vllm_config = vllm_config
|
||||
self.engine_index = engine_index
|
||||
# Use this flag to hide metrics that were deprecated in
|
||||
@ -169,11 +172,13 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
self.gauge_scheduler_running = self._gauge_cls(
|
||||
name="vllm:num_requests_running",
|
||||
documentation="Number of requests in model execution batches.",
|
||||
multiprocess_mode="mostrecent",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.gauge_scheduler_waiting = self._gauge_cls(
|
||||
name="vllm:num_requests_waiting",
|
||||
documentation="Number of requests waiting to be processed.",
|
||||
multiprocess_mode="mostrecent",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
#
|
||||
@ -182,6 +187,7 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
self.gauge_gpu_cache_usage = self._gauge_cls(
|
||||
name="vllm:gpu_cache_usage_perc",
|
||||
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
|
||||
multiprocess_mode="mostrecent",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.counter_gpu_prefix_cache_queries = self._counter_cls(
|
||||
@ -242,6 +248,9 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
# TODO: This metric might be incorrect in case of using multiple
|
||||
# api_server counts which uses prometheus mp.
|
||||
# See: https://github.com/vllm-project/vllm/pull/18053
|
||||
self.histogram_iteration_tokens = \
|
||||
self._histogram_cls(
|
||||
name="vllm:iteration_tokens_total",
|
||||
@ -340,6 +349,9 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
#
|
||||
# LoRA metrics
|
||||
#
|
||||
|
||||
# TODO: This metric might be incorrect in case of using multiple
|
||||
# api_server counts which uses prometheus mp.
|
||||
self.gauge_lora_info: Optional[prometheus_client.Gauge] = None
|
||||
if vllm_config.lora_config is not None:
|
||||
self.labelname_max_lora = "max_lora"
|
||||
@ -350,13 +362,16 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
self._gauge_cls(
|
||||
name="vllm:lora_requests_info",
|
||||
documentation="Running stats on lora requests.",
|
||||
multiprocess_mode="sum",
|
||||
labelnames=[
|
||||
self.labelname_max_lora,
|
||||
self.labelname_waiting_lora_adapters,
|
||||
self.labelname_running_lora_adapters,
|
||||
])
|
||||
],
|
||||
)
|
||||
|
||||
def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
|
||||
|
||||
metrics_info = config_obj.metrics_info()
|
||||
metrics_info["engine"] = self.engine_index
|
||||
|
||||
@ -372,25 +387,28 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
info_gauge = self._gauge_cls(
|
||||
name=name,
|
||||
documentation=documentation,
|
||||
labelnames=metrics_info.keys()).labels(**metrics_info)
|
||||
multiprocess_mode="mostrecent",
|
||||
labelnames=metrics_info.keys(),
|
||||
).labels(**metrics_info)
|
||||
info_gauge.set(1)
|
||||
|
||||
def record(self, scheduler_stats: SchedulerStats,
|
||||
def record(self, scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
"""Log to prometheus."""
|
||||
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
|
||||
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
|
||||
if scheduler_stats is not None:
|
||||
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
|
||||
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
|
||||
|
||||
self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage)
|
||||
self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage)
|
||||
|
||||
self.counter_gpu_prefix_cache_queries.inc(
|
||||
scheduler_stats.prefix_cache_stats.queries)
|
||||
self.counter_gpu_prefix_cache_hits.inc(
|
||||
scheduler_stats.prefix_cache_stats.hits)
|
||||
self.counter_gpu_prefix_cache_queries.inc(
|
||||
scheduler_stats.prefix_cache_stats.queries)
|
||||
self.counter_gpu_prefix_cache_hits.inc(
|
||||
scheduler_stats.prefix_cache_stats.hits)
|
||||
|
||||
if scheduler_stats.spec_decoding_stats is not None:
|
||||
self.spec_decoding_prom.observe(
|
||||
scheduler_stats.spec_decoding_stats)
|
||||
if scheduler_stats.spec_decoding_stats is not None:
|
||||
self.spec_decoding_prom.observe(
|
||||
scheduler_stats.spec_decoding_stats)
|
||||
|
||||
if iteration_stats is None:
|
||||
return
|
||||
@ -445,13 +463,6 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
self.gauge_lora_info.labels(**lora_info_labels)\
|
||||
.set_to_current_time()
|
||||
|
||||
@staticmethod
|
||||
def _unregister_vllm_metrics():
|
||||
# Unregister any existing vLLM collectors (for CI/CD
|
||||
for collector in list(prometheus_client.REGISTRY._collector_to_names):
|
||||
if hasattr(collector, "_name") and "vllm" in collector._name:
|
||||
prometheus_client.REGISTRY.unregister(collector)
|
||||
|
||||
def log_engine_initialized(self):
|
||||
self.log_metrics_info("cache_config", self.vllm_config.cache_config)
|
||||
|
||||
|
||||
77
vllm/v1/metrics/prometheus.py
Normal file
77
vllm/v1/metrics/prometheus.py
Normal file
@ -0,0 +1,77 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Global temporary directory for prometheus multiprocessing
|
||||
_prometheus_multiproc_dir: Optional[tempfile.TemporaryDirectory] = None
|
||||
|
||||
|
||||
def setup_multiprocess_prometheus():
|
||||
"""Set up prometheus multiprocessing directory if not already configured.
|
||||
|
||||
"""
|
||||
global _prometheus_multiproc_dir
|
||||
|
||||
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
||||
# Make TemporaryDirectory for prometheus multiprocessing
|
||||
# Note: global TemporaryDirectory will be automatically
|
||||
# cleaned up upon exit.
|
||||
_prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
||||
os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name
|
||||
logger.debug("Created PROMETHEUS_MULTIPROC_DIR at %s",
|
||||
_prometheus_multiproc_dir.name)
|
||||
else:
|
||||
logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. "
|
||||
"This directory must be wiped between vLLM runs or "
|
||||
"you will find inaccurate metrics. Unset the variable "
|
||||
"and vLLM will properly handle cleanup.")
|
||||
|
||||
|
||||
def get_prometheus_registry():
|
||||
"""Get the appropriate prometheus registry based on multiprocessing
|
||||
configuration.
|
||||
|
||||
Returns:
|
||||
Registry: A prometheus registry
|
||||
"""
|
||||
if os.getenv("PROMETHEUS_MULTIPROC_DIR") is not None:
|
||||
logger.debug("Using multiprocess registry for prometheus metrics")
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
return registry
|
||||
|
||||
return REGISTRY
|
||||
|
||||
|
||||
def unregister_vllm_metrics():
|
||||
"""Unregister any existing vLLM collectors from the prometheus registry.
|
||||
|
||||
This is useful for testing and CI/CD where metrics may be registered
|
||||
multiple times across test runs.
|
||||
|
||||
Also, in case of multiprocess, we need to unregister the metrics from the
|
||||
global registry.
|
||||
"""
|
||||
registry = REGISTRY
|
||||
# Unregister any existing vLLM collectors
|
||||
for collector in list(registry._collector_to_names):
|
||||
if hasattr(collector, "_name") and "vllm" in collector._name:
|
||||
registry.unregister(collector)
|
||||
|
||||
|
||||
def shutdown_prometheus():
|
||||
"""Shutdown prometheus metrics."""
|
||||
try:
|
||||
pid = os.getpid()
|
||||
multiprocess.mark_process_dead(pid)
|
||||
logger.debug("Marked Prometheus metrics for process %d as dead", pid)
|
||||
except Exception as e:
|
||||
logger.error("Error during metrics cleanup: %s", str(e))
|
||||
@ -26,12 +26,13 @@ class Request:
|
||||
multi_modal_placeholders: Optional[list[PlaceholderRange]],
|
||||
sampling_params: SamplingParams,
|
||||
eos_token_id: Optional[int],
|
||||
arrival_time: float,
|
||||
client_index: int = 0,
|
||||
lora_request: Optional["LoRARequest"] = None,
|
||||
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.client_index = client_index
|
||||
self.sampling_params = sampling_params
|
||||
# Because of LoRA, the eos token id can be different for each request.
|
||||
self.eos_token_id = eos_token_id
|
||||
@ -90,13 +91,13 @@ class Request:
|
||||
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
client_index=request.client_index,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
multi_modal_inputs=request.mm_inputs,
|
||||
multi_modal_hashes=request.mm_hashes,
|
||||
multi_modal_placeholders=request.mm_placeholders,
|
||||
sampling_params=request.sampling_params,
|
||||
eos_token_id=request.eos_token_id,
|
||||
arrival_time=request.arrival_time,
|
||||
lora_request=request.lora_request,
|
||||
structured_output_request=StructuredOutputRequest(
|
||||
sampling_params=request.sampling_params),
|
||||
|
||||
301
vllm/v1/utils.py
301
vllm/v1/utils.py
@ -1,31 +1,41 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import multiprocessing
|
||||
import time
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from multiprocessing import Process, connection
|
||||
from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union,
|
||||
overload)
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
|
||||
Union, overload)
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import get_mp_context, kill_process_tree
|
||||
from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
|
||||
get_tcp_uri, kill_process_tree)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
STARTUP_POLL_PERIOD_MS = 10000
|
||||
|
||||
|
||||
class ConstantList(Generic[T], Sequence):
|
||||
|
||||
@ -95,6 +105,78 @@ class ConstantList(Generic[T], Sequence):
|
||||
return f"ConstantList({self._x})"
|
||||
|
||||
|
||||
def get_engine_client_zmq_addr(local_only: bool,
|
||||
host: str,
|
||||
port: int = 0) -> str:
|
||||
return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
|
||||
host, port or get_open_port()))
|
||||
|
||||
|
||||
class APIServerProcessManager:
|
||||
"""Manages a group of API server processes.
|
||||
|
||||
Handles creation, monitoring, and termination of API server worker
|
||||
processes. Also monitors extra processes to check if they are healthy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_server_fn: Callable,
|
||||
listen_address: str,
|
||||
sock: Any,
|
||||
args: argparse.Namespace,
|
||||
num_servers: int,
|
||||
input_addresses: list[str],
|
||||
output_addresses: list[str],
|
||||
stats_update_address: Optional[str] = None,
|
||||
):
|
||||
"""Initialize and start API server worker processes.
|
||||
|
||||
Args:
|
||||
target_server_fn: Function to call for each API server process
|
||||
listen_address: Address to listen for client connections
|
||||
sock: Socket for client connections
|
||||
args: Command line arguments
|
||||
num_servers: Number of API server processes to start
|
||||
input_addresses: Input addresses for each API server
|
||||
output_addresses: Output addresses for each API server
|
||||
stats_update_address: Optional stats update address
|
||||
"""
|
||||
self.listen_address = listen_address
|
||||
self.sock = sock
|
||||
self.args = args
|
||||
|
||||
# Start API servers
|
||||
spawn_context = multiprocessing.get_context("spawn")
|
||||
self.processes: list[BaseProcess] = []
|
||||
|
||||
for i, in_addr, out_addr in zip(range(num_servers), input_addresses,
|
||||
output_addresses):
|
||||
client_config = {
|
||||
"input_address": in_addr,
|
||||
"output_address": out_addr,
|
||||
"client_index": i
|
||||
}
|
||||
if stats_update_address is not None:
|
||||
client_config["stats_update_address"] = stats_update_address
|
||||
|
||||
proc = spawn_context.Process(target=target_server_fn,
|
||||
name=f"ApiServer_{i}",
|
||||
args=(listen_address, sock, args,
|
||||
client_config))
|
||||
self.processes.append(proc)
|
||||
proc.start()
|
||||
|
||||
logger.info("Started %d API server processes", len(self.processes))
|
||||
|
||||
# Shutdown only the API server processes on garbage collection
|
||||
# The extra processes are managed by their owners
|
||||
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
||||
|
||||
def close(self) -> None:
|
||||
self._finalizer()
|
||||
|
||||
|
||||
class CoreEngineProcManager:
|
||||
"""
|
||||
Utility class to handle creation, readiness, and shutdown
|
||||
@ -109,7 +191,7 @@ class CoreEngineProcManager:
|
||||
local_start_index: int,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
input_address: str,
|
||||
handshake_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
):
|
||||
@ -117,12 +199,12 @@ class CoreEngineProcManager:
|
||||
common_kwargs = {
|
||||
"vllm_config": vllm_config,
|
||||
"on_head_node": on_head_node,
|
||||
"input_address": input_address,
|
||||
"handshake_address": handshake_address,
|
||||
"executor_class": executor_class,
|
||||
"log_stats": log_stats,
|
||||
}
|
||||
|
||||
self.processes: list[Process] = []
|
||||
self.processes: list[BaseProcess] = []
|
||||
for index in range(local_engine_count):
|
||||
local_index = local_start_index + index
|
||||
global_index = start_index + index
|
||||
@ -135,8 +217,7 @@ class CoreEngineProcManager:
|
||||
"local_dp_rank": local_index,
|
||||
}))
|
||||
|
||||
self._finalizer = weakref.finalize(self, shutdown, self.processes,
|
||||
input_address)
|
||||
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
||||
try:
|
||||
for proc in self.processes:
|
||||
proc.start()
|
||||
@ -164,9 +245,199 @@ class CoreEngineProcManager:
|
||||
}
|
||||
|
||||
|
||||
class CoreEngineState(Enum):
|
||||
NEW = auto()
|
||||
CONNECTED = auto()
|
||||
READY = auto()
|
||||
|
||||
|
||||
class CoreEngine:
|
||||
"""One per data parallel rank."""
|
||||
|
||||
def __init__(self, index: int = 0, local: bool = True):
|
||||
self.local = local
|
||||
self.index = index
|
||||
self.identity = index.to_bytes(2, "little")
|
||||
|
||||
self.state = CoreEngineState.NEW
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineZmqAddresses:
|
||||
# ZMQ input socket addresses for each front-end client (requests)
|
||||
inputs: list[str]
|
||||
# ZMQ output socket addresses for each front-end client (responses)
|
||||
outputs: list[str]
|
||||
# ZMQ input socket address of DP coordinator if applicable
|
||||
coordinator_input: Optional[str] = None
|
||||
# ZMQ output socket address of DP coordinator if applicable
|
||||
coordinator_output: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineHandshakeMetadata:
|
||||
"""Metadata sent to each engine process during startup handshake,
|
||||
including addresses of the front-end ZMQ queues that they should
|
||||
connect to.
|
||||
"""
|
||||
addresses: EngineZmqAddresses
|
||||
parallel_config: dict[str, Union[int, str]]
|
||||
|
||||
|
||||
def wait_for_engine_startup(
|
||||
handshake_socket: zmq.Socket,
|
||||
addresses: EngineZmqAddresses,
|
||||
core_engines: list[CoreEngine],
|
||||
parallel_config: ParallelConfig,
|
||||
cache_config: CacheConfig,
|
||||
proc_manager: Optional[CoreEngineProcManager],
|
||||
coord_process: Optional[Process],
|
||||
):
|
||||
|
||||
# Wait for engine core process(es) to send ready messages.
|
||||
local_count = parallel_config.data_parallel_size_local
|
||||
remote_count = len(core_engines) - local_count
|
||||
# [local, remote] counts
|
||||
conn_pending, start_pending = [local_count, remote_count], [0, 0]
|
||||
poller = zmq.Poller()
|
||||
poller.register(handshake_socket, zmq.POLLIN)
|
||||
|
||||
if proc_manager is not None:
|
||||
for sentinel in proc_manager.sentinels():
|
||||
poller.register(sentinel, zmq.POLLIN)
|
||||
if coord_process is not None:
|
||||
poller.register(coord_process.sentinel, zmq.POLLIN)
|
||||
while any(conn_pending) or any(start_pending):
|
||||
events = poller.poll(STARTUP_POLL_PERIOD_MS)
|
||||
if not events:
|
||||
if any(conn_pending):
|
||||
logger.debug(
|
||||
"Waiting for %d local, %d remote core engine proc(s) "
|
||||
"to connect.", *conn_pending)
|
||||
if any(start_pending):
|
||||
logger.debug(
|
||||
"Waiting for %d local, %d remote core engine proc(s) "
|
||||
"to start.", *start_pending)
|
||||
continue
|
||||
if len(events) > 1 or events[0][0] != handshake_socket:
|
||||
# One of the local core processes exited.
|
||||
finished = proc_manager.finished_procs() if proc_manager else {}
|
||||
if coord_process is not None and coord_process.exitcode is not None:
|
||||
finished[coord_process.name] = coord_process.exitcode
|
||||
raise RuntimeError("Engine core initialization failed. "
|
||||
"See root cause above. "
|
||||
f"Failed core proc(s): {finished}")
|
||||
|
||||
# Receive HELLO and READY messages from the input socket.
|
||||
eng_identity, ready_msg_bytes = handshake_socket.recv_multipart()
|
||||
eng_index = int.from_bytes(eng_identity, "little")
|
||||
engine = next((e for e in core_engines if e.identity == eng_identity),
|
||||
None)
|
||||
if engine is None:
|
||||
raise RuntimeError(f"Message from engine with unexpected data "
|
||||
f"parallel rank: {eng_index}")
|
||||
msg = msgspec.msgpack.decode(ready_msg_bytes)
|
||||
status, local = msg["status"], msg["local"]
|
||||
if local != engine.local:
|
||||
raise RuntimeError(f"{status} message from "
|
||||
f"{'local' if local else 'remote'} "
|
||||
f"engine {eng_index}, expected it to be "
|
||||
f"{'local' if engine.local else 'remote'}")
|
||||
|
||||
if status == "HELLO" and engine.state == CoreEngineState.NEW:
|
||||
|
||||
# Send init message with DP config info.
|
||||
init_message = msgspec.msgpack.encode(
|
||||
EngineHandshakeMetadata(
|
||||
addresses=addresses,
|
||||
parallel_config={
|
||||
"data_parallel_master_ip":
|
||||
parallel_config.data_parallel_master_ip,
|
||||
"data_parallel_master_port":
|
||||
parallel_config.data_parallel_master_port,
|
||||
"data_parallel_size":
|
||||
parallel_config.data_parallel_size,
|
||||
}))
|
||||
handshake_socket.send_multipart((eng_identity, init_message),
|
||||
copy=False)
|
||||
conn_pending[0 if local else 1] -= 1
|
||||
start_pending[0 if local else 1] += 1
|
||||
engine.state = CoreEngineState.CONNECTED
|
||||
elif status == "READY" and (engine.state == CoreEngineState.CONNECTED):
|
||||
# Setup KV cache config with initialization state from
|
||||
# engine core process. Sum values from all engines in DP case.
|
||||
num_gpu_blocks = cache_config.num_gpu_blocks or 0
|
||||
num_gpu_blocks += msg["num_gpu_blocks"]
|
||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
|
||||
start_pending[0 if local else 1] -= 1
|
||||
engine.state = CoreEngineState.READY
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected {status} message for "
|
||||
f"{'local' if local else 'remote'} engine "
|
||||
f"{eng_index} in {engine.state} state.")
|
||||
|
||||
logger.debug("%s from %s core engine process %s.", status,
|
||||
"local" if local else "remote", eng_index)
|
||||
|
||||
|
||||
def wait_for_completion_or_failure(
|
||||
api_server_manager: APIServerProcessManager,
|
||||
local_engine_manager: Optional[CoreEngineProcManager] = None,
|
||||
coordinator: Optional["DPCoordinator"] = None) -> None:
|
||||
"""Wait for all processes to complete or detect if any fail.
|
||||
|
||||
Raises an exception if any process exits with a non-zero status.
|
||||
"""
|
||||
|
||||
try:
|
||||
logger.info("Waiting for API servers to complete ...")
|
||||
# Create a mapping of sentinels to their corresponding processes
|
||||
# for efficient lookup
|
||||
sentinel_to_proc: dict[Any, BaseProcess] = {
|
||||
proc.sentinel: proc
|
||||
for proc in api_server_manager.processes
|
||||
}
|
||||
|
||||
if coordinator:
|
||||
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
|
||||
|
||||
if local_engine_manager:
|
||||
for proc in local_engine_manager.processes:
|
||||
sentinel_to_proc[proc.sentinel] = proc
|
||||
|
||||
# Check if any process terminates
|
||||
while sentinel_to_proc:
|
||||
# Wait for any process to terminate
|
||||
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc)
|
||||
|
||||
# Process any terminated processes
|
||||
for sentinel in ready_sentinels:
|
||||
proc = sentinel_to_proc.pop(sentinel)
|
||||
|
||||
# Check if process exited with error
|
||||
if proc.exitcode != 0:
|
||||
raise RuntimeError(
|
||||
f"Process {proc.name} (PID: {proc.pid}) "
|
||||
f"died with exit code {proc.exitcode}")
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received KeyboardInterrupt, shutting down API servers...")
|
||||
except Exception as e:
|
||||
logger.exception("Exception occurred while running API servers: %s",
|
||||
str(e))
|
||||
raise
|
||||
finally:
|
||||
logger.info("Terminating remaining processes ...")
|
||||
api_server_manager.close()
|
||||
if coordinator:
|
||||
coordinator.close()
|
||||
if local_engine_manager:
|
||||
local_engine_manager.close()
|
||||
|
||||
|
||||
# Note(rob): shutdown function cannot be a bound method,
|
||||
# else the gc cannot collect the objedecoupct.
|
||||
def shutdown(procs: list[Process], input_address: str):
|
||||
# else the gc cannot collect the object.
|
||||
def shutdown(procs: list[BaseProcess]):
|
||||
# Shutdown the process.
|
||||
for proc in procs:
|
||||
if proc.is_alive():
|
||||
@ -185,12 +456,6 @@ def shutdown(procs: list[Process], input_address: str):
|
||||
if proc.is_alive() and (pid := proc.pid) is not None:
|
||||
kill_process_tree(pid)
|
||||
|
||||
# Remove zmq ipc socket files.
|
||||
if input_address.startswith("ipc://"):
|
||||
socket_file = input_address[len("ipc://"):]
|
||||
if os and os.path.exists(socket_file):
|
||||
os.remove(socket_file)
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user