fix[DP][v1]: Prevent hangs from mismatched worker configurations (#26218)

Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
This commit is contained in:
Ayush Satyam 2025-10-08 11:25:08 +05:30 committed by GitHub
parent 0d4f48fa10
commit 5e65d6b2ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 11 deletions

View File

@ -336,6 +336,9 @@ class ParallelConfig:
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
This hash is also used for DP worker configuration validation
to prevent hangs from mismatched collective communication patterns.
"""
factors: list[Any] = []
factors.append(self.pipeline_parallel_size)
@ -343,6 +346,12 @@ class ParallelConfig:
factors.append(self.enable_expert_parallel)
factors.append(self.data_parallel_size)
factors.append(envs.VLLM_ALL2ALL_BACKEND)
factors.append(self.enable_eplb)
if self.enable_eplb:
factors.append(self.eplb_config.log_balancedness)
factors.append(self.eplb_config.window_size)
factors.append(self.eplb_config.step_interval)
factors.append(self.eplb_config.num_redundant_experts)
return hashlib.sha256(str(factors).encode()).hexdigest()
def __post_init__(self) -> None:

View File

@ -681,17 +681,21 @@ class EngineCoreProc(EngineCore):
# external LB case for our colocated front-end to use (coordinator
# only runs with rank 0).
dp_stats_address = self.frontend_stats_publish_address
handshake_socket.send(
msgspec.msgpack.encode(
{
"status": "READY",
"local": local_client,
"headless": headless,
"num_gpu_blocks": num_gpu_blocks,
"dp_stats_address": dp_stats_address,
}
# Include config hash for DP configuration validation
ready_msg = {
"status": "READY",
"local": local_client,
"headless": headless,
"num_gpu_blocks": num_gpu_blocks,
"dp_stats_address": dp_stats_address,
}
if vllm_config.parallel_config.data_parallel_size > 1:
ready_msg["parallel_config_hash"] = (
vllm_config.parallel_config.compute_hash()
)
)
handshake_socket.send(msgspec.msgpack.encode(ready_msg))
@staticmethod
def startup_handshake(

View File

@ -73,6 +73,7 @@ class EngineHandshakeMetadata:
addresses: EngineZmqAddresses
parallel_config: dict[str, Union[int, str, list[int]]]
parallel_config_hash: Optional[str] = None
class CoreEngineProcManager:
@ -867,7 +868,8 @@ def wait_for_engine_startup(
)
if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info.
# Send init message with DP config info and config hash.
# The config hash ensures all DP workers have compatible configs.
init_message = msgspec.msgpack.encode(
EngineHandshakeMetadata(
addresses=addresses,
@ -880,6 +882,9 @@ def wait_for_engine_startup(
"data_parallel_size",
)
},
parallel_config_hash=parallel_config.compute_hash()
if parallel_config.data_parallel_size > 1
else None,
)
)
handshake_socket.send_multipart((eng_identity, init_message), copy=False)
@ -900,6 +905,23 @@ def wait_for_engine_startup(
if addresses.frontend_stats_publish_address is None:
addresses.frontend_stats_publish_address = msg.get("dp_stats_address")
# Validate config hash consistency across DP workers
if parallel_config.data_parallel_size > 1:
worker_config_hash = msg.get("parallel_config_hash")
expected_hash = parallel_config.compute_hash()
if worker_config_hash != expected_hash:
raise RuntimeError(
f"Configuration mismatch detected for engine "
f"{eng_index}. All DP workers must have identical "
f"configurations for parameters that affect collective "
f"communication (e.g., enable_eplb, "
f"eplb_config.log_balancedness). "
f"Worker hash: {worker_config_hash}, "
f"Expected hash: {expected_hash}. "
f"Please ensure all workers are started with the same "
f"command-line arguments."
)
start_pending[0 if local else 1] -= 1
engine.state = CoreEngineState.READY
else: