fix[DP][v1]: Prevent hangs from mismatched worker configurations (#26218)

Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
This commit is contained in:
Ayush Satyam 2025-10-08 11:25:08 +05:30 committed by GitHub
parent 0d4f48fa10
commit 5e65d6b2ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 11 deletions

View File

@ -336,6 +336,9 @@ class ParallelConfig:
graph from input ids/embeddings to the final hidden states, graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after excluding anything before input ids/embeddings and after
the final hidden states. the final hidden states.
This hash is also used for DP worker configuration validation
to prevent hangs from mismatched collective communication patterns.
""" """
factors: list[Any] = [] factors: list[Any] = []
factors.append(self.pipeline_parallel_size) factors.append(self.pipeline_parallel_size)
@ -343,6 +346,12 @@ class ParallelConfig:
factors.append(self.enable_expert_parallel) factors.append(self.enable_expert_parallel)
factors.append(self.data_parallel_size) factors.append(self.data_parallel_size)
factors.append(envs.VLLM_ALL2ALL_BACKEND) factors.append(envs.VLLM_ALL2ALL_BACKEND)
factors.append(self.enable_eplb)
if self.enable_eplb:
factors.append(self.eplb_config.log_balancedness)
factors.append(self.eplb_config.window_size)
factors.append(self.eplb_config.step_interval)
factors.append(self.eplb_config.num_redundant_experts)
return hashlib.sha256(str(factors).encode()).hexdigest() return hashlib.sha256(str(factors).encode()).hexdigest()
def __post_init__(self) -> None: def __post_init__(self) -> None:

View File

@ -681,17 +681,21 @@ class EngineCoreProc(EngineCore):
# external LB case for our colocated front-end to use (coordinator # external LB case for our colocated front-end to use (coordinator
# only runs with rank 0). # only runs with rank 0).
dp_stats_address = self.frontend_stats_publish_address dp_stats_address = self.frontend_stats_publish_address
handshake_socket.send(
msgspec.msgpack.encode( # Include config hash for DP configuration validation
{ ready_msg = {
"status": "READY", "status": "READY",
"local": local_client, "local": local_client,
"headless": headless, "headless": headless,
"num_gpu_blocks": num_gpu_blocks, "num_gpu_blocks": num_gpu_blocks,
"dp_stats_address": dp_stats_address, "dp_stats_address": dp_stats_address,
} }
if vllm_config.parallel_config.data_parallel_size > 1:
ready_msg["parallel_config_hash"] = (
vllm_config.parallel_config.compute_hash()
) )
)
handshake_socket.send(msgspec.msgpack.encode(ready_msg))
@staticmethod @staticmethod
def startup_handshake( def startup_handshake(

View File

@ -73,6 +73,7 @@ class EngineHandshakeMetadata:
addresses: EngineZmqAddresses addresses: EngineZmqAddresses
parallel_config: dict[str, Union[int, str, list[int]]] parallel_config: dict[str, Union[int, str, list[int]]]
parallel_config_hash: Optional[str] = None
class CoreEngineProcManager: class CoreEngineProcManager:
@ -867,7 +868,8 @@ def wait_for_engine_startup(
) )
if status == "HELLO" and engine.state == CoreEngineState.NEW: if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info. # Send init message with DP config info and config hash.
# The config hash ensures all DP workers have compatible configs.
init_message = msgspec.msgpack.encode( init_message = msgspec.msgpack.encode(
EngineHandshakeMetadata( EngineHandshakeMetadata(
addresses=addresses, addresses=addresses,
@ -880,6 +882,9 @@ def wait_for_engine_startup(
"data_parallel_size", "data_parallel_size",
) )
}, },
parallel_config_hash=parallel_config.compute_hash()
if parallel_config.data_parallel_size > 1
else None,
) )
) )
handshake_socket.send_multipart((eng_identity, init_message), copy=False) handshake_socket.send_multipart((eng_identity, init_message), copy=False)
@ -900,6 +905,23 @@ def wait_for_engine_startup(
if addresses.frontend_stats_publish_address is None: if addresses.frontend_stats_publish_address is None:
addresses.frontend_stats_publish_address = msg.get("dp_stats_address") addresses.frontend_stats_publish_address = msg.get("dp_stats_address")
# Validate config hash consistency across DP workers
if parallel_config.data_parallel_size > 1:
worker_config_hash = msg.get("parallel_config_hash")
expected_hash = parallel_config.compute_hash()
if worker_config_hash != expected_hash:
raise RuntimeError(
f"Configuration mismatch detected for engine "
f"{eng_index}. All DP workers must have identical "
f"configurations for parameters that affect collective "
f"communication (e.g., enable_eplb, "
f"eplb_config.log_balancedness). "
f"Worker hash: {worker_config_hash}, "
f"Expected hash: {expected_hash}. "
f"Please ensure all workers are started with the same "
f"command-line arguments."
)
start_pending[0 if local else 1] -= 1 start_pending[0 if local else 1] -= 1
engine.state = CoreEngineState.READY engine.state = CoreEngineState.READY
else: else: