mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:35:26 +08:00
[Perf] API-server scaleout with many-to-many server-engine comms (#17546)
This commit is contained in:
parent
84ec470fca
commit
2dbe8c0774
@ -618,9 +618,11 @@ steps:
|
|||||||
- vllm/worker/model_runner.py
|
- vllm/worker/model_runner.py
|
||||||
- entrypoints/llm/test_collective_rpc.py
|
- entrypoints/llm/test_collective_rpc.py
|
||||||
- tests/v1/test_async_llm_dp.py
|
- tests/v1/test_async_llm_dp.py
|
||||||
|
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||||
- vllm/v1/engine/
|
- vllm/v1/engine/
|
||||||
commands:
|
commands:
|
||||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||||
|
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
||||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||||
- pytest -v -s ./compile/test_basic_correctness.py
|
- pytest -v -s ./compile/test_basic_correctness.py
|
||||||
- pytest -v -s ./compile/test_wrapper.py
|
- pytest -v -s ./compile/test_wrapper.py
|
||||||
|
|||||||
268
tests/entrypoints/test_api_server_process_manager.py
Normal file
268
tests/entrypoints/test_api_server_process_manager.py
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.v1.utils import (APIServerProcessManager,
|
||||||
|
wait_for_completion_or_failure)
|
||||||
|
|
||||||
|
# Global variables to control worker behavior
|
||||||
|
WORKER_RUNTIME_SECONDS = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
# Mock implementation of run_api_server_worker
|
||||||
|
def mock_run_api_server_worker(listen_address, sock, args, client_config=None):
|
||||||
|
"""Mock run_api_server_worker that runs for a specific time."""
|
||||||
|
print(f"Mock worker started with client_config: {client_config}")
|
||||||
|
time.sleep(WORKER_RUNTIME_SECONDS)
|
||||||
|
print("Mock worker completed successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def api_server_args():
|
||||||
|
"""Fixture to provide arguments for APIServerProcessManager."""
|
||||||
|
sock = socket.socket()
|
||||||
|
return {
|
||||||
|
"target_server_fn":
|
||||||
|
mock_run_api_server_worker,
|
||||||
|
"listen_address":
|
||||||
|
"localhost:8000",
|
||||||
|
"sock":
|
||||||
|
sock,
|
||||||
|
"args":
|
||||||
|
"test_args", # Simple string to avoid pickling issues
|
||||||
|
"num_servers":
|
||||||
|
3,
|
||||||
|
"input_addresses": [
|
||||||
|
"tcp://127.0.0.1:5001", "tcp://127.0.0.1:5002",
|
||||||
|
"tcp://127.0.0.1:5003"
|
||||||
|
],
|
||||||
|
"output_addresses": [
|
||||||
|
"tcp://127.0.0.1:6001", "tcp://127.0.0.1:6002",
|
||||||
|
"tcp://127.0.0.1:6003"
|
||||||
|
],
|
||||||
|
"stats_update_address":
|
||||||
|
"tcp://127.0.0.1:7000",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("with_stats_update", [True, False])
|
||||||
|
def test_api_server_process_manager_init(api_server_args, with_stats_update):
|
||||||
|
"""Test initializing the APIServerProcessManager."""
|
||||||
|
# Set the worker runtime to ensure tests complete in reasonable time
|
||||||
|
global WORKER_RUNTIME_SECONDS
|
||||||
|
WORKER_RUNTIME_SECONDS = 0.5
|
||||||
|
|
||||||
|
# Copy the args to avoid mutating the
|
||||||
|
args = api_server_args.copy()
|
||||||
|
|
||||||
|
if not with_stats_update:
|
||||||
|
args.pop("stats_update_address")
|
||||||
|
manager = APIServerProcessManager(**args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify the manager was initialized correctly
|
||||||
|
assert len(manager.processes) == 3
|
||||||
|
|
||||||
|
# Verify all processes are running
|
||||||
|
for proc in manager.processes:
|
||||||
|
assert proc.is_alive()
|
||||||
|
|
||||||
|
print("Waiting for processes to run...")
|
||||||
|
time.sleep(WORKER_RUNTIME_SECONDS / 2)
|
||||||
|
|
||||||
|
# They should still be alive at this point
|
||||||
|
for proc in manager.processes:
|
||||||
|
assert proc.is_alive()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Always clean up the processes
|
||||||
|
print("Cleaning up processes...")
|
||||||
|
manager.close()
|
||||||
|
|
||||||
|
# Give processes time to terminate
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
|
# Verify all processes were terminated
|
||||||
|
for proc in manager.processes:
|
||||||
|
assert not proc.is_alive()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("vllm.entrypoints.cli.serve.run_api_server_worker",
|
||||||
|
mock_run_api_server_worker)
|
||||||
|
def test_wait_for_completion_or_failure(api_server_args):
|
||||||
|
"""Test that wait_for_completion_or_failure works with failures."""
|
||||||
|
global WORKER_RUNTIME_SECONDS
|
||||||
|
WORKER_RUNTIME_SECONDS = 1.0
|
||||||
|
|
||||||
|
# Create the manager
|
||||||
|
manager = APIServerProcessManager(**api_server_args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert len(manager.processes) == 3
|
||||||
|
|
||||||
|
# Create a result capture for the thread
|
||||||
|
result: dict[str, Optional[Exception]] = {"exception": None}
|
||||||
|
|
||||||
|
def run_with_exception_capture():
|
||||||
|
try:
|
||||||
|
wait_for_completion_or_failure(api_server_manager=manager)
|
||||||
|
except Exception as e:
|
||||||
|
result["exception"] = e
|
||||||
|
|
||||||
|
# Start a thread to run wait_for_completion_or_failure
|
||||||
|
wait_thread = threading.Thread(target=run_with_exception_capture,
|
||||||
|
daemon=True)
|
||||||
|
wait_thread.start()
|
||||||
|
|
||||||
|
# Let all processes run for a short time
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
|
# All processes should still be running
|
||||||
|
assert all(proc.is_alive() for proc in manager.processes)
|
||||||
|
|
||||||
|
# Now simulate a process failure
|
||||||
|
print("Simulating process failure...")
|
||||||
|
manager.processes[0].terminate()
|
||||||
|
|
||||||
|
# Wait for the wait_for_completion_or_failure
|
||||||
|
# to detect and handle the failure
|
||||||
|
# This should trigger it to terminate all other processes
|
||||||
|
wait_thread.join(timeout=1.0)
|
||||||
|
|
||||||
|
# The wait thread should have exited
|
||||||
|
assert not wait_thread.is_alive()
|
||||||
|
|
||||||
|
# Verify that an exception was raised with appropriate error message
|
||||||
|
assert result["exception"] is not None
|
||||||
|
assert "died with exit code" in str(result["exception"])
|
||||||
|
|
||||||
|
# All processes should now be terminated
|
||||||
|
for i, proc in enumerate(manager.processes):
|
||||||
|
assert not proc.is_alive(), f"Process {i} should not be alive"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
manager.close()
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(30)
|
||||||
|
def test_normal_completion(api_server_args):
|
||||||
|
"""Test that wait_for_completion_or_failure works in normal completion."""
|
||||||
|
global WORKER_RUNTIME_SECONDS
|
||||||
|
WORKER_RUNTIME_SECONDS = 0.1
|
||||||
|
|
||||||
|
# Create the manager
|
||||||
|
manager = APIServerProcessManager(**api_server_args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Give processes time to terminate
|
||||||
|
# wait for processes to complete
|
||||||
|
remaining_processes = manager.processes.copy()
|
||||||
|
while remaining_processes:
|
||||||
|
for proc in remaining_processes:
|
||||||
|
if not proc.is_alive():
|
||||||
|
remaining_processes.remove(proc)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Verify all processes have terminated
|
||||||
|
for i, proc in enumerate(manager.processes):
|
||||||
|
assert not proc.is_alive(
|
||||||
|
), f"Process {i} still alive after terminate()"
|
||||||
|
|
||||||
|
# Now call wait_for_completion_or_failure
|
||||||
|
# since all processes have already
|
||||||
|
# terminated, it should return immediately
|
||||||
|
# with no error
|
||||||
|
wait_for_completion_or_failure(api_server_manager=manager)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up just in case
|
||||||
|
manager.close()
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(30)
|
||||||
|
def test_external_process_monitoring(api_server_args):
|
||||||
|
"""Test that wait_for_completion_or_failure handles additional processes."""
|
||||||
|
global WORKER_RUNTIME_SECONDS
|
||||||
|
WORKER_RUNTIME_SECONDS = 100
|
||||||
|
|
||||||
|
# Create and start the external process
|
||||||
|
# (simulates local_engine_manager or coordinator)
|
||||||
|
spawn_context = multiprocessing.get_context("spawn")
|
||||||
|
external_proc = spawn_context.Process(target=mock_run_api_server_worker,
|
||||||
|
name="MockExternalProcess")
|
||||||
|
external_proc.start()
|
||||||
|
|
||||||
|
# Create the class to simulate a coordinator
|
||||||
|
class MockCoordinator:
|
||||||
|
|
||||||
|
def __init__(self, proc):
|
||||||
|
self.proc = proc
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.proc.is_alive():
|
||||||
|
self.proc.terminate()
|
||||||
|
self.proc.join(timeout=0.5)
|
||||||
|
|
||||||
|
# Create a mock coordinator with the external process
|
||||||
|
mock_coordinator = MockCoordinator(external_proc)
|
||||||
|
|
||||||
|
# Create the API server manager
|
||||||
|
manager = APIServerProcessManager(**api_server_args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify manager initialization
|
||||||
|
assert len(manager.processes) == 3
|
||||||
|
|
||||||
|
# Create a result capture for the thread
|
||||||
|
result: dict[str, Optional[Exception]] = {"exception": None}
|
||||||
|
|
||||||
|
def run_with_exception_capture():
|
||||||
|
try:
|
||||||
|
wait_for_completion_or_failure(api_server_manager=manager,
|
||||||
|
coordinator=mock_coordinator)
|
||||||
|
except Exception as e:
|
||||||
|
result["exception"] = e
|
||||||
|
|
||||||
|
# Start a thread to run wait_for_completion_or_failure
|
||||||
|
wait_thread = threading.Thread(target=run_with_exception_capture,
|
||||||
|
daemon=True)
|
||||||
|
wait_thread.start()
|
||||||
|
|
||||||
|
# Terminate the external process to trigger a failure
|
||||||
|
time.sleep(0.2)
|
||||||
|
external_proc.terminate()
|
||||||
|
|
||||||
|
# Wait for the thread to detect the failure
|
||||||
|
wait_thread.join(timeout=1.0)
|
||||||
|
|
||||||
|
# The wait thread should have completed
|
||||||
|
assert not wait_thread.is_alive(
|
||||||
|
), "wait_for_completion_or_failure thread still running"
|
||||||
|
|
||||||
|
# Verify that an exception was raised with appropriate error message
|
||||||
|
assert result["exception"] is not None, "No exception was raised"
|
||||||
|
error_message = str(result["exception"])
|
||||||
|
assert "died with exit code" in error_message, \
|
||||||
|
f"Unexpected error message: {error_message}"
|
||||||
|
assert "MockExternalProcess" in error_message, \
|
||||||
|
f"Error doesn't mention external process: {error_message}"
|
||||||
|
|
||||||
|
# Verify that all API server processes were terminated as a result
|
||||||
|
for i, proc in enumerate(manager.processes):
|
||||||
|
assert not proc.is_alive(
|
||||||
|
), f"API server process {i} was not terminated"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
manager.close()
|
||||||
|
mock_coordinator.close()
|
||||||
|
time.sleep(0.2)
|
||||||
@ -28,7 +28,7 @@ from tests.models.utils import TextTextLogprobs
|
|||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
from vllm.entrypoints.cli.serve import ServeSubcommand
|
||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
@ -99,7 +99,8 @@ class RemoteOpenAIServer:
|
|||||||
|
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="vLLM's remote OpenAI server.")
|
description="vLLM's remote OpenAI server.")
|
||||||
parser = make_arg_parser(parser)
|
subparsers = parser.add_subparsers(required=False, dest="subparser")
|
||||||
|
parser = ServeSubcommand().subparser_init(subparsers)
|
||||||
args = parser.parse_args(["--model", model, *vllm_serve_args])
|
args = parser.parse_args(["--model", model, *vllm_serve_args])
|
||||||
self.host = str(args.host or 'localhost')
|
self.host = str(args.host or 'localhost')
|
||||||
self.port = int(args.port)
|
self.port = int(args.port)
|
||||||
|
|||||||
@ -45,7 +45,6 @@ def make_request(request_id,
|
|||||||
multi_modal_placeholders=mm_positions,
|
multi_modal_placeholders=mm_positions,
|
||||||
sampling_params=SamplingParams(max_tokens=17),
|
sampling_params=SamplingParams(max_tokens=17),
|
||||||
eos_token_id=100,
|
eos_token_id=100,
|
||||||
arrival_time=0,
|
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
cache_salt=cache_salt,
|
cache_salt=cache_salt,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -38,7 +38,6 @@ def make_request(request_id,
|
|||||||
sampling_params=SamplingParams(max_tokens=17,
|
sampling_params=SamplingParams(max_tokens=17,
|
||||||
prompt_logprobs=prompt_logprobs),
|
prompt_logprobs=prompt_logprobs),
|
||||||
eos_token_id=100,
|
eos_token_id=100,
|
||||||
arrival_time=0,
|
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
cache_salt=cache_salt,
|
cache_salt=cache_salt,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -138,7 +138,6 @@ def create_requests(num_requests: int,
|
|||||||
multi_modal_placeholders=mm_position,
|
multi_modal_placeholders=mm_position,
|
||||||
multi_modal_hashes=None,
|
multi_modal_hashes=None,
|
||||||
eos_token_id=EOS_TOKEN_ID,
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
arrival_time=0,
|
|
||||||
)
|
)
|
||||||
requests.append(request)
|
requests.append(request)
|
||||||
return requests
|
return requests
|
||||||
@ -744,7 +743,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
|||||||
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
|
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
|
||||||
|
|
||||||
# No draft or accepted tokens counted yet
|
# No draft or accepted tokens counted yet
|
||||||
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is None
|
assert not engine_core_outputs or (
|
||||||
|
engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None)
|
||||||
|
|
||||||
# Schedule the speculated tokens for validation
|
# Schedule the speculated tokens for validation
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
@ -772,7 +772,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
|||||||
engine_core_outputs = scheduler.update_from_output(output,
|
engine_core_outputs = scheduler.update_from_output(output,
|
||||||
model_runner_output)
|
model_runner_output)
|
||||||
|
|
||||||
scheduler_stats = engine_core_outputs.scheduler_stats
|
scheduler_stats = engine_core_outputs[0].scheduler_stats \
|
||||||
|
if engine_core_outputs else None
|
||||||
if expected[0] == 0:
|
if expected[0] == 0:
|
||||||
assert scheduler_stats.spec_decoding_stats is None
|
assert scheduler_stats.spec_decoding_stats is None
|
||||||
else:
|
else:
|
||||||
@ -843,7 +844,7 @@ def _step_until_done(
|
|||||||
# We should be in the decode phase now.
|
# We should be in the decode phase now.
|
||||||
assert num_scheduled_tokens == 1
|
assert num_scheduled_tokens == 1
|
||||||
assert len(output.kv_connector_metadata.requests) == 0
|
assert len(output.kv_connector_metadata.requests) == 0
|
||||||
ecos = scheduler.update_from_output(output, model_runner_output)
|
ecos = scheduler.update_from_output(output, model_runner_output)[0]
|
||||||
all_done = True
|
all_done = True
|
||||||
for eco in ecos.outputs:
|
for eco in ecos.outputs:
|
||||||
if eco.finish_reason is None:
|
if eco.finish_reason is None:
|
||||||
|
|||||||
@ -88,7 +88,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert len(engine_core.scheduler.running) == 4
|
assert len(engine_core.scheduler.running) == 4
|
||||||
|
|
||||||
# Loop through until they are all done.
|
# Loop through until they are all done.
|
||||||
while len(engine_core.step()[0].outputs) > 0:
|
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert len(engine_core.scheduler.waiting) == 0
|
assert len(engine_core.scheduler.waiting) == 0
|
||||||
@ -163,11 +163,11 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
|
|||||||
req0.request_id = req1.request_id = "test"
|
req0.request_id = req1.request_id = "test"
|
||||||
engine_core.add_request(req0)
|
engine_core.add_request(req0)
|
||||||
|
|
||||||
while len(engine_core.step()[0].outputs) > 0:
|
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
engine_core.add_request(req1)
|
engine_core.add_request(req1)
|
||||||
while len(engine_core.step()[0].outputs) > 0:
|
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert len(engine_core.scheduler.waiting) == 0
|
assert len(engine_core.scheduler.waiting) == 0
|
||||||
@ -207,7 +207,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert len(engine_core.scheduler.waiting) == 1
|
assert len(engine_core.scheduler.waiting) == 1
|
||||||
assert len(engine_core.scheduler.running) == 0
|
assert len(engine_core.scheduler.running) == 0
|
||||||
# Loop through until they are all done.
|
# Loop through until they are all done.
|
||||||
while len(engine_core.step()[0].outputs) > 0:
|
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
||||||
pass
|
pass
|
||||||
assert len(engine_core.scheduler.waiting) == 0
|
assert len(engine_core.scheduler.waiting) == 0
|
||||||
assert len(engine_core.scheduler.running) == 0
|
assert len(engine_core.scheduler.running) == 0
|
||||||
@ -327,7 +327,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert scheduler_output.num_scheduled_tokens[1] == 4
|
assert scheduler_output.num_scheduled_tokens[1] == 4
|
||||||
|
|
||||||
# Batch queue is full. Finish Batch 2. Get first token of req0.
|
# Batch queue is full. Finish Batch 2. Get first token of req0.
|
||||||
output = engine_core.step_with_batch_queue()[0]
|
output = engine_core.step_with_batch_queue()[0].get(0)
|
||||||
assert output is not None
|
assert output is not None
|
||||||
assert len(output.outputs) == 1
|
assert len(output.outputs) == 1
|
||||||
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
|
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
|
||||||
@ -339,7 +339,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert scheduler_output.num_scheduled_tokens[0] == 1
|
assert scheduler_output.num_scheduled_tokens[0] == 1
|
||||||
|
|
||||||
# Batch queue is full. Finish Batch 3. Get first token of req1.
|
# Batch queue is full. Finish Batch 3. Get first token of req1.
|
||||||
output = engine_core.step_with_batch_queue()[0]
|
output = engine_core.step_with_batch_queue()[0].get(0)
|
||||||
assert output is not None
|
assert output is not None
|
||||||
assert len(output.outputs) == 1
|
assert len(output.outputs) == 1
|
||||||
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
|
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
|
||||||
@ -362,7 +362,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
|||||||
if step % 2 == 0:
|
if step % 2 == 0:
|
||||||
# Even steps consumes an output.
|
# Even steps consumes an output.
|
||||||
assert output is not None
|
assert output is not None
|
||||||
assert len(output.outputs) == 1
|
assert len(output[0].outputs) == 1
|
||||||
if req_id in engine_core.scheduler.requests:
|
if req_id in engine_core.scheduler.requests:
|
||||||
assert engine_core.scheduler.requests[
|
assert engine_core.scheduler.requests[
|
||||||
req_id].num_tokens == expected_num_tokens[req_id]
|
req_id].num_tokens == expected_num_tokens[req_id]
|
||||||
|
|||||||
171
tests/v1/entrypoints/openai/test_multi_api_servers.py
Normal file
171
tests/v1/entrypoints/openai/test_multi_api_servers.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
import openai # use the official client for correctness check
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from tests.utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
||||||
|
|
||||||
|
DP_SIZE = os.getenv("DP_SIZE", "1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def default_server_args():
|
||||||
|
return [
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"128",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--api-server-count",
|
||||||
|
"4",
|
||||||
|
"--data_parallel_size",
|
||||||
|
DP_SIZE,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server(default_server_args):
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client(server):
|
||||||
|
async with server.get_async_client() as async_client:
|
||||||
|
yield async_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[MODEL_NAME],
|
||||||
|
)
|
||||||
|
async def test_single_completion(client: openai.AsyncOpenAI,
|
||||||
|
model_name: str) -> None:
|
||||||
|
|
||||||
|
async def make_request():
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
prompt="Hello, my name is",
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=1.0)
|
||||||
|
|
||||||
|
assert completion.id is not None
|
||||||
|
assert completion.choices is not None and len(completion.choices) == 1
|
||||||
|
|
||||||
|
choice = completion.choices[0]
|
||||||
|
# The exact number of tokens can vary slightly with temperature=1.0,
|
||||||
|
# so we check for a reasonable minimum length.
|
||||||
|
assert len(choice.text) >= 1
|
||||||
|
# Finish reason might not always be 'length' if the model finishes early
|
||||||
|
# or due to other reasons, especially with high temperature.
|
||||||
|
# So, we'll accept 'length' or 'stop'.
|
||||||
|
assert choice.finish_reason in ("length", "stop")
|
||||||
|
|
||||||
|
# Token counts can also vary, so we check they are positive.
|
||||||
|
assert completion.usage.completion_tokens > 0
|
||||||
|
assert completion.usage.prompt_tokens > 0
|
||||||
|
assert completion.usage.total_tokens > 0
|
||||||
|
return completion
|
||||||
|
|
||||||
|
# Test single request
|
||||||
|
result = await make_request()
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# Send two bursts of requests
|
||||||
|
num_requests = 100
|
||||||
|
tasks = [make_request() for _ in range(num_requests)]
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
assert len(results) == num_requests
|
||||||
|
assert all(completion is not None for completion in results)
|
||||||
|
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
tasks = [make_request() for _ in range(num_requests)]
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
assert len(results) == num_requests
|
||||||
|
assert all(completion is not None for completion in results)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[MODEL_NAME],
|
||||||
|
)
|
||||||
|
async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||||
|
model_name: str) -> None:
|
||||||
|
prompt = "What is an LLM?"
|
||||||
|
|
||||||
|
async def make_streaming_request():
|
||||||
|
# Perform a non-streaming request to get the expected full output
|
||||||
|
single_completion = await client.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
single_output = single_completion.choices[0].text
|
||||||
|
|
||||||
|
# Perform the streaming request
|
||||||
|
stream = await client.completions.create(model=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
stream=True)
|
||||||
|
chunks: list[str] = []
|
||||||
|
finish_reason_count = 0
|
||||||
|
last_chunk = None
|
||||||
|
async for chunk in stream:
|
||||||
|
chunks.append(chunk.choices[0].text)
|
||||||
|
if chunk.choices[0].finish_reason is not None:
|
||||||
|
finish_reason_count += 1
|
||||||
|
last_chunk = chunk # Keep track of the last chunk
|
||||||
|
|
||||||
|
# finish reason should only return in the last block for OpenAI API
|
||||||
|
assert finish_reason_count == 1, (
|
||||||
|
"Finish reason should appear exactly once.")
|
||||||
|
assert last_chunk is not None, (
|
||||||
|
"Stream should have yielded at least one chunk.")
|
||||||
|
assert last_chunk.choices[
|
||||||
|
0].finish_reason == "length", "Finish reason should be 'length'."
|
||||||
|
# Check that the combined text matches the non-streamed version.
|
||||||
|
assert "".join(
|
||||||
|
chunks
|
||||||
|
) == single_output, "Streamed output should match non-streamed output."
|
||||||
|
return True # Indicate success for this request
|
||||||
|
|
||||||
|
# Test single request
|
||||||
|
result = await make_streaming_request()
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# Send two bursts of requests
|
||||||
|
num_requests = 100
|
||||||
|
tasks = [make_streaming_request() for _ in range(num_requests)]
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
assert len(
|
||||||
|
results
|
||||||
|
) == num_requests, f"Expected {num_requests} results, got {len(results)}"
|
||||||
|
assert all(results), "Not all streaming requests completed successfully."
|
||||||
|
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
tasks = [make_streaming_request() for _ in range(num_requests)]
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
assert len(
|
||||||
|
results
|
||||||
|
) == num_requests, f"Expected {num_requests} results, got {len(results)}"
|
||||||
|
assert all(results), "Not all streaming requests completed successfully."
|
||||||
@ -43,7 +43,7 @@ def test_basic_lifecycle():
|
|||||||
# Ensure the request is finished after 1 tokens.
|
# Ensure the request is finished after 1 tokens.
|
||||||
assert request.is_finished()
|
assert request.is_finished()
|
||||||
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||||
output = engine_core_outputs.outputs[0]
|
output = engine_core_outputs[0].outputs[0]
|
||||||
assert output.finish_reason == FinishReason.LENGTH
|
assert output.finish_reason == FinishReason.LENGTH
|
||||||
assert output.kv_transfer_params is not None
|
assert output.kv_transfer_params is not None
|
||||||
|
|
||||||
@ -165,7 +165,7 @@ def test_prefix_cache_lifecycle():
|
|||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||||
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
kv_transfer_params = eco.outputs[0].kv_transfer_params
|
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
|
||||||
|
|
||||||
# Ensure we send all block ids, even if there is a cache hit.
|
# Ensure we send all block ids, even if there is a cache hit.
|
||||||
assert (len(
|
assert (len(
|
||||||
|
|||||||
@ -61,7 +61,7 @@ def test_basic_lifecycle():
|
|||||||
# (1c): update_from_output()
|
# (1c): update_from_output()
|
||||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||||
model_runner_output)
|
model_runner_output)
|
||||||
assert len(engine_core_outputs.outputs) == 0
|
assert not engine_core_outputs or not engine_core_outputs[0].outputs
|
||||||
|
|
||||||
# STEP (2):
|
# STEP (2):
|
||||||
# (2a): schedule(): nothing happens!
|
# (2a): schedule(): nothing happens!
|
||||||
@ -112,7 +112,7 @@ def test_basic_lifecycle():
|
|||||||
model_runner_output)
|
model_runner_output)
|
||||||
scheduler.schedule()
|
scheduler.schedule()
|
||||||
|
|
||||||
outputs = engine_core_outputs.outputs
|
outputs = engine_core_outputs[0].outputs
|
||||||
assert len(outputs) == 1
|
assert len(outputs) == 1
|
||||||
output = outputs[0]
|
output = outputs[0]
|
||||||
assert output.finish_reason == FinishReason.STOP
|
assert output.finish_reason == FinishReason.STOP
|
||||||
@ -335,7 +335,7 @@ def test_full_block_prompt():
|
|||||||
model_runner_output)
|
model_runner_output)
|
||||||
scheduler.schedule()
|
scheduler.schedule()
|
||||||
|
|
||||||
outputs = engine_core_outputs.outputs
|
outputs = engine_core_outputs[0].outputs
|
||||||
assert len(outputs) == 1
|
assert len(outputs) == 1
|
||||||
output = outputs[0]
|
output = outputs[0]
|
||||||
assert output.finish_reason == FinishReason.STOP
|
assert output.finish_reason == FinishReason.STOP
|
||||||
|
|||||||
@ -153,7 +153,6 @@ def create_request(
|
|||||||
multi_modal_placeholders=None,
|
multi_modal_placeholders=None,
|
||||||
multi_modal_hashes=None,
|
multi_modal_hashes=None,
|
||||||
eos_token_id=EOS_TOKEN_ID,
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
arrival_time=0,
|
|
||||||
)
|
)
|
||||||
req.kv_transfer_params = kv_transfer_params
|
req.kv_transfer_params = kv_transfer_params
|
||||||
return req
|
return req
|
||||||
|
|||||||
@ -1,24 +1,35 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
import signal
|
import signal
|
||||||
|
import sys
|
||||||
|
|
||||||
import uvloop
|
import uvloop
|
||||||
|
import zmq
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import AsyncEngineArgs
|
from vllm import AsyncEngineArgs
|
||||||
from vllm.entrypoints.cli.types import CLISubcommand
|
from vllm.entrypoints.cli.types import CLISubcommand
|
||||||
from vllm.entrypoints.openai.api_server import run_server
|
from vllm.entrypoints.openai.api_server import (run_server, run_server_worker,
|
||||||
|
setup_server)
|
||||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||||
validate_parsed_serve_args)
|
validate_parsed_serve_args)
|
||||||
from vllm.entrypoints.utils import (VLLM_SERVE_PARSER_EPILOG,
|
from vllm.entrypoints.utils import (VLLM_SERVE_PARSER_EPILOG,
|
||||||
show_filtered_argument_or_group_from_help)
|
show_filtered_argument_or_group_from_help)
|
||||||
|
from vllm.executor.multiproc_worker_utils import _add_prefix
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
|
from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx
|
||||||
|
from vllm.v1.engine.coordinator import DPCoordinator
|
||||||
from vllm.v1.engine.core import EngineCoreProc
|
from vllm.v1.engine.core import EngineCoreProc
|
||||||
from vllm.v1.engine.core_client import CoreEngineProcManager
|
from vllm.v1.engine.core_client import CoreEngineProcManager
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
|
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
|
||||||
|
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
|
||||||
|
EngineZmqAddresses, get_engine_client_zmq_addr,
|
||||||
|
wait_for_completion_or_failure,
|
||||||
|
wait_for_engine_startup)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -36,9 +47,12 @@ class ServeSubcommand(CLISubcommand):
|
|||||||
if hasattr(args, 'model_tag') and args.model_tag is not None:
|
if hasattr(args, 'model_tag') and args.model_tag is not None:
|
||||||
args.model = args.model_tag
|
args.model = args.model_tag
|
||||||
|
|
||||||
if args.headless:
|
if args.headless or args.api_server_count < 1:
|
||||||
run_headless(args)
|
run_headless(args)
|
||||||
|
elif args.api_server_count > 1:
|
||||||
|
run_multi_api_server(args)
|
||||||
else:
|
else:
|
||||||
|
# Single API server (this process).
|
||||||
uvloop.run(run_server(args))
|
uvloop.run(run_server(args))
|
||||||
|
|
||||||
def validate(self, args: argparse.Namespace) -> None:
|
def validate(self, args: argparse.Namespace) -> None:
|
||||||
@ -69,6 +83,11 @@ class ServeSubcommand(CLISubcommand):
|
|||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help='Starting data parallel rank for secondary nodes.')
|
help='Starting data parallel rank for secondary nodes.')
|
||||||
|
serve_parser.add_argument('--api-server-count',
|
||||||
|
'-asc',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='How many API server processes to run.')
|
||||||
serve_parser.add_argument(
|
serve_parser.add_argument(
|
||||||
"--config",
|
"--config",
|
||||||
type=str,
|
type=str,
|
||||||
@ -91,22 +110,25 @@ def cmd_init() -> list[CLISubcommand]:
|
|||||||
|
|
||||||
def run_headless(args: argparse.Namespace):
|
def run_headless(args: argparse.Namespace):
|
||||||
|
|
||||||
|
if args.api_server_count > 1:
|
||||||
|
raise ValueError("api_server_count can't be set in headless mode")
|
||||||
|
|
||||||
# Create the EngineConfig.
|
# Create the EngineConfig.
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
usage_context = UsageContext.OPENAI_API_SERVER
|
usage_context = UsageContext.OPENAI_API_SERVER
|
||||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||||
|
|
||||||
if not envs.VLLM_USE_V1:
|
if not envs.VLLM_USE_V1:
|
||||||
raise RuntimeError("Headless mode is only supported for V1")
|
raise ValueError("Headless mode is only supported for V1")
|
||||||
|
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
local_engine_count = parallel_config.data_parallel_size_local
|
local_engine_count = parallel_config.data_parallel_size_local
|
||||||
host = parallel_config.data_parallel_master_ip
|
host = parallel_config.data_parallel_master_ip
|
||||||
port = engine_args.data_parallel_rpc_port # add to config too
|
port = engine_args.data_parallel_rpc_port # add to config too
|
||||||
input_address = get_tcp_uri(host, port)
|
handshake_address = get_tcp_uri(host, port)
|
||||||
|
|
||||||
if local_engine_count <= 0:
|
if local_engine_count <= 0:
|
||||||
raise RuntimeError("data_parallel_size_local must be > 0 in "
|
raise ValueError("data_parallel_size_local must be > 0 in "
|
||||||
"headless mode")
|
"headless mode")
|
||||||
|
|
||||||
# Catch SIGTERM and SIGINT to allow graceful shutdown.
|
# Catch SIGTERM and SIGINT to allow graceful shutdown.
|
||||||
@ -119,7 +141,7 @@ def run_headless(args: argparse.Namespace):
|
|||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Launching %d data parallel engine(s) in headless mode, "
|
"Launching %d data parallel engine(s) in headless mode, "
|
||||||
"with head node address %s.", local_engine_count, input_address)
|
"with head node address %s.", local_engine_count, handshake_address)
|
||||||
|
|
||||||
# Create the engines.
|
# Create the engines.
|
||||||
engine_manager = CoreEngineProcManager(
|
engine_manager = CoreEngineProcManager(
|
||||||
@ -129,7 +151,7 @@ def run_headless(args: argparse.Namespace):
|
|||||||
local_start_index=0,
|
local_start_index=0,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
on_head_node=False,
|
on_head_node=False,
|
||||||
input_address=input_address,
|
handshake_address=handshake_address,
|
||||||
executor_class=Executor.get_class(vllm_config),
|
executor_class=Executor.get_class(vllm_config),
|
||||||
log_stats=not engine_args.disable_log_stats,
|
log_stats=not engine_args.disable_log_stats,
|
||||||
)
|
)
|
||||||
@ -139,3 +161,142 @@ def run_headless(args: argparse.Namespace):
|
|||||||
finally:
|
finally:
|
||||||
logger.info("Shutting down.")
|
logger.info("Shutting down.")
|
||||||
engine_manager.close()
|
engine_manager.close()
|
||||||
|
|
||||||
|
|
||||||
|
def run_multi_api_server(args: argparse.Namespace):
|
||||||
|
|
||||||
|
assert not args.headless
|
||||||
|
num_api_servers = args.api_server_count
|
||||||
|
assert num_api_servers > 0
|
||||||
|
|
||||||
|
if num_api_servers > 1:
|
||||||
|
setup_multiprocess_prometheus()
|
||||||
|
|
||||||
|
listen_address, sock = setup_server(args)
|
||||||
|
|
||||||
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
|
usage_context = UsageContext.OPENAI_API_SERVER
|
||||||
|
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
|
||||||
|
if num_api_servers > 1:
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
|
raise ValueError("api_server_count > 1 is only supported for V1")
|
||||||
|
|
||||||
|
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||||
|
raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
|
||||||
|
"with api_server_count > 1")
|
||||||
|
|
||||||
|
if model_config.is_multimodal_model and not (
|
||||||
|
model_config.disable_mm_preprocessor_cache):
|
||||||
|
logger.warning(
|
||||||
|
"Multi-model preprocessor cache will be disabled for"
|
||||||
|
" api_server_count > 1")
|
||||||
|
model_config.disable_mm_preprocessor_cache = True
|
||||||
|
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
|
||||||
|
assert parallel_config.data_parallel_rank == 0
|
||||||
|
|
||||||
|
dp_size = parallel_config.data_parallel_size
|
||||||
|
local_engine_count = parallel_config.data_parallel_size_local
|
||||||
|
host = parallel_config.data_parallel_master_ip
|
||||||
|
local_only = local_engine_count == dp_size
|
||||||
|
|
||||||
|
# Set up input and output addresses.
|
||||||
|
input_addresses = [
|
||||||
|
get_engine_client_zmq_addr(local_only, host)
|
||||||
|
for _ in range(num_api_servers)
|
||||||
|
]
|
||||||
|
output_addresses = [
|
||||||
|
get_engine_client_zmq_addr(local_only, host)
|
||||||
|
for _ in range(num_api_servers)
|
||||||
|
]
|
||||||
|
|
||||||
|
addresses = EngineZmqAddresses(
|
||||||
|
inputs=input_addresses,
|
||||||
|
outputs=output_addresses,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up coordinator for dp > 1.
|
||||||
|
coordinator = None
|
||||||
|
stats_update_address = None
|
||||||
|
if dp_size > 1:
|
||||||
|
coordinator = DPCoordinator(parallel_config)
|
||||||
|
addresses.coordinator_input, addresses.coordinator_output = (
|
||||||
|
coordinator.get_engine_socket_addresses())
|
||||||
|
stats_update_address = coordinator.get_stats_publish_address()
|
||||||
|
logger.info("Started DP Coordinator process (PID: %d)",
|
||||||
|
coordinator.proc.pid)
|
||||||
|
|
||||||
|
handshake_address = get_engine_client_zmq_addr(
|
||||||
|
local_only, host, parallel_config.data_parallel_rpc_port)
|
||||||
|
|
||||||
|
with zmq_socket_ctx(handshake_address, zmq.ROUTER,
|
||||||
|
bind=True) as handshake_socket:
|
||||||
|
|
||||||
|
# Start local engines.
|
||||||
|
if not local_engine_count:
|
||||||
|
local_engine_manager = None
|
||||||
|
else:
|
||||||
|
local_engine_manager = CoreEngineProcManager(
|
||||||
|
EngineCoreProc.run_engine_core,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
executor_class=Executor.get_class(vllm_config),
|
||||||
|
log_stats=not engine_args.disable_log_stats,
|
||||||
|
handshake_address=handshake_address,
|
||||||
|
on_head_node=True,
|
||||||
|
local_engine_count=local_engine_count,
|
||||||
|
start_index=0,
|
||||||
|
local_start_index=0)
|
||||||
|
|
||||||
|
# Start API servers using the manager
|
||||||
|
api_server_manager = APIServerProcessManager(
|
||||||
|
target_server_fn=run_api_server_worker_proc,
|
||||||
|
listen_address=listen_address,
|
||||||
|
sock=sock,
|
||||||
|
args=args,
|
||||||
|
num_servers=num_api_servers,
|
||||||
|
input_addresses=input_addresses,
|
||||||
|
output_addresses=output_addresses,
|
||||||
|
stats_update_address=stats_update_address)
|
||||||
|
|
||||||
|
# Wait for engine handshakes to complete.
|
||||||
|
core_engines = [
|
||||||
|
CoreEngine(index=i, local=(i < local_engine_count))
|
||||||
|
for i in range(dp_size)
|
||||||
|
]
|
||||||
|
wait_for_engine_startup(
|
||||||
|
handshake_socket,
|
||||||
|
addresses,
|
||||||
|
core_engines,
|
||||||
|
parallel_config,
|
||||||
|
vllm_config.cache_config,
|
||||||
|
local_engine_manager,
|
||||||
|
coordinator.proc if coordinator else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for API servers
|
||||||
|
wait_for_completion_or_failure(
|
||||||
|
api_server_manager=api_server_manager,
|
||||||
|
local_engine_manager=local_engine_manager,
|
||||||
|
coordinator=coordinator)
|
||||||
|
|
||||||
|
|
||||||
|
def run_api_server_worker_proc(listen_address,
|
||||||
|
sock,
|
||||||
|
args,
|
||||||
|
client_config=None,
|
||||||
|
**uvicorn_kwargs) -> None:
|
||||||
|
"""Entrypoint for individual API server worker processes."""
|
||||||
|
|
||||||
|
# Add process-specific prefix to stdout and stderr.
|
||||||
|
from multiprocessing import current_process
|
||||||
|
process_name = current_process().name
|
||||||
|
pid = os.getpid()
|
||||||
|
_add_prefix(sys.stdout, process_name, pid)
|
||||||
|
_add_prefix(sys.stderr, process_name, pid)
|
||||||
|
|
||||||
|
uvloop.run(
|
||||||
|
run_server_worker(listen_address, sock, args, client_config,
|
||||||
|
**uvicorn_kwargs))
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from contextlib import asynccontextmanager
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Any, Optional
|
||||||
|
|
||||||
import prometheus_client
|
import prometheus_client
|
||||||
import regex as re
|
import regex as re
|
||||||
@ -26,6 +26,8 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
|||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
|
from prometheus_client import make_asgi_app
|
||||||
|
from prometheus_fastapi_instrumentator import Instrumentator
|
||||||
from starlette.concurrency import iterate_in_threadpool
|
from starlette.concurrency import iterate_in_threadpool
|
||||||
from starlette.datastructures import State
|
from starlette.datastructures import State
|
||||||
from starlette.routing import Mount
|
from starlette.routing import Mount
|
||||||
@ -97,6 +99,7 @@ from vllm.transformers_utils.tokenizer import MistralTokenizer
|
|||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
|
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
|
||||||
is_valid_ipv6_address, set_ulimit)
|
is_valid_ipv6_address, set_ulimit)
|
||||||
|
from vllm.v1.metrics.prometheus import get_prometheus_registry
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||||
@ -142,14 +145,17 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def build_async_engine_client(
|
async def build_async_engine_client(
|
||||||
args: Namespace) -> AsyncIterator[EngineClient]:
|
args: Namespace,
|
||||||
|
client_config: Optional[dict[str, Any]] = None,
|
||||||
|
) -> AsyncIterator[EngineClient]:
|
||||||
|
|
||||||
# Context manager to handle engine_client lifecycle
|
# Context manager to handle engine_client lifecycle
|
||||||
# Ensures everything is shutdown and cleaned up on error/exit
|
# Ensures everything is shutdown and cleaned up on error/exit
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
|
|
||||||
async with build_async_engine_client_from_engine_args(
|
async with build_async_engine_client_from_engine_args(
|
||||||
engine_args, args.disable_frontend_multiprocessing) as engine:
|
engine_args, args.disable_frontend_multiprocessing,
|
||||||
|
client_config) as engine:
|
||||||
yield engine
|
yield engine
|
||||||
|
|
||||||
|
|
||||||
@ -157,6 +163,7 @@ async def build_async_engine_client(
|
|||||||
async def build_async_engine_client_from_engine_args(
|
async def build_async_engine_client_from_engine_args(
|
||||||
engine_args: AsyncEngineArgs,
|
engine_args: AsyncEngineArgs,
|
||||||
disable_frontend_multiprocessing: bool = False,
|
disable_frontend_multiprocessing: bool = False,
|
||||||
|
client_config: Optional[dict[str, Any]] = None,
|
||||||
) -> AsyncIterator[EngineClient]:
|
) -> AsyncIterator[EngineClient]:
|
||||||
"""
|
"""
|
||||||
Create EngineClient, either:
|
Create EngineClient, either:
|
||||||
@ -179,12 +186,16 @@ async def build_async_engine_client_from_engine_args(
|
|||||||
|
|
||||||
from vllm.v1.engine.async_llm import AsyncLLM
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
async_llm: Optional[AsyncLLM] = None
|
async_llm: Optional[AsyncLLM] = None
|
||||||
|
client_index = client_config.pop(
|
||||||
|
"client_index") if client_config else 0
|
||||||
try:
|
try:
|
||||||
async_llm = AsyncLLM.from_vllm_config(
|
async_llm = AsyncLLM.from_vllm_config(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
usage_context=usage_context,
|
usage_context=usage_context,
|
||||||
disable_log_requests=engine_args.disable_log_requests,
|
disable_log_requests=engine_args.disable_log_requests,
|
||||||
disable_log_stats=engine_args.disable_log_stats)
|
disable_log_stats=engine_args.disable_log_stats,
|
||||||
|
client_addresses=client_config,
|
||||||
|
client_index=client_index)
|
||||||
|
|
||||||
# Don't keep the dummy data in memory
|
# Don't keep the dummy data in memory
|
||||||
await async_llm.reset_mm_cache()
|
await async_llm.reset_mm_cache()
|
||||||
@ -318,22 +329,9 @@ class PrometheusResponse(Response):
|
|||||||
|
|
||||||
|
|
||||||
def mount_metrics(app: FastAPI):
|
def mount_metrics(app: FastAPI):
|
||||||
# Lazy import for prometheus multiprocessing.
|
"""Mount prometheus metrics to a FastAPI app."""
|
||||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
|
||||||
# before prometheus_client is imported.
|
|
||||||
# See https://prometheus.github.io/client_python/multiprocess/
|
|
||||||
from prometheus_client import (REGISTRY, CollectorRegistry, make_asgi_app,
|
|
||||||
multiprocess)
|
|
||||||
from prometheus_fastapi_instrumentator import Instrumentator
|
|
||||||
|
|
||||||
registry = REGISTRY
|
registry = get_prometheus_registry()
|
||||||
|
|
||||||
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
|
||||||
if prometheus_multiproc_dir_path is not None:
|
|
||||||
logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
|
|
||||||
prometheus_multiproc_dir_path)
|
|
||||||
registry = CollectorRegistry()
|
|
||||||
multiprocess.MultiProcessCollector(registry)
|
|
||||||
|
|
||||||
# `response_class=PrometheusResponse` is needed to return an HTTP response
|
# `response_class=PrometheusResponse` is needed to return an HTTP response
|
||||||
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
|
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
|
||||||
@ -1256,13 +1254,7 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket:
|
|||||||
return sock
|
return sock
|
||||||
|
|
||||||
|
|
||||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
def validate_api_server_args(args):
|
||||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
|
||||||
log_non_default_args(args)
|
|
||||||
|
|
||||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
|
||||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
|
||||||
|
|
||||||
valid_tool_parses = ToolParserManager.tool_parsers.keys()
|
valid_tool_parses = ToolParserManager.tool_parsers.keys()
|
||||||
if args.enable_auto_tool_choice \
|
if args.enable_auto_tool_choice \
|
||||||
and args.tool_call_parser not in valid_tool_parses:
|
and args.tool_call_parser not in valid_tool_parses:
|
||||||
@ -1276,6 +1268,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|||||||
f"invalid reasoning parser: {args.reasoning_parser} "
|
f"invalid reasoning parser: {args.reasoning_parser} "
|
||||||
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
|
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_server(args):
|
||||||
|
"""Validate API server args, set up signal handler, create socket
|
||||||
|
ready to serve."""
|
||||||
|
|
||||||
|
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||||
|
log_non_default_args(args)
|
||||||
|
|
||||||
|
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||||
|
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||||
|
|
||||||
|
validate_api_server_args(args)
|
||||||
|
|
||||||
# workaround to make sure that we bind the port before the engine is set up.
|
# workaround to make sure that we bind the port before the engine is set up.
|
||||||
# This avoids race conditions with ray.
|
# This avoids race conditions with ray.
|
||||||
# see https://github.com/vllm-project/vllm/issues/8204
|
# see https://github.com/vllm-project/vllm/issues/8204
|
||||||
@ -1292,22 +1297,41 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|||||||
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
async with build_async_engine_client(args) as engine_client:
|
addr, port = sock_addr
|
||||||
|
is_ssl = args.ssl_keyfile and args.ssl_certfile
|
||||||
|
host_part = f"[{addr}]" if is_valid_ipv6_address(
|
||||||
|
addr) else addr or "0.0.0.0"
|
||||||
|
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
|
||||||
|
|
||||||
|
return listen_address, sock
|
||||||
|
|
||||||
|
|
||||||
|
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||||
|
"""Run a single-worker API server."""
|
||||||
|
listen_address, sock = setup_server(args)
|
||||||
|
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_server_worker(listen_address,
|
||||||
|
sock,
|
||||||
|
args,
|
||||||
|
client_config=None,
|
||||||
|
**uvicorn_kwargs) -> None:
|
||||||
|
"""Run a single API server worker."""
|
||||||
|
|
||||||
|
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||||
|
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||||
|
|
||||||
|
server_index = client_config.get("client_index", 0) if client_config else 0
|
||||||
|
|
||||||
|
async with build_async_engine_client(args, client_config) as engine_client:
|
||||||
app = build_app(args)
|
app = build_app(args)
|
||||||
|
|
||||||
vllm_config = await engine_client.get_vllm_config()
|
vllm_config = await engine_client.get_vllm_config()
|
||||||
await init_app_state(engine_client, vllm_config, app.state, args)
|
await init_app_state(engine_client, vllm_config, app.state, args)
|
||||||
|
|
||||||
def _listen_addr(a: str) -> str:
|
logger.info("Starting vLLM API server %d on %s", server_index,
|
||||||
if is_valid_ipv6_address(a):
|
listen_address)
|
||||||
return '[' + a + ']'
|
|
||||||
return a or "0.0.0.0"
|
|
||||||
|
|
||||||
is_ssl = args.ssl_keyfile and args.ssl_certfile
|
|
||||||
logger.info("Starting vLLM API server on http%s://%s:%d",
|
|
||||||
"s" if is_ssl else "", _listen_addr(sock_addr[0]),
|
|
||||||
sock_addr[1])
|
|
||||||
|
|
||||||
shutdown_task = await serve_http(
|
shutdown_task = await serve_http(
|
||||||
app,
|
app,
|
||||||
sock=sock,
|
sock=sock,
|
||||||
|
|||||||
@ -229,6 +229,11 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
|||||||
self.add_adapter(lora)
|
self.add_adapter(lora)
|
||||||
|
|
||||||
def add_adapter(self, lora_request: LoRARequest) -> bool:
|
def add_adapter(self, lora_request: LoRARequest) -> bool:
|
||||||
|
# Note that this method is not thread-safe. It may be invoked multiple
|
||||||
|
# times for the same adapter when using multiple API servers.
|
||||||
|
# This is ok because it's currently only called from
|
||||||
|
# the single-threaded core engine loop.
|
||||||
|
|
||||||
if lora_request.lora_int_id not in self.list_adapters():
|
if lora_request.lora_int_id not in self.list_adapters():
|
||||||
# Load the new adapter first to ensure it is actually valid, before
|
# Load the new adapter first to ensure it is actually valid, before
|
||||||
# evicting any existing adapters.
|
# evicting any existing adapters.
|
||||||
|
|||||||
@ -2420,6 +2420,7 @@ def make_zmq_socket(
|
|||||||
socket_type: Any,
|
socket_type: Any,
|
||||||
bind: Optional[bool] = None,
|
bind: Optional[bool] = None,
|
||||||
identity: Optional[bytes] = None,
|
identity: Optional[bytes] = None,
|
||||||
|
linger: Optional[int] = None,
|
||||||
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
|
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
|
||||||
"""Make a ZMQ socket with the proper bind/connect semantics."""
|
"""Make a ZMQ socket with the proper bind/connect semantics."""
|
||||||
|
|
||||||
@ -2439,7 +2440,7 @@ def make_zmq_socket(
|
|||||||
buf_size = -1 # Use system default buffer size
|
buf_size = -1 # Use system default buffer size
|
||||||
|
|
||||||
if bind is None:
|
if bind is None:
|
||||||
bind = socket_type != zmq.PUSH
|
bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB)
|
||||||
|
|
||||||
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
|
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
|
||||||
socket.setsockopt(zmq.RCVHWM, 0)
|
socket.setsockopt(zmq.RCVHWM, 0)
|
||||||
@ -2452,6 +2453,9 @@ def make_zmq_socket(
|
|||||||
if identity is not None:
|
if identity is not None:
|
||||||
socket.setsockopt(zmq.IDENTITY, identity)
|
socket.setsockopt(zmq.IDENTITY, identity)
|
||||||
|
|
||||||
|
if linger is not None:
|
||||||
|
socket.setsockopt(zmq.LINGER, linger)
|
||||||
|
|
||||||
# Determine if the path is a TCP socket with an IPv6 address.
|
# Determine if the path is a TCP socket with an IPv6 address.
|
||||||
# Enable IPv6 on the zmq socket if so.
|
# Enable IPv6 on the zmq socket if so.
|
||||||
scheme, host, _ = split_zmq_path(path)
|
scheme, host, _ = split_zmq_path(path)
|
||||||
|
|||||||
@ -45,7 +45,7 @@ class SchedulerInterface(ABC):
|
|||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
model_runner_output: "ModelRunnerOutput",
|
model_runner_output: "ModelRunnerOutput",
|
||||||
) -> "EngineCoreOutputs":
|
) -> dict[int, "EngineCoreOutputs"]:
|
||||||
"""Update the scheduler state based on the model runner output.
|
"""Update the scheduler state based on the model runner output.
|
||||||
|
|
||||||
This method is called after the model runner has processed the scheduled
|
This method is called after the model runner has processed the scheduled
|
||||||
@ -55,7 +55,8 @@ class SchedulerInterface(ABC):
|
|||||||
for each request.
|
for each request.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A EngineCoreOutputs object containing the outputs for each request.
|
A dict of client index to EngineCoreOutputs object containing the
|
||||||
|
outputs for each request originating from that client.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -126,6 +127,11 @@ class SchedulerInterface(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_request_counts(self) -> tuple[int, int]:
|
||||||
|
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def make_stats(self) -> Optional["SchedulerStats"]:
|
def make_stats(self) -> Optional["SchedulerStats"]:
|
||||||
"""Make a SchedulerStats object for logging.
|
"""Make a SchedulerStats object for logging.
|
||||||
|
|||||||
@ -58,7 +58,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
# request ids should be included in the EngineCoreOutputs returned
|
# request ids should be included in the EngineCoreOutputs returned
|
||||||
# by update_from_outputs(). This is currently used in the multi-engine
|
# by update_from_outputs(). This is currently used in the multi-engine
|
||||||
# case to track request lifetimes efficiently.
|
# case to track request lifetimes efficiently.
|
||||||
self.include_finished_set = include_finished_set
|
self.finished_req_ids_dict: Optional[dict[int, set[str]]] = (
|
||||||
|
defaultdict(set) if include_finished_set else None)
|
||||||
|
|
||||||
# Scheduling constraints.
|
# Scheduling constraints.
|
||||||
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
|
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
|
||||||
@ -693,7 +694,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
self,
|
self,
|
||||||
scheduler_output: SchedulerOutput,
|
scheduler_output: SchedulerOutput,
|
||||||
model_runner_output: ModelRunnerOutput,
|
model_runner_output: ModelRunnerOutput,
|
||||||
) -> EngineCoreOutputs:
|
) -> dict[int, EngineCoreOutputs]:
|
||||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||||
spec_token_ids = model_runner_output.spec_token_ids
|
spec_token_ids = model_runner_output.spec_token_ids
|
||||||
logprobs = model_runner_output.logprobs
|
logprobs = model_runner_output.logprobs
|
||||||
@ -701,7 +702,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||||
|
|
||||||
new_running: list[Request] = []
|
new_running: list[Request] = []
|
||||||
outputs: list[EngineCoreOutput] = []
|
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||||
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
||||||
|
|
||||||
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
|
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
|
||||||
@ -797,7 +798,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
if new_token_ids or kv_transfer_params:
|
if new_token_ids or kv_transfer_params:
|
||||||
|
|
||||||
# Add EngineCoreOutput for this Request.
|
# Add EngineCoreOutput for this Request.
|
||||||
outputs.append(
|
outputs[request.client_index].append(
|
||||||
EngineCoreOutput(
|
EngineCoreOutput(
|
||||||
request_id=req_id,
|
request_id=req_id,
|
||||||
new_token_ids=new_token_ids,
|
new_token_ids=new_token_ids,
|
||||||
@ -828,17 +829,38 @@ class Scheduler(SchedulerInterface):
|
|||||||
self._cached_reqs_data[req_data.req_id].append(req_data)
|
self._cached_reqs_data[req_data.req_id].append(req_data)
|
||||||
|
|
||||||
self.running = new_running
|
self.running = new_running
|
||||||
engine_core_outputs = EngineCoreOutputs(
|
|
||||||
outputs=outputs,
|
# Create EngineCoreOutputs for all clients that have requests with
|
||||||
scheduler_stats=self.make_stats(spec_decoding_stats),
|
# outputs in this step.
|
||||||
)
|
engine_core_outputs = {
|
||||||
if self.include_finished_set:
|
client_index: EngineCoreOutputs(outputs=outs)
|
||||||
#TODO currently sending duplicates here, improve this
|
for client_index, outs in outputs.items()
|
||||||
engine_core_outputs.finished_requests = (
|
}
|
||||||
scheduler_output.finished_req_ids | self.finished_req_ids)
|
|
||||||
|
finished_req_ids = self.finished_req_ids_dict
|
||||||
|
if finished_req_ids is not None:
|
||||||
|
# Include ids of requests that finished since last outputs
|
||||||
|
# were sent.
|
||||||
|
for client_index, finished_set in finished_req_ids.items():
|
||||||
|
# Set finished request set in EngineCoreOutputs for this client.
|
||||||
|
if (eco := engine_core_outputs.get(client_index)) is not None:
|
||||||
|
eco.finished_requests = finished_set
|
||||||
|
else:
|
||||||
|
engine_core_outputs[client_index] = EngineCoreOutputs(
|
||||||
|
finished_requests=finished_set)
|
||||||
|
finished_req_ids.clear()
|
||||||
|
|
||||||
|
if engine_core_outputs:
|
||||||
|
# Return stats to only one of the front-ends.
|
||||||
|
next(iter(engine_core_outputs.values())).scheduler_stats = (
|
||||||
|
self.make_stats(spec_decoding_stats))
|
||||||
|
|
||||||
return engine_core_outputs
|
return engine_core_outputs
|
||||||
|
|
||||||
|
def get_request_counts(self) -> tuple[int, int]:
|
||||||
|
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
||||||
|
return len(self.running), len(self.waiting)
|
||||||
|
|
||||||
def add_request(self, request: Request) -> None:
|
def add_request(self, request: Request) -> None:
|
||||||
self.waiting.append(request)
|
self.waiting.append(request)
|
||||||
self.requests[request.request_id] = request
|
self.requests[request.request_id] = request
|
||||||
@ -880,8 +902,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
|
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
|
||||||
self.encoder_cache_manager.free(request)
|
self.encoder_cache_manager.free(request)
|
||||||
self._cached_reqs_data.pop(request.request_id, None)
|
request_id = request.request_id
|
||||||
self.finished_req_ids.add(request.request_id)
|
self._cached_reqs_data.pop(request_id, None)
|
||||||
|
self.finished_req_ids.add(request_id)
|
||||||
|
if self.finished_req_ids_dict is not None:
|
||||||
|
self.finished_req_ids_dict[request.client_index].add(request_id)
|
||||||
|
|
||||||
if not delay_free_blocks:
|
if not delay_free_blocks:
|
||||||
self._free_blocks(request)
|
self._free_blocks(request)
|
||||||
|
|||||||
@ -44,10 +44,6 @@ class EngineCoreRequest(
|
|||||||
omit_defaults=True, # type: ignore[call-arg]
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
gc=False): # type: ignore[call-arg]
|
gc=False): # type: ignore[call-arg]
|
||||||
|
|
||||||
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
|
|
||||||
# but this object is currently not playing well with msgspec
|
|
||||||
# due to circular imports and typing we have in data.py
|
|
||||||
|
|
||||||
request_id: str
|
request_id: str
|
||||||
prompt_token_ids: list[int]
|
prompt_token_ids: list[int]
|
||||||
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
|
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
|
||||||
@ -59,6 +55,10 @@ class EngineCoreRequest(
|
|||||||
lora_request: Optional[LoRARequest]
|
lora_request: Optional[LoRARequest]
|
||||||
cache_salt: Optional[str]
|
cache_salt: Optional[str]
|
||||||
|
|
||||||
|
# Index of the client, used to ensure outputs are sent back to the same
|
||||||
|
# client for this request when scaling out the front-end.
|
||||||
|
client_index: int = 0
|
||||||
|
|
||||||
# Used in DP case to indicate which wave of requests this is expected to
|
# Used in DP case to indicate which wave of requests this is expected to
|
||||||
# belong to, to cover a race condition where the request is sent before
|
# belong to, to cover a race condition where the request is sent before
|
||||||
# a wave finished notification is received.
|
# a wave finished notification is received.
|
||||||
|
|||||||
@ -36,6 +36,7 @@ from vllm.v1.engine.processor import Processor
|
|||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
|
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
|
||||||
setup_default_loggers)
|
setup_default_loggers)
|
||||||
|
from vllm.v1.metrics.prometheus import shutdown_prometheus
|
||||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -54,6 +55,8 @@ class AsyncLLM(EngineClient):
|
|||||||
log_requests: bool = True,
|
log_requests: bool = True,
|
||||||
start_engine_loop: bool = True,
|
start_engine_loop: bool = True,
|
||||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||||
|
client_addresses: Optional[dict[str, str]] = None,
|
||||||
|
client_index: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Create an AsyncLLM.
|
Create an AsyncLLM.
|
||||||
@ -124,6 +127,8 @@ class AsyncLLM(EngineClient):
|
|||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
executor_class=executor_class,
|
executor_class=executor_class,
|
||||||
log_stats=self.log_stats,
|
log_stats=self.log_stats,
|
||||||
|
client_addresses=client_addresses,
|
||||||
|
client_index=client_index,
|
||||||
)
|
)
|
||||||
if self.stat_loggers:
|
if self.stat_loggers:
|
||||||
for stat_logger in self.stat_loggers[0]:
|
for stat_logger in self.stat_loggers[0]:
|
||||||
@ -145,6 +150,8 @@ class AsyncLLM(EngineClient):
|
|||||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||||
disable_log_requests: bool = False,
|
disable_log_requests: bool = False,
|
||||||
disable_log_stats: bool = False,
|
disable_log_stats: bool = False,
|
||||||
|
client_addresses: Optional[dict[str, str]] = None,
|
||||||
|
client_index: int = 0,
|
||||||
) -> "AsyncLLM":
|
) -> "AsyncLLM":
|
||||||
if not envs.VLLM_USE_V1:
|
if not envs.VLLM_USE_V1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -162,6 +169,8 @@ class AsyncLLM(EngineClient):
|
|||||||
log_requests=not disable_log_requests,
|
log_requests=not disable_log_requests,
|
||||||
log_stats=not disable_log_stats,
|
log_stats=not disable_log_stats,
|
||||||
usage_context=usage_context,
|
usage_context=usage_context,
|
||||||
|
client_addresses=client_addresses,
|
||||||
|
client_index=client_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -195,6 +204,8 @@ class AsyncLLM(EngineClient):
|
|||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Shutdown, cleaning up the background proc and IPC."""
|
"""Shutdown, cleaning up the background proc and IPC."""
|
||||||
|
|
||||||
|
shutdown_prometheus()
|
||||||
|
|
||||||
if engine_core := getattr(self, "engine_core", None):
|
if engine_core := getattr(self, "engine_core", None):
|
||||||
engine_core.shutdown()
|
engine_core.shutdown()
|
||||||
|
|
||||||
@ -398,7 +409,6 @@ class AsyncLLM(EngineClient):
|
|||||||
# TODO(rob): make into a coroutine and launch it in
|
# TODO(rob): make into a coroutine and launch it in
|
||||||
# background thread once Prometheus overhead is non-trivial.
|
# background thread once Prometheus overhead is non-trivial.
|
||||||
if stat_loggers:
|
if stat_loggers:
|
||||||
assert outputs.scheduler_stats is not None
|
|
||||||
AsyncLLM._record_stats(
|
AsyncLLM._record_stats(
|
||||||
stat_loggers[outputs.engine_index],
|
stat_loggers[outputs.engine_index],
|
||||||
scheduler_stats=outputs.scheduler_stats,
|
scheduler_stats=outputs.scheduler_stats,
|
||||||
@ -422,7 +432,7 @@ class AsyncLLM(EngineClient):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _record_stats(
|
def _record_stats(
|
||||||
stat_loggers: list[StatLoggerBase],
|
stat_loggers: list[StatLoggerBase],
|
||||||
scheduler_stats: SchedulerStats,
|
scheduler_stats: Optional[SchedulerStats],
|
||||||
iteration_stats: Optional[IterationStats],
|
iteration_stats: Optional[IterationStats],
|
||||||
):
|
):
|
||||||
"""static so that it can be used from the output_handler task
|
"""static so that it can be used from the output_handler task
|
||||||
|
|||||||
252
vllm/v1/engine/coordinator.py
Normal file
252
vllm/v1/engine/coordinator.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import multiprocessing
|
||||||
|
import time
|
||||||
|
import weakref
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import msgspec.msgpack
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from vllm.config import ParallelConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket
|
||||||
|
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
|
||||||
|
from vllm.v1.serial_utils import MsgpackDecoder
|
||||||
|
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DPCoordinator:
|
||||||
|
"""Coordinator process used for data-parallel deployments (DP>1).
|
||||||
|
|
||||||
|
Intermediates between multiple DP engine rank processes and one or more
|
||||||
|
front-end API server processes.
|
||||||
|
|
||||||
|
* Collects stats from each DP engine (currently just waiting and running
|
||||||
|
queue lengths), and publishes these to all front-ends for use in
|
||||||
|
load-balancing decisions.
|
||||||
|
|
||||||
|
* Keeps track of the current DP "request wave" number and running state
|
||||||
|
of the engines. This is received from the DP rank 0 engine and published
|
||||||
|
to the front-end processes along with the current load stats.
|
||||||
|
|
||||||
|
The engines alternate between a global running/paused state. The global
|
||||||
|
"request wave" number is a count of the number of times that the workers
|
||||||
|
collectively move from a running state to a paused state. This transition
|
||||||
|
is synchronized via the all-reduce operation performed in the
|
||||||
|
DPEngineCoreProc._has_global_unfinished_reqs method.
|
||||||
|
|
||||||
|
* Broadcasts the START_DP_WAVE message to engines to move them from paused
|
||||||
|
to running state when one engine receives a new request. This can happen
|
||||||
|
in two cases:
|
||||||
|
1) A front-end sending a new request while the engines are paused will
|
||||||
|
concurrently notify the coordinator.
|
||||||
|
2) An engine receiving a request for a stale request wave while in paused
|
||||||
|
state will notify the coordinator.
|
||||||
|
|
||||||
|
Engines will move into running state when receiving a new request or
|
||||||
|
START_DP_WAVE message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, parallel_config: ParallelConfig):
|
||||||
|
|
||||||
|
# Assume coordinator is colocated with front-end procs.
|
||||||
|
front_publish_address = get_open_zmq_ipc_path()
|
||||||
|
|
||||||
|
dp_size = parallel_config.data_parallel_size
|
||||||
|
assert dp_size > 1, "Coordinator only used for data parallel"
|
||||||
|
|
||||||
|
local_only = dp_size == parallel_config.data_parallel_size_local
|
||||||
|
host = parallel_config.data_parallel_master_ip
|
||||||
|
back_publish_address = get_engine_client_zmq_addr(local_only, host)
|
||||||
|
back_output_address = get_engine_client_zmq_addr(local_only, host)
|
||||||
|
|
||||||
|
context = get_mp_context()
|
||||||
|
self.proc: multiprocessing.Process = context.Process(
|
||||||
|
target=CoordinatorProc.run_coordinator,
|
||||||
|
name="VLLM_DP_Coordinator",
|
||||||
|
kwargs={
|
||||||
|
"engine_count": parallel_config.data_parallel_size,
|
||||||
|
"front_publish_address": front_publish_address,
|
||||||
|
"back_output_address": back_output_address,
|
||||||
|
"back_publish_address": back_publish_address,
|
||||||
|
},
|
||||||
|
daemon=True)
|
||||||
|
self.proc.start()
|
||||||
|
|
||||||
|
self.stats_publish_address = front_publish_address
|
||||||
|
self.coord_in_address = back_publish_address
|
||||||
|
self.coord_out_address = back_output_address
|
||||||
|
self._finalizer = weakref.finalize(self, shutdown, [self.proc])
|
||||||
|
|
||||||
|
def get_stats_publish_address(self) -> str:
|
||||||
|
return self.stats_publish_address
|
||||||
|
|
||||||
|
def get_engine_socket_addresses(self) -> tuple[str, str]:
|
||||||
|
"""Returns tuple of ZMQ input address, output address."""
|
||||||
|
return self.coord_in_address, self.coord_out_address
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._finalizer()
|
||||||
|
|
||||||
|
|
||||||
|
class EngineState:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.request_counts = [0, 0] # [waiting, running]
|
||||||
|
|
||||||
|
|
||||||
|
class CoordinatorProc:
|
||||||
|
|
||||||
|
def __init__(self, engine_count: int):
|
||||||
|
|
||||||
|
self.ctx = zmq.Context()
|
||||||
|
|
||||||
|
self.engines = [EngineState() for _ in range(engine_count)]
|
||||||
|
|
||||||
|
self.current_wave = 0
|
||||||
|
self.engines_running = False
|
||||||
|
self.stats_changed = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run_coordinator(
|
||||||
|
engine_count: int,
|
||||||
|
front_publish_address: str,
|
||||||
|
back_output_address: str,
|
||||||
|
back_publish_address: str,
|
||||||
|
):
|
||||||
|
coordinator = CoordinatorProc(engine_count=engine_count)
|
||||||
|
try:
|
||||||
|
coordinator.process_input_socket(
|
||||||
|
front_publish_address,
|
||||||
|
back_output_address,
|
||||||
|
back_publish_address,
|
||||||
|
)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("DP Coordinator process exiting")
|
||||||
|
|
||||||
|
def process_input_socket(self, front_publish_address: str,
|
||||||
|
back_output_address: str,
|
||||||
|
back_publish_address: str):
|
||||||
|
|
||||||
|
decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||||
|
|
||||||
|
with make_zmq_socket(
|
||||||
|
path=front_publish_address, # IPC
|
||||||
|
ctx=self.ctx,
|
||||||
|
socket_type=zmq.XPUB,
|
||||||
|
bind=True,
|
||||||
|
) as publish_front, make_zmq_socket(
|
||||||
|
path=back_output_address, # IPC or TCP
|
||||||
|
ctx=self.ctx,
|
||||||
|
socket_type=zmq.PULL,
|
||||||
|
bind=True,
|
||||||
|
) as output_back, make_zmq_socket(
|
||||||
|
path=back_publish_address, # IPC or TCP
|
||||||
|
ctx=self.ctx,
|
||||||
|
socket_type=zmq.XPUB,
|
||||||
|
bind=True,
|
||||||
|
) as publish_back:
|
||||||
|
|
||||||
|
poller = zmq.Poller()
|
||||||
|
poller.register(publish_front, zmq.POLLIN)
|
||||||
|
poller.register(output_back, zmq.POLLIN)
|
||||||
|
last_publish_time = 0
|
||||||
|
while True:
|
||||||
|
elapsed = int(time.time() * 1000) - last_publish_time
|
||||||
|
# Send at 100 ms interval if the stats have changed,
|
||||||
|
# or otherwise every 3 seconds.
|
||||||
|
wait_for = 100 if self.stats_changed else 3000
|
||||||
|
events = poller.poll(timeout=max(0, wait_for - elapsed))
|
||||||
|
if not events:
|
||||||
|
# Poller timeout - publish current stats to front-ends.
|
||||||
|
engine_req_counts_list = self._get_engine_counts()
|
||||||
|
to_publish = (engine_req_counts_list, self.current_wave,
|
||||||
|
self.engines_running)
|
||||||
|
publish_front.send(msgspec.msgpack.encode(to_publish))
|
||||||
|
last_publish_time = int(time.time() * 1000)
|
||||||
|
self.stats_changed = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
events = dict(events)
|
||||||
|
|
||||||
|
if publish_front in events:
|
||||||
|
buffer = publish_front.recv()
|
||||||
|
if buffer == b'\x01':
|
||||||
|
# Ignore subscription messages.
|
||||||
|
continue
|
||||||
|
|
||||||
|
# We received a message on the front-end XPUB socket,
|
||||||
|
# from an API server sending a new request while the
|
||||||
|
# engines are paused, so that we can wake the other
|
||||||
|
# engines.
|
||||||
|
engine_to_exclude, wave = msgspec.msgpack.decode(buffer)
|
||||||
|
if wave < self.current_wave:
|
||||||
|
# If the wave number is stale, ensure the message is
|
||||||
|
# handled by all the engines.
|
||||||
|
engine_to_exclude = None
|
||||||
|
if not self.engines_running:
|
||||||
|
self.engines_running = True
|
||||||
|
self.stats_changed = True
|
||||||
|
self._send_start_wave(publish_back, self.current_wave,
|
||||||
|
engine_to_exclude)
|
||||||
|
|
||||||
|
if output_back in events:
|
||||||
|
# We received a message from one of the engines.
|
||||||
|
|
||||||
|
buffer = output_back.recv()
|
||||||
|
outputs: EngineCoreOutputs = decoder.decode(buffer)
|
||||||
|
|
||||||
|
assert not outputs.outputs
|
||||||
|
assert outputs.utility_output is None
|
||||||
|
|
||||||
|
eng_index = outputs.engine_index
|
||||||
|
if outputs.scheduler_stats:
|
||||||
|
# 1. Updated request load stats - update our local
|
||||||
|
# state with these.
|
||||||
|
stats = self.engines[eng_index].request_counts
|
||||||
|
stats[0] = outputs.scheduler_stats.num_waiting_reqs
|
||||||
|
stats[1] = outputs.scheduler_stats.num_running_reqs
|
||||||
|
self.stats_changed = True
|
||||||
|
|
||||||
|
if (wave := outputs.wave_complete) is not None:
|
||||||
|
# 2. Notification from rank 0 engine that we've
|
||||||
|
# moved into the global paused state
|
||||||
|
# (engines_running==False)
|
||||||
|
if self.current_wave <= wave:
|
||||||
|
logger.debug("Moving DP wave from %d to %d.",
|
||||||
|
self.current_wave, wave)
|
||||||
|
self.current_wave = wave + 1
|
||||||
|
self.engines_running = False
|
||||||
|
self.stats_changed = True
|
||||||
|
elif (wave := outputs.start_wave) is not None and (
|
||||||
|
wave > self.current_wave or
|
||||||
|
(wave == self.current_wave
|
||||||
|
and not self.engines_running)):
|
||||||
|
# 3. The engine received request for a non-current wave
|
||||||
|
# so we must ensure that other engines progress to the
|
||||||
|
# next wave (race condition handling).
|
||||||
|
logger.debug(
|
||||||
|
"Starting wave %d after notification of "
|
||||||
|
"stale wave request from engine.", wave)
|
||||||
|
self.current_wave = wave
|
||||||
|
self.engines_running = True
|
||||||
|
self.stats_changed = True
|
||||||
|
self._send_start_wave(publish_back, wave, eng_index)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _send_start_wave(socket: zmq.Socket, wave: int,
|
||||||
|
exclude_engine_index: Optional[int]):
|
||||||
|
"""Broadcast the START_DP_WAVE message to all the engines.
|
||||||
|
It includes the current wave number and index of engine which
|
||||||
|
has already received a request with this wave number and so doesn't
|
||||||
|
require additional notification.
|
||||||
|
"""
|
||||||
|
wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index))
|
||||||
|
socket.send_multipart(
|
||||||
|
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
|
||||||
|
|
||||||
|
def _get_engine_counts(self) -> list[list[int]]:
|
||||||
|
"""Return list of [waiting, running] count lists for each engine."""
|
||||||
|
return [e.request_counts for e in self.engines]
|
||||||
@ -7,6 +7,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
|
from contextlib import ExitStack
|
||||||
from inspect import isclass, signature
|
from inspect import isclass, signature
|
||||||
from logging import DEBUG
|
from logging import DEBUG
|
||||||
from typing import Any, Callable, Optional, TypeVar, Union
|
from typing import Any, Callable, Optional, TypeVar, Union
|
||||||
@ -22,7 +23,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
maybe_register_config_serialize_by_value)
|
maybe_register_config_serialize_by_value)
|
||||||
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx
|
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname
|
||||||
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
|
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
|
||||||
unify_kv_cache_configs)
|
unify_kv_cache_configs)
|
||||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||||
@ -33,10 +34,12 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
|||||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
|
from vllm.v1.metrics.stats import SchedulerStats
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -211,7 +214,7 @@ class EngineCore:
|
|||||||
# Re-raise exception
|
# Re-raise exception
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
def step(self) -> tuple[EngineCoreOutputs, bool]:
|
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
|
||||||
"""Schedule, execute, and make output.
|
"""Schedule, execute, and make output.
|
||||||
|
|
||||||
Returns tuple of outputs and a flag indicating whether the model
|
Returns tuple of outputs and a flag indicating whether the model
|
||||||
@ -221,10 +224,7 @@ class EngineCore:
|
|||||||
# Check for any requests remaining in the scheduler - unfinished,
|
# Check for any requests remaining in the scheduler - unfinished,
|
||||||
# or finished and not yet removed from the batch.
|
# or finished and not yet removed from the batch.
|
||||||
if not self.scheduler.has_requests():
|
if not self.scheduler.has_requests():
|
||||||
return EngineCoreOutputs(
|
return {}, False
|
||||||
outputs=[],
|
|
||||||
scheduler_stats=self.scheduler.make_stats(),
|
|
||||||
), False
|
|
||||||
scheduler_output = self.scheduler.schedule()
|
scheduler_output = self.scheduler.schedule()
|
||||||
model_output = self.execute_model(scheduler_output)
|
model_output = self.execute_model(scheduler_output)
|
||||||
engine_core_outputs = self.scheduler.update_from_output(
|
engine_core_outputs = self.scheduler.update_from_output(
|
||||||
@ -234,7 +234,7 @@ class EngineCore:
|
|||||||
scheduler_output.total_num_scheduled_tokens > 0)
|
scheduler_output.total_num_scheduled_tokens > 0)
|
||||||
|
|
||||||
def step_with_batch_queue(
|
def step_with_batch_queue(
|
||||||
self) -> tuple[Optional[EngineCoreOutputs], bool]:
|
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
|
||||||
"""Schedule and execute batches with the batch queue.
|
"""Schedule and execute batches with the batch queue.
|
||||||
Note that if nothing to output in this step, None is returned.
|
Note that if nothing to output in this step, None is returned.
|
||||||
|
|
||||||
@ -276,8 +276,8 @@ class EngineCore:
|
|||||||
# Blocking until the first result is available.
|
# Blocking until the first result is available.
|
||||||
model_output = future.result()
|
model_output = future.result()
|
||||||
self.batch_queue.task_done()
|
self.batch_queue.task_done()
|
||||||
engine_core_outputs = self.scheduler.update_from_output(
|
engine_core_outputs = (self.scheduler.update_from_output(
|
||||||
scheduler_output, model_output)
|
scheduler_output, model_output))
|
||||||
|
|
||||||
return engine_core_outputs, scheduled_batch
|
return engine_core_outputs, scheduled_batch
|
||||||
|
|
||||||
@ -362,7 +362,7 @@ class EngineCoreProc(EngineCore):
|
|||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
on_head_node: bool,
|
on_head_node: bool,
|
||||||
input_address: str,
|
handshake_address: str,
|
||||||
executor_class: type[Executor],
|
executor_class: type[Executor],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
engine_index: int = 0,
|
engine_index: int = 0,
|
||||||
@ -375,33 +375,38 @@ class EngineCoreProc(EngineCore):
|
|||||||
# Create input socket.
|
# Create input socket.
|
||||||
input_ctx = zmq.Context()
|
input_ctx = zmq.Context()
|
||||||
identity = engine_index.to_bytes(length=2, byteorder="little")
|
identity = engine_index.to_bytes(length=2, byteorder="little")
|
||||||
input_socket = make_zmq_socket(input_ctx,
|
with make_zmq_socket(input_ctx,
|
||||||
input_address,
|
handshake_address,
|
||||||
zmq.DEALER,
|
zmq.DEALER,
|
||||||
identity=identity,
|
identity=identity,
|
||||||
bind=False)
|
linger=5000,
|
||||||
try:
|
bind=False) as handshake_socket:
|
||||||
|
|
||||||
# Register engine with front-end.
|
# Register engine with front-end.
|
||||||
output_address = self.startup_handshake(
|
addresses = self.startup_handshake(handshake_socket, on_head_node,
|
||||||
input_socket, on_head_node, vllm_config.parallel_config)
|
vllm_config.parallel_config)
|
||||||
|
self.client_count = len(addresses.outputs)
|
||||||
|
|
||||||
# Update config which may have changed from the handshake.
|
# Update config which may have changed from the handshake.
|
||||||
vllm_config.__post_init__()
|
vllm_config.__post_init__()
|
||||||
|
|
||||||
# Set up data parallel environment.
|
# Set up data parallel environment.
|
||||||
|
self.has_coordinator = addresses.coordinator_output is not None
|
||||||
self._init_data_parallel(vllm_config)
|
self._init_data_parallel(vllm_config)
|
||||||
|
|
||||||
# Initialize engine core and model.
|
# Initialize engine core and model.
|
||||||
super().__init__(vllm_config, executor_class, log_stats,
|
super().__init__(vllm_config, executor_class, log_stats,
|
||||||
executor_fail_callback)
|
executor_fail_callback)
|
||||||
|
|
||||||
|
self.engine_index = engine_index
|
||||||
self.step_fn = (self.step if self.batch_queue is None else
|
self.step_fn = (self.step if self.batch_queue is None else
|
||||||
self.step_with_batch_queue)
|
self.step_with_batch_queue)
|
||||||
self.engines_running = False
|
self.engines_running = False
|
||||||
|
self.last_counts = (0, 0)
|
||||||
|
|
||||||
# Send ready message.
|
# Send ready message.
|
||||||
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
||||||
input_socket.send(
|
handshake_socket.send(
|
||||||
msgspec.msgpack.encode({
|
msgspec.msgpack.encode({
|
||||||
"status": "READY",
|
"status": "READY",
|
||||||
"local": on_head_node,
|
"local": on_head_node,
|
||||||
@ -414,26 +419,26 @@ class EngineCoreProc(EngineCore):
|
|||||||
# model forward pass.
|
# model forward pass.
|
||||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||||
self.input_queue = input_queue
|
self.input_queue = input_queue
|
||||||
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
|
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
|
||||||
threading.Thread(target=self.process_input_socket,
|
bytes]]()
|
||||||
args=(input_socket, ),
|
threading.Thread(target=self.process_input_sockets,
|
||||||
|
args=(addresses.inputs, addresses.coordinator_input,
|
||||||
|
identity),
|
||||||
daemon=True).start()
|
daemon=True).start()
|
||||||
input_socket = None
|
|
||||||
self.output_thread = threading.Thread(
|
self.output_thread = threading.Thread(
|
||||||
target=self.process_output_socket,
|
target=self.process_output_sockets,
|
||||||
args=(output_address, engine_index),
|
args=(addresses.outputs, addresses.coordinator_output,
|
||||||
|
engine_index),
|
||||||
daemon=True)
|
daemon=True)
|
||||||
self.output_thread.start()
|
self.output_thread.start()
|
||||||
finally:
|
|
||||||
if input_socket is not None:
|
|
||||||
input_socket.close(linger=0)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def startup_handshake(input_socket: zmq.Socket, on_head_node: bool,
|
def startup_handshake(
|
||||||
parallel_config: ParallelConfig) -> str:
|
handshake_socket: zmq.Socket, on_head_node: bool,
|
||||||
|
parallel_config: ParallelConfig) -> EngineZmqAddresses:
|
||||||
|
|
||||||
# Send registration message.
|
# Send registration message.
|
||||||
input_socket.send(
|
handshake_socket.send(
|
||||||
msgspec.msgpack.encode({
|
msgspec.msgpack.encode({
|
||||||
"status": "HELLO",
|
"status": "HELLO",
|
||||||
"local": on_head_node,
|
"local": on_head_node,
|
||||||
@ -441,22 +446,20 @@ class EngineCoreProc(EngineCore):
|
|||||||
|
|
||||||
# Receive initialization message.
|
# Receive initialization message.
|
||||||
logger.info("Waiting for init message from front-end.")
|
logger.info("Waiting for init message from front-end.")
|
||||||
if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000):
|
if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
|
||||||
raise RuntimeError("Did not receive response from front-end "
|
raise RuntimeError("Did not receive response from front-end "
|
||||||
f"process within {HANDSHAKE_TIMEOUT_MINS} "
|
f"process within {HANDSHAKE_TIMEOUT_MINS} "
|
||||||
f"minutes")
|
f"minutes")
|
||||||
init_bytes = input_socket.recv()
|
init_bytes = handshake_socket.recv()
|
||||||
init_message = msgspec.msgpack.decode(init_bytes)
|
init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
|
||||||
|
init_bytes, type=EngineHandshakeMetadata)
|
||||||
logger.debug("Received init message: %s", init_message)
|
logger.debug("Received init message: %s", init_message)
|
||||||
|
|
||||||
output_socket_address = init_message["output_socket_address"]
|
received_parallel_config = init_message.parallel_config
|
||||||
#TBD(nick) maybe replace IP with configured head node address
|
|
||||||
|
|
||||||
received_parallel_config = init_message["parallel_config"]
|
|
||||||
for key, value in received_parallel_config.items():
|
for key, value in received_parallel_config.items():
|
||||||
setattr(parallel_config, key, value)
|
setattr(parallel_config, key, value)
|
||||||
|
|
||||||
return output_socket_address
|
return init_message.addresses
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def run_engine_core(*args,
|
def run_engine_core(*args,
|
||||||
@ -528,7 +531,7 @@ class EngineCoreProc(EngineCore):
|
|||||||
"""Exits when an engine step needs to be performed."""
|
"""Exits when an engine step needs to be performed."""
|
||||||
|
|
||||||
waited = False
|
waited = False
|
||||||
while not self.engines_running and not (self.scheduler.has_requests()):
|
while not self.engines_running and not self.scheduler.has_requests():
|
||||||
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
|
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
|
||||||
logger.debug("EngineCore waiting for work.")
|
logger.debug("EngineCore waiting for work.")
|
||||||
waited = True
|
waited = True
|
||||||
@ -549,8 +552,8 @@ class EngineCoreProc(EngineCore):
|
|||||||
# Step the engine core.
|
# Step the engine core.
|
||||||
outputs, model_executed = self.step_fn()
|
outputs, model_executed = self.step_fn()
|
||||||
# Put EngineCoreOutputs into the output queue.
|
# Put EngineCoreOutputs into the output queue.
|
||||||
if outputs is not None:
|
for output in (outputs.items() if outputs else ()):
|
||||||
self.output_queue.put_nowait(outputs)
|
self.output_queue.put_nowait(output)
|
||||||
|
|
||||||
return model_executed
|
return model_executed
|
||||||
|
|
||||||
@ -563,7 +566,7 @@ class EngineCoreProc(EngineCore):
|
|||||||
elif request_type == EngineCoreRequestType.ABORT:
|
elif request_type == EngineCoreRequestType.ABORT:
|
||||||
self.abort_requests(request)
|
self.abort_requests(request)
|
||||||
elif request_type == EngineCoreRequestType.UTILITY:
|
elif request_type == EngineCoreRequestType.UTILITY:
|
||||||
call_id, method_name, args = request
|
client_idx, call_id, method_name, args = request
|
||||||
output = UtilityOutput(call_id)
|
output = UtilityOutput(call_id)
|
||||||
try:
|
try:
|
||||||
method = getattr(self, method_name)
|
method = getattr(self, method_name)
|
||||||
@ -574,7 +577,7 @@ class EngineCoreProc(EngineCore):
|
|||||||
output.failure_message = (f"Call to {method_name} method"
|
output.failure_message = (f"Call to {method_name} method"
|
||||||
f" failed: {str(e)}")
|
f" failed: {str(e)}")
|
||||||
self.output_queue.put_nowait(
|
self.output_queue.put_nowait(
|
||||||
EngineCoreOutputs(utility_output=output))
|
(client_idx, EngineCoreOutputs(utility_output=output)))
|
||||||
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
|
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
|
||||||
raise RuntimeError("Executor failed.")
|
raise RuntimeError("Executor failed.")
|
||||||
else:
|
else:
|
||||||
@ -607,27 +610,68 @@ class EngineCoreProc(EngineCore):
|
|||||||
logger.fatal("vLLM shutdown signal from EngineCore failed "
|
logger.fatal("vLLM shutdown signal from EngineCore failed "
|
||||||
"to send. Please report this issue.")
|
"to send. Please report this issue.")
|
||||||
|
|
||||||
def process_input_socket(self, input_socket: zmq.Socket):
|
def process_input_sockets(self, input_addresses: list[str],
|
||||||
|
coord_input_address: Optional[str],
|
||||||
|
identity: bytes):
|
||||||
"""Input socket IO thread."""
|
"""Input socket IO thread."""
|
||||||
|
|
||||||
# Msgpack serialization decoding.
|
# Msgpack serialization decoding.
|
||||||
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||||
generic_decoder = MsgpackDecoder()
|
generic_decoder = MsgpackDecoder()
|
||||||
|
|
||||||
|
with ExitStack() as stack, zmq.Context() as ctx:
|
||||||
|
input_sockets = [
|
||||||
|
stack.enter_context(
|
||||||
|
make_zmq_socket(ctx,
|
||||||
|
input_address,
|
||||||
|
zmq.DEALER,
|
||||||
|
identity=identity,
|
||||||
|
bind=False))
|
||||||
|
for input_address in input_addresses
|
||||||
|
]
|
||||||
|
if coord_input_address is None:
|
||||||
|
coord_socket = None
|
||||||
|
else:
|
||||||
|
coord_socket = stack.enter_context(
|
||||||
|
make_zmq_socket(ctx,
|
||||||
|
coord_input_address,
|
||||||
|
zmq.XSUB,
|
||||||
|
identity=identity,
|
||||||
|
bind=False))
|
||||||
|
# Send subscription message to coordinator.
|
||||||
|
coord_socket.send(b'\x01')
|
||||||
|
|
||||||
|
# Register sockets with poller.
|
||||||
|
poller = zmq.Poller()
|
||||||
|
for input_socket in input_sockets:
|
||||||
|
# Send initial message to each input socket - this is required
|
||||||
|
# before the front-end ROUTER socket can send input messages
|
||||||
|
# back to us.
|
||||||
|
input_socket.send(b'')
|
||||||
|
poller.register(input_socket, zmq.POLLIN)
|
||||||
|
if coord_socket is not None:
|
||||||
|
poller.register(coord_socket, zmq.POLLIN)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
for input_socket, _ in poller.poll():
|
||||||
# (RequestType, RequestData)
|
# (RequestType, RequestData)
|
||||||
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
|
type_frame, *data_frames = input_socket.recv_multipart(
|
||||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
copy=False)
|
||||||
|
request_type = EngineCoreRequestType(
|
||||||
|
bytes(type_frame.buffer))
|
||||||
|
|
||||||
# Deserialize the request data.
|
# Deserialize the request data.
|
||||||
decoder = add_request_decoder if (
|
decoder = add_request_decoder if (
|
||||||
request_type == EngineCoreRequestType.ADD) else generic_decoder
|
request_type
|
||||||
|
== EngineCoreRequestType.ADD) else generic_decoder
|
||||||
request = decoder.decode(data_frames)
|
request = decoder.decode(data_frames)
|
||||||
|
|
||||||
# Push to input queue for core busy loop.
|
# Push to input queue for core busy loop.
|
||||||
self.input_queue.put_nowait((request_type, request))
|
self.input_queue.put_nowait((request_type, request))
|
||||||
|
|
||||||
def process_output_socket(self, output_path: str, engine_index: int):
|
def process_output_sockets(self, output_paths: list[str],
|
||||||
|
coord_output_path: Optional[str],
|
||||||
|
engine_index: int):
|
||||||
"""Output socket IO thread."""
|
"""Output socket IO thread."""
|
||||||
|
|
||||||
# Msgpack serialization encoding.
|
# Msgpack serialization encoding.
|
||||||
@ -641,30 +685,49 @@ class EngineCoreProc(EngineCore):
|
|||||||
|
|
||||||
# We must set linger to ensure the ENGINE_CORE_DEAD
|
# We must set linger to ensure the ENGINE_CORE_DEAD
|
||||||
# message is sent prior to closing the socket.
|
# message is sent prior to closing the socket.
|
||||||
with zmq_socket_ctx(output_path, zmq.constants.PUSH,
|
with ExitStack() as stack, zmq.Context() as ctx:
|
||||||
linger=4000) as socket:
|
sockets = [
|
||||||
|
stack.enter_context(
|
||||||
|
make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
|
||||||
|
for output_path in output_paths
|
||||||
|
]
|
||||||
|
coord_socket = stack.enter_context(
|
||||||
|
make_zmq_socket(
|
||||||
|
ctx, coord_output_path, zmq.PUSH, bind=False,
|
||||||
|
linger=4000)) if coord_output_path is not None else None
|
||||||
|
max_reuse_bufs = len(sockets) + 1
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
outputs = self.output_queue.get()
|
output = self.output_queue.get()
|
||||||
if outputs == EngineCoreProc.ENGINE_CORE_DEAD:
|
if output == EngineCoreProc.ENGINE_CORE_DEAD:
|
||||||
socket.send(outputs, copy=False)
|
for socket in sockets:
|
||||||
|
socket.send(output)
|
||||||
break
|
break
|
||||||
assert not isinstance(outputs, bytes)
|
assert not isinstance(output, bytes)
|
||||||
|
client_index, outputs = output
|
||||||
outputs.engine_index = engine_index
|
outputs.engine_index = engine_index
|
||||||
|
|
||||||
|
if client_index == -1:
|
||||||
|
# Don't reuse buffer for coordinator message
|
||||||
|
# which will be very small.
|
||||||
|
assert coord_socket is not None
|
||||||
|
coord_socket.send_multipart(encoder.encode(outputs))
|
||||||
|
continue
|
||||||
|
|
||||||
# Reclaim buffers that zmq is finished with.
|
# Reclaim buffers that zmq is finished with.
|
||||||
while pending and pending[-1][0].done:
|
while pending and pending[-1][0].done:
|
||||||
reuse_buffers.append(pending.pop()[2])
|
reuse_buffers.append(pending.pop()[2])
|
||||||
|
|
||||||
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
|
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
|
||||||
buffers = encoder.encode_into(outputs, buffer)
|
buffers = encoder.encode_into(outputs, buffer)
|
||||||
tracker = socket.send_multipart(buffers,
|
tracker = sockets[client_index].send_multipart(buffers,
|
||||||
copy=False,
|
copy=False,
|
||||||
track=True)
|
track=True)
|
||||||
if not tracker.done:
|
if not tracker.done:
|
||||||
ref = outputs if len(buffers) > 1 else None
|
ref = outputs if len(buffers) > 1 else None
|
||||||
pending.appendleft((tracker, ref, buffer))
|
pending.appendleft((tracker, ref, buffer))
|
||||||
elif len(reuse_buffers) < 2:
|
elif len(reuse_buffers) < max_reuse_bufs:
|
||||||
# Keep at most 2 buffers to reuse.
|
# Limit the number of buffers to reuse.
|
||||||
reuse_buffers.append(buffer)
|
reuse_buffers.append(buffer)
|
||||||
|
|
||||||
|
|
||||||
@ -676,7 +739,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
on_head_node: bool,
|
on_head_node: bool,
|
||||||
input_address: str,
|
handshake_address: str,
|
||||||
executor_class: type[Executor],
|
executor_class: type[Executor],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
):
|
):
|
||||||
@ -691,10 +754,11 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
# Counts forward-passes of the model so that we can synchronize
|
# Counts forward-passes of the model so that we can synchronize
|
||||||
# finished with DP peers every N steps.
|
# finished with DP peers every N steps.
|
||||||
self.counter = 0
|
self.counter = 0
|
||||||
|
self.current_wave = 0
|
||||||
|
|
||||||
# Initialize the engine.
|
# Initialize the engine.
|
||||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||||
super().__init__(vllm_config, on_head_node, input_address,
|
super().__init__(vllm_config, on_head_node, handshake_address,
|
||||||
executor_class, log_stats, dp_rank)
|
executor_class, log_stats, dp_rank)
|
||||||
|
|
||||||
def _init_data_parallel(self, vllm_config: VllmConfig):
|
def _init_data_parallel(self, vllm_config: VllmConfig):
|
||||||
@ -726,7 +790,6 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
|
|
||||||
self.dp_rank = dp_rank
|
self.dp_rank = dp_rank
|
||||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||||
self.current_wave = 0
|
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
super().shutdown()
|
super().shutdown()
|
||||||
@ -734,22 +797,23 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||||
|
|
||||||
def add_request(self, request: EngineCoreRequest):
|
def add_request(self, request: EngineCoreRequest):
|
||||||
if request.current_wave != self.current_wave:
|
if self.has_coordinator and request.current_wave != self.current_wave:
|
||||||
if request.current_wave > self.current_wave:
|
if request.current_wave > self.current_wave:
|
||||||
self.current_wave = request.current_wave
|
self.current_wave = request.current_wave
|
||||||
elif not self.engines_running:
|
elif not self.engines_running:
|
||||||
# Request received for an already-completed wave, notify
|
# Request received for an already-completed wave, notify
|
||||||
# front-end that we need to start the next one.
|
# front-end that we need to start the next one.
|
||||||
self.output_queue.put_nowait(
|
self.output_queue.put_nowait(
|
||||||
EngineCoreOutputs(start_wave=self.current_wave))
|
(-1, EngineCoreOutputs(start_wave=self.current_wave)))
|
||||||
|
|
||||||
super().add_request(request)
|
super().add_request(request)
|
||||||
|
|
||||||
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||||
request: Any) -> None:
|
request: Any) -> None:
|
||||||
if request_type == EngineCoreRequestType.START_DP_WAVE:
|
if request_type == EngineCoreRequestType.START_DP_WAVE:
|
||||||
new_wave: int = request
|
new_wave, exclude_eng_index = request
|
||||||
if new_wave >= self.current_wave:
|
if exclude_eng_index != self.engine_index and (
|
||||||
|
new_wave >= self.current_wave):
|
||||||
self.current_wave = new_wave
|
self.current_wave = new_wave
|
||||||
if not self.engines_running:
|
if not self.engines_running:
|
||||||
logger.debug("EngineCore starting idle loop for wave %d.",
|
logger.debug("EngineCore starting idle loop for wave %d.",
|
||||||
@ -758,6 +822,18 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
else:
|
else:
|
||||||
super()._handle_client_request(request_type, request)
|
super()._handle_client_request(request_type, request)
|
||||||
|
|
||||||
|
def _maybe_publish_request_counts(self):
|
||||||
|
if not self.has_coordinator:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Publish our request counts (if they've changed).
|
||||||
|
counts = self.scheduler.get_request_counts()
|
||||||
|
if counts != self.last_counts:
|
||||||
|
self.last_counts = counts
|
||||||
|
stats = SchedulerStats(*counts)
|
||||||
|
self.output_queue.put_nowait(
|
||||||
|
(-1, EngineCoreOutputs(scheduler_stats=stats)))
|
||||||
|
|
||||||
def run_busy_loop(self):
|
def run_busy_loop(self):
|
||||||
"""Core busy loop of the EngineCore for data parallel case."""
|
"""Core busy loop of the EngineCore for data parallel case."""
|
||||||
|
|
||||||
@ -768,6 +844,8 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
|
|
||||||
# 2) Step the engine core.
|
# 2) Step the engine core.
|
||||||
executed = self._process_engine_step()
|
executed = self._process_engine_step()
|
||||||
|
self._maybe_publish_request_counts()
|
||||||
|
|
||||||
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
|
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
|
||||||
if not executed:
|
if not executed:
|
||||||
if not local_unfinished_reqs and not self.engines_running:
|
if not local_unfinished_reqs and not self.engines_running:
|
||||||
@ -788,7 +866,8 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
logger.debug("Wave %d finished, pausing engine loop.",
|
logger.debug("Wave %d finished, pausing engine loop.",
|
||||||
self.current_wave)
|
self.current_wave)
|
||||||
self.output_queue.put_nowait(
|
self.output_queue.put_nowait(
|
||||||
EngineCoreOutputs(wave_complete=self.current_wave))
|
(-1,
|
||||||
|
EngineCoreOutputs(wave_complete=self.current_wave)))
|
||||||
self.current_wave += 1
|
self.current_wave += 1
|
||||||
|
|
||||||
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import queue
|
import queue
|
||||||
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
import weakref
|
import weakref
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -9,26 +10,28 @@ from collections import deque
|
|||||||
from collections.abc import Awaitable, Sequence
|
from collections.abc import Awaitable, Sequence
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Callable, Optional, TypeVar, Union
|
from typing import Any, Callable, Optional, TypeVar, Union
|
||||||
|
|
||||||
import msgspec
|
import msgspec.msgpack
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
|
|
||||||
from vllm.config import ParallelConfig, VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.utils import (get_open_port, get_open_zmq_inproc_path,
|
from vllm.utils import (get_open_zmq_inproc_path, make_zmq_socket,
|
||||||
get_open_zmq_ipc_path, get_tcp_uri, make_zmq_socket)
|
zmq_socket_ctx)
|
||||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||||
EngineCoreRequestType, UtilityOutput)
|
EngineCoreRequestType, UtilityOutput)
|
||||||
|
from vllm.v1.engine.coordinator import DPCoordinator
|
||||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||||
from vllm.v1.engine.exceptions import EngineDeadError
|
from vllm.v1.engine.exceptions import EngineDeadError
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
||||||
from vllm.v1.utils import CoreEngineProcManager
|
from vllm.v1.utils import (CoreEngine, CoreEngineProcManager,
|
||||||
|
EngineZmqAddresses, get_engine_client_zmq_addr,
|
||||||
|
wait_for_engine_startup)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -36,8 +39,6 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]]
|
|||||||
|
|
||||||
_R = TypeVar('_R') # Return type for collective_rpc
|
_R = TypeVar('_R') # Return type for collective_rpc
|
||||||
|
|
||||||
STARTUP_POLL_PERIOD_MS = 10000
|
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreClient(ABC):
|
class EngineCoreClient(ABC):
|
||||||
"""
|
"""
|
||||||
@ -207,7 +208,7 @@ class InprocClient(EngineCoreClient):
|
|||||||
|
|
||||||
def get_output(self) -> EngineCoreOutputs:
|
def get_output(self) -> EngineCoreOutputs:
|
||||||
outputs, _ = self.engine_core.step()
|
outputs, _ = self.engine_core.step()
|
||||||
return outputs
|
return outputs.get(0) or EngineCoreOutputs()
|
||||||
|
|
||||||
def add_request(self, request: EngineCoreRequest) -> None:
|
def add_request(self, request: EngineCoreRequest) -> None:
|
||||||
self.engine_core.add_request(request)
|
self.engine_core.add_request(request)
|
||||||
@ -266,24 +267,6 @@ class InprocClient(EngineCoreClient):
|
|||||||
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CoreEngineState(Enum):
|
|
||||||
NEW = auto()
|
|
||||||
CONNECTED = auto()
|
|
||||||
READY = auto()
|
|
||||||
|
|
||||||
|
|
||||||
class CoreEngine:
|
|
||||||
"""One per data parallel rank."""
|
|
||||||
|
|
||||||
def __init__(self, index: int = 0, local: bool = True):
|
|
||||||
self.local = local
|
|
||||||
self.index = index
|
|
||||||
self.identity = index.to_bytes(length=2, byteorder="little")
|
|
||||||
|
|
||||||
self.state = CoreEngineState.NEW
|
|
||||||
self.num_reqs_in_flight = 0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BackgroundResources:
|
class BackgroundResources:
|
||||||
"""Used as a finalizer for clean shutdown, avoiding
|
"""Used as a finalizer for clean shutdown, avoiding
|
||||||
@ -291,9 +274,12 @@ class BackgroundResources:
|
|||||||
|
|
||||||
ctx: Union[zmq.Context]
|
ctx: Union[zmq.Context]
|
||||||
local_engine_manager: Optional[CoreEngineProcManager] = None
|
local_engine_manager: Optional[CoreEngineProcManager] = None
|
||||||
|
coordinator: Optional[DPCoordinator] = None
|
||||||
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
||||||
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
||||||
|
first_req_send_socket: Optional[zmq.asyncio.Socket] = None
|
||||||
output_queue_task: Optional[asyncio.Task] = None
|
output_queue_task: Optional[asyncio.Task] = None
|
||||||
|
stats_update_task: Optional[asyncio.Task] = None
|
||||||
shutdown_path: Optional[str] = None
|
shutdown_path: Optional[str] = None
|
||||||
|
|
||||||
# Set if any of the engines are dead. Here so that the output
|
# Set if any of the engines are dead. Here so that the output
|
||||||
@ -306,16 +292,21 @@ class BackgroundResources:
|
|||||||
self.engine_dead = True
|
self.engine_dead = True
|
||||||
if self.local_engine_manager is not None:
|
if self.local_engine_manager is not None:
|
||||||
self.local_engine_manager.close()
|
self.local_engine_manager.close()
|
||||||
|
if self.coordinator is not None:
|
||||||
|
self.coordinator.close()
|
||||||
|
|
||||||
if self.output_queue_task is not None:
|
if self.output_queue_task is not None:
|
||||||
self.output_queue_task.cancel()
|
self.output_queue_task.cancel()
|
||||||
|
if self.stats_update_task is not None:
|
||||||
|
self.stats_update_task.cancel()
|
||||||
|
|
||||||
# ZMQ context termination can hang if the sockets
|
# ZMQ context termination can hang if the sockets
|
||||||
# aren't explicitly closed first.
|
# aren't explicitly closed first.
|
||||||
if self.output_socket is not None:
|
for socket in (self.output_socket, self.input_socket,
|
||||||
self.output_socket.close(linger=0)
|
self.first_req_send_socket):
|
||||||
if self.input_socket is not None:
|
if socket is not None:
|
||||||
self.input_socket.close(linger=0)
|
socket.close(linger=0)
|
||||||
|
|
||||||
if self.shutdown_path is not None:
|
if self.shutdown_path is not None:
|
||||||
# We must ensure that the sync output socket is
|
# We must ensure that the sync output socket is
|
||||||
# closed cleanly in its own thread.
|
# closed cleanly in its own thread.
|
||||||
@ -350,6 +341,7 @@ class MPClient(EngineCoreClient):
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
executor_class: type[Executor],
|
executor_class: type[Executor],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
|
client_addresses: Optional[dict[str, str]] = None,
|
||||||
):
|
):
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
# Serialization setup.
|
# Serialization setup.
|
||||||
@ -369,8 +361,8 @@ class MPClient(EngineCoreClient):
|
|||||||
try:
|
try:
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
local_engine_count = parallel_config.data_parallel_size_local
|
local_engine_count = parallel_config.data_parallel_size_local
|
||||||
start_index = parallel_config.data_parallel_rank
|
|
||||||
local_start_index = parallel_config.data_parallel_rank_local
|
local_start_index = parallel_config.data_parallel_rank_local
|
||||||
|
dp_size = parallel_config.data_parallel_size
|
||||||
|
|
||||||
# SPMD mode is where there is an LLM instance per DP rank and
|
# SPMD mode is where there is an LLM instance per DP rank and
|
||||||
# one core engine per LLM, see
|
# one core engine per LLM, see
|
||||||
@ -382,42 +374,53 @@ class MPClient(EngineCoreClient):
|
|||||||
CoreEngine(index=local_start_index, local=True)
|
CoreEngine(index=local_start_index, local=True)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
assert start_index == 0
|
assert parallel_config.data_parallel_rank == 0
|
||||||
local_start_index = 0
|
local_start_index = 0
|
||||||
self.core_engines = [
|
self.core_engines = [
|
||||||
CoreEngine(index=i, local=(i < local_engine_count))
|
CoreEngine(index=i, local=(i < local_engine_count))
|
||||||
for i in range(parallel_config.data_parallel_size)
|
for i in range(dp_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
input_address, output_address = self._get_zmq_addresses(
|
local_only = spmd_mode or local_engine_count == dp_size
|
||||||
parallel_config, spmd_mode)
|
|
||||||
|
self.stats_update_address: Optional[str] = None
|
||||||
|
if client_addresses is not None:
|
||||||
|
input_address = client_addresses["input_address"]
|
||||||
|
output_address = client_addresses["output_address"]
|
||||||
|
self.stats_update_address = client_addresses.get(
|
||||||
|
"stats_update_address")
|
||||||
|
else:
|
||||||
|
host = parallel_config.data_parallel_master_ip
|
||||||
|
input_address = get_engine_client_zmq_addr(local_only, host)
|
||||||
|
output_address = get_engine_client_zmq_addr(local_only, host)
|
||||||
|
|
||||||
# Create input and output sockets.
|
# Create input and output sockets.
|
||||||
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
||||||
self.ctx, input_address, zmq.ROUTER, bind=True)
|
self.ctx, input_address, zmq.ROUTER, bind=True)
|
||||||
|
|
||||||
self.resources.output_socket = make_zmq_socket(
|
self.resources.output_socket = make_zmq_socket(
|
||||||
self.ctx, output_address, zmq.constants.PULL)
|
self.ctx, output_address, zmq.PULL)
|
||||||
# Start local engines.
|
|
||||||
if local_engine_count:
|
if client_addresses is None:
|
||||||
# In server mode, start_index and local_start_index will
|
self._init_engines_direct(vllm_config, local_only,
|
||||||
# both be 0.
|
local_start_index, input_address,
|
||||||
self.resources.local_engine_manager = CoreEngineProcManager(
|
output_address, executor_class,
|
||||||
EngineCoreProc.run_engine_core,
|
log_stats)
|
||||||
vllm_config=vllm_config,
|
coordinator = self.resources.coordinator
|
||||||
executor_class=executor_class,
|
if coordinator:
|
||||||
log_stats=log_stats,
|
self.stats_update_address = (
|
||||||
input_address=input_address,
|
coordinator.get_stats_publish_address())
|
||||||
on_head_node=True,
|
|
||||||
local_engine_count=local_engine_count,
|
# Wait for ready messages from each engine on the input socket.
|
||||||
start_index=start_index,
|
identities = set(e.identity for e in self.core_engines)
|
||||||
local_start_index=local_start_index)
|
sync_input_socket = zmq.Socket.shadow(self.input_socket)
|
||||||
|
while identities:
|
||||||
|
if not sync_input_socket.poll(timeout=600_000):
|
||||||
|
raise TimeoutError("Timed out waiting for engines to send"
|
||||||
|
"initial message on input socket.")
|
||||||
|
identity, _ = sync_input_socket.recv_multipart()
|
||||||
|
identities.remove(identity)
|
||||||
|
|
||||||
self.core_engine = self.core_engines[0]
|
self.core_engine = self.core_engines[0]
|
||||||
|
|
||||||
# Wait for engine core process(es) to start.
|
|
||||||
self._wait_for_engine_startup(output_address, parallel_config)
|
|
||||||
|
|
||||||
self.utility_results: dict[int, AnyFuture] = {}
|
self.utility_results: dict[int, AnyFuture] = {}
|
||||||
|
|
||||||
# Request objects which may contain pytorch-allocated tensors
|
# Request objects which may contain pytorch-allocated tensors
|
||||||
@ -430,116 +433,67 @@ class MPClient(EngineCoreClient):
|
|||||||
if not success:
|
if not success:
|
||||||
self._finalizer()
|
self._finalizer()
|
||||||
|
|
||||||
@staticmethod
|
def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool,
|
||||||
def _get_zmq_addresses(parallel_config: ParallelConfig,
|
local_start_index: int, input_address: str,
|
||||||
spmd_mode: bool) -> tuple[str, str]:
|
output_address: str,
|
||||||
"""Returns (input_address, output_address)."""
|
executor_class: type[Executor], log_stats: bool):
|
||||||
dp_size = parallel_config.data_parallel_size
|
"""Self-contained client mode, launch engine and coordinator process
|
||||||
|
as needed."""
|
||||||
|
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
local_engine_count = parallel_config.data_parallel_size_local
|
local_engine_count = parallel_config.data_parallel_size_local
|
||||||
|
start_index = parallel_config.data_parallel_rank
|
||||||
if local_engine_count == dp_size or spmd_mode:
|
|
||||||
input_address = get_open_zmq_ipc_path()
|
|
||||||
output_address = get_open_zmq_ipc_path()
|
|
||||||
else:
|
|
||||||
host = parallel_config.data_parallel_master_ip
|
host = parallel_config.data_parallel_master_ip
|
||||||
input_port = parallel_config.data_parallel_rpc_port
|
|
||||||
output_port = get_open_port()
|
|
||||||
input_address = get_tcp_uri(host, input_port)
|
|
||||||
output_address = get_tcp_uri(host, output_port)
|
|
||||||
|
|
||||||
return input_address, output_address
|
if len(self.core_engines) > 1:
|
||||||
|
self.resources.coordinator = DPCoordinator(parallel_config)
|
||||||
|
|
||||||
def _wait_for_engine_startup(self, output_address: str,
|
handshake_address = get_engine_client_zmq_addr(
|
||||||
parallel_config: ParallelConfig):
|
local_only, host, parallel_config.data_parallel_rpc_port)
|
||||||
# Get a sync handle to the socket which can be sync or async.
|
|
||||||
sync_input_socket = zmq.Socket.shadow(self.input_socket)
|
|
||||||
|
|
||||||
# Wait for engine core process(es) to send ready messages.
|
with zmq_socket_ctx(handshake_address, zmq.ROUTER,
|
||||||
local_count = parallel_config.data_parallel_size_local
|
bind=True) as handshake_socket:
|
||||||
remote_count = len(self.core_engines) - local_count
|
|
||||||
# [local, remote] counts
|
|
||||||
conn_pending, start_pending = [local_count, remote_count], [0, 0]
|
|
||||||
|
|
||||||
poller = zmq.Poller()
|
# Start local engines.
|
||||||
poller.register(sync_input_socket, zmq.POLLIN)
|
if local_engine_count:
|
||||||
proc_manager = self.resources.local_engine_manager
|
# In server mode, start_index and local_start_index will
|
||||||
if proc_manager is not None:
|
# both be 0.
|
||||||
for sentinel in proc_manager.sentinels():
|
self.resources.local_engine_manager = CoreEngineProcManager(
|
||||||
poller.register(sentinel, zmq.POLLIN)
|
EngineCoreProc.run_engine_core,
|
||||||
while any(conn_pending) or any(start_pending):
|
vllm_config=vllm_config,
|
||||||
events = poller.poll(STARTUP_POLL_PERIOD_MS)
|
executor_class=executor_class,
|
||||||
if not events:
|
log_stats=log_stats,
|
||||||
if any(conn_pending):
|
handshake_address=handshake_address,
|
||||||
logger.debug(
|
on_head_node=True,
|
||||||
"Waiting for %d local, %d remote core engine proc(s) "
|
local_engine_count=local_engine_count,
|
||||||
"to connect.", *conn_pending)
|
start_index=start_index,
|
||||||
if any(start_pending):
|
local_start_index=local_start_index)
|
||||||
logger.debug(
|
|
||||||
"Waiting for %d local, %d remote core engine proc(s) "
|
|
||||||
"to start.", *start_pending)
|
|
||||||
continue
|
|
||||||
if len(events) > 1 or events[0][0] != sync_input_socket:
|
|
||||||
# One of the local core processes exited.
|
|
||||||
finished = proc_manager.finished_procs(
|
|
||||||
) if proc_manager else {}
|
|
||||||
raise RuntimeError("Engine core initialization failed. "
|
|
||||||
"See root cause above. "
|
|
||||||
f"Failed core proc(s): {finished}")
|
|
||||||
|
|
||||||
# Receive HELLO and READY messages from the input socket.
|
# Wait for engine core process(es) to start.
|
||||||
eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart()
|
self._wait_for_engine_startup(handshake_socket, input_address,
|
||||||
eng_index = int.from_bytes(eng_identity, byteorder="little")
|
output_address)
|
||||||
engine = next(
|
|
||||||
(e for e in self.core_engines if e.identity == eng_identity),
|
|
||||||
None)
|
|
||||||
if engine is None:
|
|
||||||
raise RuntimeError(f"Message from engine with unexpected data "
|
|
||||||
f"parallel rank: {eng_index}")
|
|
||||||
msg = msgspec.msgpack.decode(ready_msg_bytes)
|
|
||||||
status, local = msg["status"], msg["local"]
|
|
||||||
if local != engine.local:
|
|
||||||
raise RuntimeError(f"{status} message from "
|
|
||||||
f"{'local' if local else 'remote'} "
|
|
||||||
f"engine {eng_index}, expected it to be "
|
|
||||||
f"{'local' if engine.local else 'remote'}")
|
|
||||||
|
|
||||||
if status == "HELLO" and engine.state == CoreEngineState.NEW:
|
def _wait_for_engine_startup(self, handshake_socket: zmq.Socket,
|
||||||
|
input_address: str, output_address: str):
|
||||||
|
addresses = EngineZmqAddresses(
|
||||||
|
inputs=[input_address],
|
||||||
|
outputs=[output_address],
|
||||||
|
)
|
||||||
|
|
||||||
# Send init message with DP config info.
|
coordinator = self.resources.coordinator
|
||||||
init_message = self.encoder.encode({
|
if coordinator is not None:
|
||||||
"output_socket_address": output_address,
|
addresses.coordinator_input, addresses.coordinator_output = (
|
||||||
"parallel_config": {
|
coordinator.get_engine_socket_addresses())
|
||||||
"data_parallel_master_ip":
|
|
||||||
parallel_config.data_parallel_master_ip,
|
|
||||||
"data_parallel_master_port":
|
|
||||||
parallel_config.data_parallel_master_port,
|
|
||||||
"data_parallel_size":
|
|
||||||
parallel_config.data_parallel_size,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
sync_input_socket.send_multipart((eng_identity, *init_message),
|
|
||||||
copy=False)
|
|
||||||
conn_pending[0 if local else 1] -= 1
|
|
||||||
start_pending[0 if local else 1] += 1
|
|
||||||
engine.state = CoreEngineState.CONNECTED
|
|
||||||
elif status == "READY" and (engine.state
|
|
||||||
== CoreEngineState.CONNECTED):
|
|
||||||
# Setup KV cache config with initialization state from
|
|
||||||
# engine core process. Sum values from all engines in DP case.
|
|
||||||
cache_config = self.vllm_config.cache_config
|
|
||||||
num_gpu_blocks = cache_config.num_gpu_blocks or 0
|
|
||||||
num_gpu_blocks += msg['num_gpu_blocks']
|
|
||||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
|
||||||
|
|
||||||
start_pending[0 if local else 1] -= 1
|
wait_for_engine_startup(
|
||||||
engine.state = CoreEngineState.READY
|
handshake_socket,
|
||||||
else:
|
addresses,
|
||||||
raise RuntimeError(f"Unexpected {status} message for "
|
self.core_engines,
|
||||||
f"{'local' if local else 'remote'} engine "
|
self.vllm_config.parallel_config,
|
||||||
f"{eng_index} in {engine.state} state.")
|
self.vllm_config.cache_config,
|
||||||
|
self.resources.local_engine_manager,
|
||||||
logger.debug("%s from %s core engine process %s.", status,
|
coordinator.proc if coordinator else None,
|
||||||
"local" if local else "remote", eng_index)
|
)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
# Terminate background resources.
|
# Terminate background resources.
|
||||||
@ -605,8 +559,8 @@ class SyncMPClient(MPClient):
|
|||||||
try:
|
try:
|
||||||
shutdown_socket.bind(shutdown_path)
|
shutdown_socket.bind(shutdown_path)
|
||||||
poller = zmq.Poller()
|
poller = zmq.Poller()
|
||||||
poller.register(shutdown_socket)
|
poller.register(shutdown_socket, zmq.POLLIN)
|
||||||
poller.register(out_socket)
|
poller.register(out_socket, zmq.POLLIN)
|
||||||
while True:
|
while True:
|
||||||
socks = poller.poll()
|
socks = poller.poll()
|
||||||
if not socks:
|
if not socks:
|
||||||
@ -668,7 +622,7 @@ class SyncMPClient(MPClient):
|
|||||||
future: Future[Any] = Future()
|
future: Future[Any] = Future()
|
||||||
self.utility_results[call_id] = future
|
self.utility_results[call_id] = future
|
||||||
self._send_input(EngineCoreRequestType.UTILITY,
|
self._send_input(EngineCoreRequestType.UTILITY,
|
||||||
(call_id, method, args))
|
(0, call_id, method, args))
|
||||||
|
|
||||||
return future.result()
|
return future.result()
|
||||||
|
|
||||||
@ -730,15 +684,21 @@ class SyncMPClient(MPClient):
|
|||||||
class AsyncMPClient(MPClient):
|
class AsyncMPClient(MPClient):
|
||||||
"""Asyncio-compatible client for multi-proc EngineCore."""
|
"""Asyncio-compatible client for multi-proc EngineCore."""
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
|
def __init__(self,
|
||||||
log_stats: bool):
|
vllm_config: VllmConfig,
|
||||||
|
executor_class: type[Executor],
|
||||||
|
log_stats: bool,
|
||||||
|
client_addresses: Optional[dict[str, str]] = None,
|
||||||
|
client_index: int = 0):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
asyncio_mode=True,
|
asyncio_mode=True,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
executor_class=executor_class,
|
executor_class=executor_class,
|
||||||
log_stats=log_stats,
|
log_stats=log_stats,
|
||||||
|
client_addresses=client_addresses,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.client_index = client_index
|
||||||
self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs,
|
self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs,
|
||||||
Exception]]()
|
Exception]]()
|
||||||
try:
|
try:
|
||||||
@ -854,12 +814,13 @@ class AsyncMPClient(MPClient):
|
|||||||
future = asyncio.get_running_loop().create_future()
|
future = asyncio.get_running_loop().create_future()
|
||||||
self.utility_results[call_id] = future
|
self.utility_results[call_id] = future
|
||||||
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
|
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
|
||||||
(call_id, method, args)))
|
(self.client_index, call_id, method, args)))
|
||||||
await self._send_input_message(message, engine, args)
|
await self._send_input_message(message, engine, args)
|
||||||
self._ensure_output_queue_task()
|
self._ensure_output_queue_task()
|
||||||
return await future
|
return await future
|
||||||
|
|
||||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||||
|
request.client_index = self.client_index
|
||||||
await self._send_input(EngineCoreRequestType.ADD, request)
|
await self._send_input(EngineCoreRequestType.ADD, request)
|
||||||
self._ensure_output_queue_task()
|
self._ensure_output_queue_task()
|
||||||
|
|
||||||
@ -921,17 +882,120 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
|
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
|
||||||
EngineCore."""
|
EngineCore."""
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
|
def __init__(self,
|
||||||
log_stats: bool):
|
vllm_config: VllmConfig,
|
||||||
|
executor_class: type[Executor],
|
||||||
|
log_stats: bool,
|
||||||
|
client_addresses: Optional[dict[str, str]] = None,
|
||||||
|
client_index: int = 0):
|
||||||
|
|
||||||
self.current_wave = 0
|
self.current_wave = 0
|
||||||
self.engines_running = False
|
self.engines_running = False
|
||||||
|
# To route aborts to the correct engine.
|
||||||
self.reqs_in_flight: dict[str, CoreEngine] = {}
|
self.reqs_in_flight: dict[str, CoreEngine] = {}
|
||||||
|
|
||||||
super().__init__(vllm_config, executor_class, log_stats)
|
super().__init__(vllm_config, executor_class, log_stats,
|
||||||
|
client_addresses, client_index)
|
||||||
|
|
||||||
assert len(self.core_engines) > 1
|
assert len(self.core_engines) > 1
|
||||||
|
|
||||||
|
# List of [waiting, running] pair per engine.
|
||||||
|
self.lb_engines: list[list[int]] = []
|
||||||
|
|
||||||
|
self.first_req_sock_addr = get_open_zmq_inproc_path()
|
||||||
|
self.first_req_send_socket = self.resources.first_req_send_socket = (
|
||||||
|
make_zmq_socket(self.ctx,
|
||||||
|
self.first_req_sock_addr,
|
||||||
|
zmq.PAIR,
|
||||||
|
bind=True))
|
||||||
|
try:
|
||||||
|
# If we are running in an asyncio event loop, start the stats task.
|
||||||
|
# Otherwise, it will be started lazily.
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
self._ensure_stats_update_task()
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _ensure_stats_update_task(self):
|
||||||
|
resources = self.resources
|
||||||
|
if resources.stats_update_task is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
assert self.stats_update_address is not None
|
||||||
|
|
||||||
|
async def run_engine_stats_update_task():
|
||||||
|
with make_zmq_socket(self.ctx, self.stats_update_address,
|
||||||
|
zmq.XSUB) as socket, make_zmq_socket(
|
||||||
|
self.ctx,
|
||||||
|
self.first_req_sock_addr,
|
||||||
|
zmq.PAIR,
|
||||||
|
bind=False) as first_req_rcv_socket:
|
||||||
|
# Send subscription message.
|
||||||
|
await socket.send(b'\x01')
|
||||||
|
|
||||||
|
poller = zmq.asyncio.Poller()
|
||||||
|
poller.register(socket, zmq.POLLIN)
|
||||||
|
poller.register(first_req_rcv_socket, zmq.POLLIN)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
events = await poller.poll()
|
||||||
|
if not self.engines_running and len(events) == 2 or (
|
||||||
|
events[0][0] == first_req_rcv_socket):
|
||||||
|
# Send a message to notify the coordinator that
|
||||||
|
# we're sending a request while the engines are
|
||||||
|
# paused, so that it can wake the others up
|
||||||
|
# (to run dummy EP loop).
|
||||||
|
self.engines_running = True
|
||||||
|
buf = first_req_rcv_socket.recv(
|
||||||
|
flags=zmq.NOBLOCK).result()
|
||||||
|
target_eng_index = int.from_bytes(buf, "little")
|
||||||
|
msg = msgspec.msgpack.encode(
|
||||||
|
(target_eng_index, self.current_wave))
|
||||||
|
await socket.send(msg)
|
||||||
|
|
||||||
|
buf = None
|
||||||
|
while True:
|
||||||
|
# Drain all stats events (we only care about latest).
|
||||||
|
future: asyncio.Future[bytes] = socket.recv(
|
||||||
|
flags=zmq.NOBLOCK)
|
||||||
|
if isinstance(future.exception(), zmq.Again):
|
||||||
|
break
|
||||||
|
buf = future.result()
|
||||||
|
if buf is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update local load-balancing state.
|
||||||
|
counts, wave, running = msgspec.msgpack.decode(buf)
|
||||||
|
self.current_wave = wave
|
||||||
|
self.engines_running = running
|
||||||
|
self.lb_engines = counts
|
||||||
|
|
||||||
|
resources.stats_update_task = asyncio.create_task(
|
||||||
|
run_engine_stats_update_task())
|
||||||
|
|
||||||
|
def get_core_engine_for_request(self) -> CoreEngine:
|
||||||
|
if not self.lb_engines:
|
||||||
|
return self.core_engines[0]
|
||||||
|
# TODO use P2C alg for larger DP sizes
|
||||||
|
num_engines = len(self.lb_engines)
|
||||||
|
min_counts = [sys.maxsize, sys.maxsize]
|
||||||
|
eng_index = 0
|
||||||
|
for i in range(num_engines):
|
||||||
|
# Start from client_index to help with balancing when engines
|
||||||
|
# are empty.
|
||||||
|
idx = (self.client_index + i) % num_engines
|
||||||
|
counts = self.lb_engines[idx]
|
||||||
|
if counts < min_counts:
|
||||||
|
min_counts = counts
|
||||||
|
eng_index = idx
|
||||||
|
# Adjust local counts for better balancing between stats updates
|
||||||
|
# from the coordinator (which happen every 100ms).
|
||||||
|
if min_counts[0]:
|
||||||
|
min_counts[0] += 1
|
||||||
|
else:
|
||||||
|
min_counts[1] += 1
|
||||||
|
return self.core_engines[eng_index]
|
||||||
|
|
||||||
async def call_utility_async(self, method: str, *args) -> Any:
|
async def call_utility_async(self, method: str, *args) -> Any:
|
||||||
# Only the result from the first engine is returned.
|
# Only the result from the first engine is returned.
|
||||||
return (await asyncio.gather(*[
|
return (await asyncio.gather(*[
|
||||||
@ -940,62 +1004,30 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
]))[0]
|
]))[0]
|
||||||
|
|
||||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||||
|
self._ensure_stats_update_task()
|
||||||
|
|
||||||
request.current_wave = self.current_wave
|
request.current_wave = self.current_wave
|
||||||
|
request.client_index = self.client_index
|
||||||
|
|
||||||
chosen_engine = self.get_core_engine_for_request()
|
chosen_engine = self.get_core_engine_for_request()
|
||||||
self.reqs_in_flight[request.request_id] = chosen_engine
|
self.reqs_in_flight[request.request_id] = chosen_engine
|
||||||
chosen_engine.num_reqs_in_flight += 1
|
|
||||||
|
|
||||||
to_await = self._send_input(EngineCoreRequestType.ADD, request,
|
to_await = self._send_input(EngineCoreRequestType.ADD, request,
|
||||||
chosen_engine)
|
chosen_engine)
|
||||||
if not self.engines_running:
|
if not self.engines_running:
|
||||||
# Send request to chosen engine and dp start loop
|
# Notify coordinator that we're sending a request
|
||||||
# control message to all other engines.
|
await self.first_req_send_socket.send(chosen_engine.identity)
|
||||||
self.engines_running = True
|
|
||||||
to_await = asyncio.gather(
|
|
||||||
to_await, # type: ignore[assignment]
|
|
||||||
*self._start_wave_coros(exclude_index=chosen_engine.index))
|
|
||||||
|
|
||||||
await to_await
|
await to_await
|
||||||
|
|
||||||
self._ensure_output_queue_task()
|
self._ensure_output_queue_task()
|
||||||
|
|
||||||
def get_core_engine_for_request(self) -> CoreEngine:
|
|
||||||
return min(self.core_engines, key=lambda e: e.num_reqs_in_flight)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def process_engine_outputs(self: "DPAsyncMPClient",
|
async def process_engine_outputs(self: "DPAsyncMPClient",
|
||||||
outputs: EngineCoreOutputs):
|
outputs: EngineCoreOutputs):
|
||||||
if self.reqs_in_flight:
|
if outputs.finished_requests and self.reqs_in_flight:
|
||||||
for req_id in outputs.finished_requests or ():
|
for req_id in outputs.finished_requests:
|
||||||
if engine := self.reqs_in_flight.pop(req_id, None):
|
self.reqs_in_flight.pop(req_id, None)
|
||||||
engine.num_reqs_in_flight -= 1
|
|
||||||
|
|
||||||
if outputs.wave_complete is not None:
|
|
||||||
# Current wave is complete, move to next wave number
|
|
||||||
# and mark engines as paused.
|
|
||||||
if self.current_wave <= outputs.wave_complete:
|
|
||||||
self.current_wave = outputs.wave_complete + 1
|
|
||||||
self.engines_running = False
|
|
||||||
|
|
||||||
elif outputs.start_wave is not None and (
|
|
||||||
outputs.start_wave > self.current_wave or
|
|
||||||
(outputs.start_wave == self.current_wave
|
|
||||||
and not self.engines_running)):
|
|
||||||
# Engine received request for a non-current wave so we must ensure
|
|
||||||
# that other engines progress to the next wave.
|
|
||||||
self.current_wave = outputs.start_wave
|
|
||||||
self.engines_running = True
|
|
||||||
await asyncio.gather(*self._start_wave_coros(
|
|
||||||
exclude_index=outputs.engine_index))
|
|
||||||
|
|
||||||
def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]:
|
|
||||||
logger.debug("Sending start DP wave %d.", self.current_wave)
|
|
||||||
return [
|
|
||||||
self._send_input(EngineCoreRequestType.START_DP_WAVE,
|
|
||||||
self.current_wave, engine)
|
|
||||||
for engine in self.core_engines if engine.index != exclude_index
|
|
||||||
]
|
|
||||||
|
|
||||||
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
||||||
if not request_ids:
|
if not request_ids:
|
||||||
|
|||||||
@ -12,13 +12,12 @@ from vllm.config import SupportsMetricsInfo, VllmConfig
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
|
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
|
||||||
from vllm.v1.engine import FinishReason
|
from vllm.v1.engine import FinishReason
|
||||||
|
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
|
||||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||||
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
|
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_LOCAL_LOGGING_INTERVAL_SEC = 5.0
|
|
||||||
|
|
||||||
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
|
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
|
||||||
|
|
||||||
|
|
||||||
@ -35,7 +34,7 @@ class StatLoggerBase(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def record(self, scheduler_stats: SchedulerStats,
|
def record(self, scheduler_stats: Optional[SchedulerStats],
|
||||||
iteration_stats: Optional[IterationStats]):
|
iteration_stats: Optional[IterationStats]):
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -78,14 +77,16 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
# Compute summary metrics for tracked stats
|
# Compute summary metrics for tracked stats
|
||||||
return float(np.sum(tracked_stats) / (now - self.last_log_time))
|
return float(np.sum(tracked_stats) / (now - self.last_log_time))
|
||||||
|
|
||||||
def record(self, scheduler_stats: SchedulerStats,
|
def record(self, scheduler_stats: Optional[SchedulerStats],
|
||||||
iteration_stats: Optional[IterationStats]):
|
iteration_stats: Optional[IterationStats]):
|
||||||
"""Log Stats to standard output."""
|
"""Log Stats to standard output."""
|
||||||
|
|
||||||
if iteration_stats:
|
if iteration_stats:
|
||||||
self._track_iteration_stats(iteration_stats)
|
self._track_iteration_stats(iteration_stats)
|
||||||
|
|
||||||
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
|
if scheduler_stats is not None:
|
||||||
|
self.prefix_caching_metrics.observe(
|
||||||
|
scheduler_stats.prefix_cache_stats)
|
||||||
|
|
||||||
if scheduler_stats.spec_decoding_stats is not None:
|
if scheduler_stats.spec_decoding_stats is not None:
|
||||||
self.spec_decoding_logging.observe(
|
self.spec_decoding_logging.observe(
|
||||||
@ -131,9 +132,10 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||||
|
|
||||||
def log_engine_initialized(self):
|
def log_engine_initialized(self):
|
||||||
|
if self.vllm_config.cache_config.num_gpu_blocks:
|
||||||
logger.info(
|
logger.info(
|
||||||
"vllm cache_config_info with initialization " \
|
"Engine %03d: vllm cache_config_info with initialization "
|
||||||
"after num_gpu_blocks is: %d",
|
"after num_gpu_blocks is: %d", self.engine_index,
|
||||||
self.vllm_config.cache_config.num_gpu_blocks)
|
self.vllm_config.cache_config.num_gpu_blocks)
|
||||||
|
|
||||||
|
|
||||||
@ -144,7 +146,8 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
_spec_decoding_cls = SpecDecodingProm
|
_spec_decoding_cls = SpecDecodingProm
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||||
self._unregister_vllm_metrics()
|
|
||||||
|
unregister_vllm_metrics()
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.engine_index = engine_index
|
self.engine_index = engine_index
|
||||||
# Use this flag to hide metrics that were deprecated in
|
# Use this flag to hide metrics that were deprecated in
|
||||||
@ -169,11 +172,13 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
self.gauge_scheduler_running = self._gauge_cls(
|
self.gauge_scheduler_running = self._gauge_cls(
|
||||||
name="vllm:num_requests_running",
|
name="vllm:num_requests_running",
|
||||||
documentation="Number of requests in model execution batches.",
|
documentation="Number of requests in model execution batches.",
|
||||||
|
multiprocess_mode="mostrecent",
|
||||||
labelnames=labelnames).labels(*labelvalues)
|
labelnames=labelnames).labels(*labelvalues)
|
||||||
|
|
||||||
self.gauge_scheduler_waiting = self._gauge_cls(
|
self.gauge_scheduler_waiting = self._gauge_cls(
|
||||||
name="vllm:num_requests_waiting",
|
name="vllm:num_requests_waiting",
|
||||||
documentation="Number of requests waiting to be processed.",
|
documentation="Number of requests waiting to be processed.",
|
||||||
|
multiprocess_mode="mostrecent",
|
||||||
labelnames=labelnames).labels(*labelvalues)
|
labelnames=labelnames).labels(*labelvalues)
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -182,6 +187,7 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
self.gauge_gpu_cache_usage = self._gauge_cls(
|
self.gauge_gpu_cache_usage = self._gauge_cls(
|
||||||
name="vllm:gpu_cache_usage_perc",
|
name="vllm:gpu_cache_usage_perc",
|
||||||
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
|
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
|
||||||
|
multiprocess_mode="mostrecent",
|
||||||
labelnames=labelnames).labels(*labelvalues)
|
labelnames=labelnames).labels(*labelvalues)
|
||||||
|
|
||||||
self.counter_gpu_prefix_cache_queries = self._counter_cls(
|
self.counter_gpu_prefix_cache_queries = self._counter_cls(
|
||||||
@ -242,6 +248,9 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
buckets=build_1_2_5_buckets(max_model_len),
|
buckets=build_1_2_5_buckets(max_model_len),
|
||||||
labelnames=labelnames).labels(*labelvalues)
|
labelnames=labelnames).labels(*labelvalues)
|
||||||
|
|
||||||
|
# TODO: This metric might be incorrect in case of using multiple
|
||||||
|
# api_server counts which uses prometheus mp.
|
||||||
|
# See: https://github.com/vllm-project/vllm/pull/18053
|
||||||
self.histogram_iteration_tokens = \
|
self.histogram_iteration_tokens = \
|
||||||
self._histogram_cls(
|
self._histogram_cls(
|
||||||
name="vllm:iteration_tokens_total",
|
name="vllm:iteration_tokens_total",
|
||||||
@ -340,6 +349,9 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
#
|
#
|
||||||
# LoRA metrics
|
# LoRA metrics
|
||||||
#
|
#
|
||||||
|
|
||||||
|
# TODO: This metric might be incorrect in case of using multiple
|
||||||
|
# api_server counts which uses prometheus mp.
|
||||||
self.gauge_lora_info: Optional[prometheus_client.Gauge] = None
|
self.gauge_lora_info: Optional[prometheus_client.Gauge] = None
|
||||||
if vllm_config.lora_config is not None:
|
if vllm_config.lora_config is not None:
|
||||||
self.labelname_max_lora = "max_lora"
|
self.labelname_max_lora = "max_lora"
|
||||||
@ -350,13 +362,16 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
self._gauge_cls(
|
self._gauge_cls(
|
||||||
name="vllm:lora_requests_info",
|
name="vllm:lora_requests_info",
|
||||||
documentation="Running stats on lora requests.",
|
documentation="Running stats on lora requests.",
|
||||||
|
multiprocess_mode="sum",
|
||||||
labelnames=[
|
labelnames=[
|
||||||
self.labelname_max_lora,
|
self.labelname_max_lora,
|
||||||
self.labelname_waiting_lora_adapters,
|
self.labelname_waiting_lora_adapters,
|
||||||
self.labelname_running_lora_adapters,
|
self.labelname_running_lora_adapters,
|
||||||
])
|
],
|
||||||
|
)
|
||||||
|
|
||||||
def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
|
def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
|
||||||
|
|
||||||
metrics_info = config_obj.metrics_info()
|
metrics_info = config_obj.metrics_info()
|
||||||
metrics_info["engine"] = self.engine_index
|
metrics_info["engine"] = self.engine_index
|
||||||
|
|
||||||
@ -372,12 +387,15 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
info_gauge = self._gauge_cls(
|
info_gauge = self._gauge_cls(
|
||||||
name=name,
|
name=name,
|
||||||
documentation=documentation,
|
documentation=documentation,
|
||||||
labelnames=metrics_info.keys()).labels(**metrics_info)
|
multiprocess_mode="mostrecent",
|
||||||
|
labelnames=metrics_info.keys(),
|
||||||
|
).labels(**metrics_info)
|
||||||
info_gauge.set(1)
|
info_gauge.set(1)
|
||||||
|
|
||||||
def record(self, scheduler_stats: SchedulerStats,
|
def record(self, scheduler_stats: Optional[SchedulerStats],
|
||||||
iteration_stats: Optional[IterationStats]):
|
iteration_stats: Optional[IterationStats]):
|
||||||
"""Log to prometheus."""
|
"""Log to prometheus."""
|
||||||
|
if scheduler_stats is not None:
|
||||||
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
|
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
|
||||||
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
|
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
|
||||||
|
|
||||||
@ -445,13 +463,6 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
self.gauge_lora_info.labels(**lora_info_labels)\
|
self.gauge_lora_info.labels(**lora_info_labels)\
|
||||||
.set_to_current_time()
|
.set_to_current_time()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _unregister_vllm_metrics():
|
|
||||||
# Unregister any existing vLLM collectors (for CI/CD
|
|
||||||
for collector in list(prometheus_client.REGISTRY._collector_to_names):
|
|
||||||
if hasattr(collector, "_name") and "vllm" in collector._name:
|
|
||||||
prometheus_client.REGISTRY.unregister(collector)
|
|
||||||
|
|
||||||
def log_engine_initialized(self):
|
def log_engine_initialized(self):
|
||||||
self.log_metrics_info("cache_config", self.vllm_config.cache_config)
|
self.log_metrics_info("cache_config", self.vllm_config.cache_config)
|
||||||
|
|
||||||
|
|||||||
77
vllm/v1/metrics/prometheus.py
Normal file
77
vllm/v1/metrics/prometheus.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Global temporary directory for prometheus multiprocessing
|
||||||
|
_prometheus_multiproc_dir: Optional[tempfile.TemporaryDirectory] = None
|
||||||
|
|
||||||
|
|
||||||
|
def setup_multiprocess_prometheus():
|
||||||
|
"""Set up prometheus multiprocessing directory if not already configured.
|
||||||
|
|
||||||
|
"""
|
||||||
|
global _prometheus_multiproc_dir
|
||||||
|
|
||||||
|
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
||||||
|
# Make TemporaryDirectory for prometheus multiprocessing
|
||||||
|
# Note: global TemporaryDirectory will be automatically
|
||||||
|
# cleaned up upon exit.
|
||||||
|
_prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
||||||
|
os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name
|
||||||
|
logger.debug("Created PROMETHEUS_MULTIPROC_DIR at %s",
|
||||||
|
_prometheus_multiproc_dir.name)
|
||||||
|
else:
|
||||||
|
logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. "
|
||||||
|
"This directory must be wiped between vLLM runs or "
|
||||||
|
"you will find inaccurate metrics. Unset the variable "
|
||||||
|
"and vLLM will properly handle cleanup.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_prometheus_registry():
|
||||||
|
"""Get the appropriate prometheus registry based on multiprocessing
|
||||||
|
configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Registry: A prometheus registry
|
||||||
|
"""
|
||||||
|
if os.getenv("PROMETHEUS_MULTIPROC_DIR") is not None:
|
||||||
|
logger.debug("Using multiprocess registry for prometheus metrics")
|
||||||
|
registry = CollectorRegistry()
|
||||||
|
multiprocess.MultiProcessCollector(registry)
|
||||||
|
return registry
|
||||||
|
|
||||||
|
return REGISTRY
|
||||||
|
|
||||||
|
|
||||||
|
def unregister_vllm_metrics():
|
||||||
|
"""Unregister any existing vLLM collectors from the prometheus registry.
|
||||||
|
|
||||||
|
This is useful for testing and CI/CD where metrics may be registered
|
||||||
|
multiple times across test runs.
|
||||||
|
|
||||||
|
Also, in case of multiprocess, we need to unregister the metrics from the
|
||||||
|
global registry.
|
||||||
|
"""
|
||||||
|
registry = REGISTRY
|
||||||
|
# Unregister any existing vLLM collectors
|
||||||
|
for collector in list(registry._collector_to_names):
|
||||||
|
if hasattr(collector, "_name") and "vllm" in collector._name:
|
||||||
|
registry.unregister(collector)
|
||||||
|
|
||||||
|
|
||||||
|
def shutdown_prometheus():
|
||||||
|
"""Shutdown prometheus metrics."""
|
||||||
|
try:
|
||||||
|
pid = os.getpid()
|
||||||
|
multiprocess.mark_process_dead(pid)
|
||||||
|
logger.debug("Marked Prometheus metrics for process %d as dead", pid)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error during metrics cleanup: %s", str(e))
|
||||||
@ -26,12 +26,13 @@ class Request:
|
|||||||
multi_modal_placeholders: Optional[list[PlaceholderRange]],
|
multi_modal_placeholders: Optional[list[PlaceholderRange]],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
eos_token_id: Optional[int],
|
eos_token_id: Optional[int],
|
||||||
arrival_time: float,
|
client_index: int = 0,
|
||||||
lora_request: Optional["LoRARequest"] = None,
|
lora_request: Optional["LoRARequest"] = None,
|
||||||
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
||||||
cache_salt: Optional[str] = None,
|
cache_salt: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
|
self.client_index = client_index
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
# Because of LoRA, the eos token id can be different for each request.
|
# Because of LoRA, the eos token id can be different for each request.
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
@ -90,13 +91,13 @@ class Request:
|
|||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
request_id=request.request_id,
|
request_id=request.request_id,
|
||||||
|
client_index=request.client_index,
|
||||||
prompt_token_ids=request.prompt_token_ids,
|
prompt_token_ids=request.prompt_token_ids,
|
||||||
multi_modal_inputs=request.mm_inputs,
|
multi_modal_inputs=request.mm_inputs,
|
||||||
multi_modal_hashes=request.mm_hashes,
|
multi_modal_hashes=request.mm_hashes,
|
||||||
multi_modal_placeholders=request.mm_placeholders,
|
multi_modal_placeholders=request.mm_placeholders,
|
||||||
sampling_params=request.sampling_params,
|
sampling_params=request.sampling_params,
|
||||||
eos_token_id=request.eos_token_id,
|
eos_token_id=request.eos_token_id,
|
||||||
arrival_time=request.arrival_time,
|
|
||||||
lora_request=request.lora_request,
|
lora_request=request.lora_request,
|
||||||
structured_output_request=StructuredOutputRequest(
|
structured_output_request=StructuredOutputRequest(
|
||||||
sampling_params=request.sampling_params),
|
sampling_params=request.sampling_params),
|
||||||
|
|||||||
301
vllm/v1/utils.py
301
vllm/v1/utils.py
@ -1,31 +1,41 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import os
|
import argparse
|
||||||
|
import multiprocessing
|
||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum, auto
|
||||||
from multiprocessing import Process, connection
|
from multiprocessing import Process, connection
|
||||||
from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union,
|
from multiprocessing.process import BaseProcess
|
||||||
overload)
|
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
|
||||||
|
Union, overload)
|
||||||
|
|
||||||
|
import msgspec
|
||||||
import torch
|
import torch
|
||||||
|
import zmq
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models.utils import extract_layer_index
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||||
usage_message)
|
usage_message)
|
||||||
from vllm.utils import get_mp_context, kill_process_tree
|
from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
|
||||||
|
get_tcp_uri, kill_process_tree)
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.v1.engine.coordinator import DPCoordinator
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
STARTUP_POLL_PERIOD_MS = 10000
|
||||||
|
|
||||||
|
|
||||||
class ConstantList(Generic[T], Sequence):
|
class ConstantList(Generic[T], Sequence):
|
||||||
|
|
||||||
@ -95,6 +105,78 @@ class ConstantList(Generic[T], Sequence):
|
|||||||
return f"ConstantList({self._x})"
|
return f"ConstantList({self._x})"
|
||||||
|
|
||||||
|
|
||||||
|
def get_engine_client_zmq_addr(local_only: bool,
|
||||||
|
host: str,
|
||||||
|
port: int = 0) -> str:
|
||||||
|
return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
|
||||||
|
host, port or get_open_port()))
|
||||||
|
|
||||||
|
|
||||||
|
class APIServerProcessManager:
|
||||||
|
"""Manages a group of API server processes.
|
||||||
|
|
||||||
|
Handles creation, monitoring, and termination of API server worker
|
||||||
|
processes. Also monitors extra processes to check if they are healthy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target_server_fn: Callable,
|
||||||
|
listen_address: str,
|
||||||
|
sock: Any,
|
||||||
|
args: argparse.Namespace,
|
||||||
|
num_servers: int,
|
||||||
|
input_addresses: list[str],
|
||||||
|
output_addresses: list[str],
|
||||||
|
stats_update_address: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Initialize and start API server worker processes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_server_fn: Function to call for each API server process
|
||||||
|
listen_address: Address to listen for client connections
|
||||||
|
sock: Socket for client connections
|
||||||
|
args: Command line arguments
|
||||||
|
num_servers: Number of API server processes to start
|
||||||
|
input_addresses: Input addresses for each API server
|
||||||
|
output_addresses: Output addresses for each API server
|
||||||
|
stats_update_address: Optional stats update address
|
||||||
|
"""
|
||||||
|
self.listen_address = listen_address
|
||||||
|
self.sock = sock
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
# Start API servers
|
||||||
|
spawn_context = multiprocessing.get_context("spawn")
|
||||||
|
self.processes: list[BaseProcess] = []
|
||||||
|
|
||||||
|
for i, in_addr, out_addr in zip(range(num_servers), input_addresses,
|
||||||
|
output_addresses):
|
||||||
|
client_config = {
|
||||||
|
"input_address": in_addr,
|
||||||
|
"output_address": out_addr,
|
||||||
|
"client_index": i
|
||||||
|
}
|
||||||
|
if stats_update_address is not None:
|
||||||
|
client_config["stats_update_address"] = stats_update_address
|
||||||
|
|
||||||
|
proc = spawn_context.Process(target=target_server_fn,
|
||||||
|
name=f"ApiServer_{i}",
|
||||||
|
args=(listen_address, sock, args,
|
||||||
|
client_config))
|
||||||
|
self.processes.append(proc)
|
||||||
|
proc.start()
|
||||||
|
|
||||||
|
logger.info("Started %d API server processes", len(self.processes))
|
||||||
|
|
||||||
|
# Shutdown only the API server processes on garbage collection
|
||||||
|
# The extra processes are managed by their owners
|
||||||
|
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self._finalizer()
|
||||||
|
|
||||||
|
|
||||||
class CoreEngineProcManager:
|
class CoreEngineProcManager:
|
||||||
"""
|
"""
|
||||||
Utility class to handle creation, readiness, and shutdown
|
Utility class to handle creation, readiness, and shutdown
|
||||||
@ -109,7 +191,7 @@ class CoreEngineProcManager:
|
|||||||
local_start_index: int,
|
local_start_index: int,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
on_head_node: bool,
|
on_head_node: bool,
|
||||||
input_address: str,
|
handshake_address: str,
|
||||||
executor_class: type[Executor],
|
executor_class: type[Executor],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
):
|
):
|
||||||
@ -117,12 +199,12 @@ class CoreEngineProcManager:
|
|||||||
common_kwargs = {
|
common_kwargs = {
|
||||||
"vllm_config": vllm_config,
|
"vllm_config": vllm_config,
|
||||||
"on_head_node": on_head_node,
|
"on_head_node": on_head_node,
|
||||||
"input_address": input_address,
|
"handshake_address": handshake_address,
|
||||||
"executor_class": executor_class,
|
"executor_class": executor_class,
|
||||||
"log_stats": log_stats,
|
"log_stats": log_stats,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.processes: list[Process] = []
|
self.processes: list[BaseProcess] = []
|
||||||
for index in range(local_engine_count):
|
for index in range(local_engine_count):
|
||||||
local_index = local_start_index + index
|
local_index = local_start_index + index
|
||||||
global_index = start_index + index
|
global_index = start_index + index
|
||||||
@ -135,8 +217,7 @@ class CoreEngineProcManager:
|
|||||||
"local_dp_rank": local_index,
|
"local_dp_rank": local_index,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
self._finalizer = weakref.finalize(self, shutdown, self.processes,
|
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
||||||
input_address)
|
|
||||||
try:
|
try:
|
||||||
for proc in self.processes:
|
for proc in self.processes:
|
||||||
proc.start()
|
proc.start()
|
||||||
@ -164,9 +245,199 @@ class CoreEngineProcManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CoreEngineState(Enum):
|
||||||
|
NEW = auto()
|
||||||
|
CONNECTED = auto()
|
||||||
|
READY = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class CoreEngine:
|
||||||
|
"""One per data parallel rank."""
|
||||||
|
|
||||||
|
def __init__(self, index: int = 0, local: bool = True):
|
||||||
|
self.local = local
|
||||||
|
self.index = index
|
||||||
|
self.identity = index.to_bytes(2, "little")
|
||||||
|
|
||||||
|
self.state = CoreEngineState.NEW
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EngineZmqAddresses:
|
||||||
|
# ZMQ input socket addresses for each front-end client (requests)
|
||||||
|
inputs: list[str]
|
||||||
|
# ZMQ output socket addresses for each front-end client (responses)
|
||||||
|
outputs: list[str]
|
||||||
|
# ZMQ input socket address of DP coordinator if applicable
|
||||||
|
coordinator_input: Optional[str] = None
|
||||||
|
# ZMQ output socket address of DP coordinator if applicable
|
||||||
|
coordinator_output: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EngineHandshakeMetadata:
|
||||||
|
"""Metadata sent to each engine process during startup handshake,
|
||||||
|
including addresses of the front-end ZMQ queues that they should
|
||||||
|
connect to.
|
||||||
|
"""
|
||||||
|
addresses: EngineZmqAddresses
|
||||||
|
parallel_config: dict[str, Union[int, str]]
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_engine_startup(
|
||||||
|
handshake_socket: zmq.Socket,
|
||||||
|
addresses: EngineZmqAddresses,
|
||||||
|
core_engines: list[CoreEngine],
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
proc_manager: Optional[CoreEngineProcManager],
|
||||||
|
coord_process: Optional[Process],
|
||||||
|
):
|
||||||
|
|
||||||
|
# Wait for engine core process(es) to send ready messages.
|
||||||
|
local_count = parallel_config.data_parallel_size_local
|
||||||
|
remote_count = len(core_engines) - local_count
|
||||||
|
# [local, remote] counts
|
||||||
|
conn_pending, start_pending = [local_count, remote_count], [0, 0]
|
||||||
|
poller = zmq.Poller()
|
||||||
|
poller.register(handshake_socket, zmq.POLLIN)
|
||||||
|
|
||||||
|
if proc_manager is not None:
|
||||||
|
for sentinel in proc_manager.sentinels():
|
||||||
|
poller.register(sentinel, zmq.POLLIN)
|
||||||
|
if coord_process is not None:
|
||||||
|
poller.register(coord_process.sentinel, zmq.POLLIN)
|
||||||
|
while any(conn_pending) or any(start_pending):
|
||||||
|
events = poller.poll(STARTUP_POLL_PERIOD_MS)
|
||||||
|
if not events:
|
||||||
|
if any(conn_pending):
|
||||||
|
logger.debug(
|
||||||
|
"Waiting for %d local, %d remote core engine proc(s) "
|
||||||
|
"to connect.", *conn_pending)
|
||||||
|
if any(start_pending):
|
||||||
|
logger.debug(
|
||||||
|
"Waiting for %d local, %d remote core engine proc(s) "
|
||||||
|
"to start.", *start_pending)
|
||||||
|
continue
|
||||||
|
if len(events) > 1 or events[0][0] != handshake_socket:
|
||||||
|
# One of the local core processes exited.
|
||||||
|
finished = proc_manager.finished_procs() if proc_manager else {}
|
||||||
|
if coord_process is not None and coord_process.exitcode is not None:
|
||||||
|
finished[coord_process.name] = coord_process.exitcode
|
||||||
|
raise RuntimeError("Engine core initialization failed. "
|
||||||
|
"See root cause above. "
|
||||||
|
f"Failed core proc(s): {finished}")
|
||||||
|
|
||||||
|
# Receive HELLO and READY messages from the input socket.
|
||||||
|
eng_identity, ready_msg_bytes = handshake_socket.recv_multipart()
|
||||||
|
eng_index = int.from_bytes(eng_identity, "little")
|
||||||
|
engine = next((e for e in core_engines if e.identity == eng_identity),
|
||||||
|
None)
|
||||||
|
if engine is None:
|
||||||
|
raise RuntimeError(f"Message from engine with unexpected data "
|
||||||
|
f"parallel rank: {eng_index}")
|
||||||
|
msg = msgspec.msgpack.decode(ready_msg_bytes)
|
||||||
|
status, local = msg["status"], msg["local"]
|
||||||
|
if local != engine.local:
|
||||||
|
raise RuntimeError(f"{status} message from "
|
||||||
|
f"{'local' if local else 'remote'} "
|
||||||
|
f"engine {eng_index}, expected it to be "
|
||||||
|
f"{'local' if engine.local else 'remote'}")
|
||||||
|
|
||||||
|
if status == "HELLO" and engine.state == CoreEngineState.NEW:
|
||||||
|
|
||||||
|
# Send init message with DP config info.
|
||||||
|
init_message = msgspec.msgpack.encode(
|
||||||
|
EngineHandshakeMetadata(
|
||||||
|
addresses=addresses,
|
||||||
|
parallel_config={
|
||||||
|
"data_parallel_master_ip":
|
||||||
|
parallel_config.data_parallel_master_ip,
|
||||||
|
"data_parallel_master_port":
|
||||||
|
parallel_config.data_parallel_master_port,
|
||||||
|
"data_parallel_size":
|
||||||
|
parallel_config.data_parallel_size,
|
||||||
|
}))
|
||||||
|
handshake_socket.send_multipart((eng_identity, init_message),
|
||||||
|
copy=False)
|
||||||
|
conn_pending[0 if local else 1] -= 1
|
||||||
|
start_pending[0 if local else 1] += 1
|
||||||
|
engine.state = CoreEngineState.CONNECTED
|
||||||
|
elif status == "READY" and (engine.state == CoreEngineState.CONNECTED):
|
||||||
|
# Setup KV cache config with initialization state from
|
||||||
|
# engine core process. Sum values from all engines in DP case.
|
||||||
|
num_gpu_blocks = cache_config.num_gpu_blocks or 0
|
||||||
|
num_gpu_blocks += msg["num_gpu_blocks"]
|
||||||
|
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
|
||||||
|
start_pending[0 if local else 1] -= 1
|
||||||
|
engine.state = CoreEngineState.READY
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unexpected {status} message for "
|
||||||
|
f"{'local' if local else 'remote'} engine "
|
||||||
|
f"{eng_index} in {engine.state} state.")
|
||||||
|
|
||||||
|
logger.debug("%s from %s core engine process %s.", status,
|
||||||
|
"local" if local else "remote", eng_index)
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_completion_or_failure(
|
||||||
|
api_server_manager: APIServerProcessManager,
|
||||||
|
local_engine_manager: Optional[CoreEngineProcManager] = None,
|
||||||
|
coordinator: Optional["DPCoordinator"] = None) -> None:
|
||||||
|
"""Wait for all processes to complete or detect if any fail.
|
||||||
|
|
||||||
|
Raises an exception if any process exits with a non-zero status.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("Waiting for API servers to complete ...")
|
||||||
|
# Create a mapping of sentinels to their corresponding processes
|
||||||
|
# for efficient lookup
|
||||||
|
sentinel_to_proc: dict[Any, BaseProcess] = {
|
||||||
|
proc.sentinel: proc
|
||||||
|
for proc in api_server_manager.processes
|
||||||
|
}
|
||||||
|
|
||||||
|
if coordinator:
|
||||||
|
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
|
||||||
|
|
||||||
|
if local_engine_manager:
|
||||||
|
for proc in local_engine_manager.processes:
|
||||||
|
sentinel_to_proc[proc.sentinel] = proc
|
||||||
|
|
||||||
|
# Check if any process terminates
|
||||||
|
while sentinel_to_proc:
|
||||||
|
# Wait for any process to terminate
|
||||||
|
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc)
|
||||||
|
|
||||||
|
# Process any terminated processes
|
||||||
|
for sentinel in ready_sentinels:
|
||||||
|
proc = sentinel_to_proc.pop(sentinel)
|
||||||
|
|
||||||
|
# Check if process exited with error
|
||||||
|
if proc.exitcode != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Process {proc.name} (PID: {proc.pid}) "
|
||||||
|
f"died with exit code {proc.exitcode}")
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Received KeyboardInterrupt, shutting down API servers...")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Exception occurred while running API servers: %s",
|
||||||
|
str(e))
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
logger.info("Terminating remaining processes ...")
|
||||||
|
api_server_manager.close()
|
||||||
|
if coordinator:
|
||||||
|
coordinator.close()
|
||||||
|
if local_engine_manager:
|
||||||
|
local_engine_manager.close()
|
||||||
|
|
||||||
|
|
||||||
# Note(rob): shutdown function cannot be a bound method,
|
# Note(rob): shutdown function cannot be a bound method,
|
||||||
# else the gc cannot collect the objedecoupct.
|
# else the gc cannot collect the object.
|
||||||
def shutdown(procs: list[Process], input_address: str):
|
def shutdown(procs: list[BaseProcess]):
|
||||||
# Shutdown the process.
|
# Shutdown the process.
|
||||||
for proc in procs:
|
for proc in procs:
|
||||||
if proc.is_alive():
|
if proc.is_alive():
|
||||||
@ -185,12 +456,6 @@ def shutdown(procs: list[Process], input_address: str):
|
|||||||
if proc.is_alive() and (pid := proc.pid) is not None:
|
if proc.is_alive() and (pid := proc.pid) is not None:
|
||||||
kill_process_tree(pid)
|
kill_process_tree(pid)
|
||||||
|
|
||||||
# Remove zmq ipc socket files.
|
|
||||||
if input_address.startswith("ipc://"):
|
|
||||||
socket_file = input_address[len("ipc://"):]
|
|
||||||
if os and os.path.exists(socket_file):
|
|
||||||
os.remove(socket_file)
|
|
||||||
|
|
||||||
|
|
||||||
def bind_kv_cache(
|
def bind_kv_cache(
|
||||||
kv_caches: dict[str, torch.Tensor],
|
kv_caches: dict[str, torch.Tensor],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user