add cross-node dp arg validation

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-07-22 13:48:54 +01:00
parent 82f9292b84
commit f27a85d435
2 changed files with 18 additions and 3 deletions

View File

@ -483,8 +483,10 @@ class EngineCoreProc(EngineCore):
"""
input_ctx = zmq.Context()
is_local = local_client and client_handshake_address is None
headless = not local_client
handshake = self._perform_handshake(input_ctx, handshake_address,
identity, is_local, vllm_config,
identity, is_local, headless,
vllm_config,
vllm_config.parallel_config)
if client_handshake_address is None:
with handshake as addresses:
@ -492,7 +494,7 @@ class EngineCoreProc(EngineCore):
else:
assert local_client
local_handshake = self._perform_handshake(
input_ctx, client_handshake_address, identity, True,
input_ctx, client_handshake_address, identity, True, False,
vllm_config)
with handshake as addresses, local_handshake as client_addresses:
addresses.inputs = client_addresses.inputs
@ -509,6 +511,7 @@ class EngineCoreProc(EngineCore):
handshake_address: str,
identity: bytes,
local_client: bool,
headless: bool,
vllm_config: VllmConfig,
parallel_config_to_update: Optional[ParallelConfig] = None,
) -> Generator[EngineZmqAddresses, None, None]:
@ -533,6 +536,7 @@ class EngineCoreProc(EngineCore):
msgspec.msgpack.encode({
"status": "READY",
"local": local_client,
"headless": headless,
"num_gpu_blocks": num_gpu_blocks,
"dp_stats_address": dp_stats_address,
}))

View File

@ -725,13 +725,24 @@ def wait_for_engine_startup(
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"]
status, local, headless = msg["status"], msg["local"], msg["headless"]
if local != engine.local:
raise RuntimeError(f"{status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}")
# Remote engines must be headless iff we aren't in hybrid dp lb mode.
if not local and headless == parallel_config.data_parallel_hybrid_lb:
if headless:
raise RuntimeError(f"Remote engine {eng_index} must not use "
f"--headless in --data-parallel-hybrid-lb "
f"mode")
else:
raise RuntimeError(f"Remote engine {eng_index} must use "
f"--headless unless"
f"in --data-parallel-hybrid-lb mode")
if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info.