mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:25:44 +08:00
fix[DP][v1]: Prevent hangs from mismatched worker configurations (#26218)
Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
This commit is contained in:
parent
0d4f48fa10
commit
5e65d6b2ad
@ -336,6 +336,9 @@ class ParallelConfig:
|
|||||||
graph from input ids/embeddings to the final hidden states,
|
graph from input ids/embeddings to the final hidden states,
|
||||||
excluding anything before input ids/embeddings and after
|
excluding anything before input ids/embeddings and after
|
||||||
the final hidden states.
|
the final hidden states.
|
||||||
|
|
||||||
|
This hash is also used for DP worker configuration validation
|
||||||
|
to prevent hangs from mismatched collective communication patterns.
|
||||||
"""
|
"""
|
||||||
factors: list[Any] = []
|
factors: list[Any] = []
|
||||||
factors.append(self.pipeline_parallel_size)
|
factors.append(self.pipeline_parallel_size)
|
||||||
@ -343,6 +346,12 @@ class ParallelConfig:
|
|||||||
factors.append(self.enable_expert_parallel)
|
factors.append(self.enable_expert_parallel)
|
||||||
factors.append(self.data_parallel_size)
|
factors.append(self.data_parallel_size)
|
||||||
factors.append(envs.VLLM_ALL2ALL_BACKEND)
|
factors.append(envs.VLLM_ALL2ALL_BACKEND)
|
||||||
|
factors.append(self.enable_eplb)
|
||||||
|
if self.enable_eplb:
|
||||||
|
factors.append(self.eplb_config.log_balancedness)
|
||||||
|
factors.append(self.eplb_config.window_size)
|
||||||
|
factors.append(self.eplb_config.step_interval)
|
||||||
|
factors.append(self.eplb_config.num_redundant_experts)
|
||||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
|||||||
@ -681,17 +681,21 @@ class EngineCoreProc(EngineCore):
|
|||||||
# external LB case for our colocated front-end to use (coordinator
|
# external LB case for our colocated front-end to use (coordinator
|
||||||
# only runs with rank 0).
|
# only runs with rank 0).
|
||||||
dp_stats_address = self.frontend_stats_publish_address
|
dp_stats_address = self.frontend_stats_publish_address
|
||||||
handshake_socket.send(
|
|
||||||
msgspec.msgpack.encode(
|
# Include config hash for DP configuration validation
|
||||||
{
|
ready_msg = {
|
||||||
"status": "READY",
|
"status": "READY",
|
||||||
"local": local_client,
|
"local": local_client,
|
||||||
"headless": headless,
|
"headless": headless,
|
||||||
"num_gpu_blocks": num_gpu_blocks,
|
"num_gpu_blocks": num_gpu_blocks,
|
||||||
"dp_stats_address": dp_stats_address,
|
"dp_stats_address": dp_stats_address,
|
||||||
}
|
}
|
||||||
|
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||||
|
ready_msg["parallel_config_hash"] = (
|
||||||
|
vllm_config.parallel_config.compute_hash()
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
handshake_socket.send(msgspec.msgpack.encode(ready_msg))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def startup_handshake(
|
def startup_handshake(
|
||||||
|
|||||||
@ -73,6 +73,7 @@ class EngineHandshakeMetadata:
|
|||||||
|
|
||||||
addresses: EngineZmqAddresses
|
addresses: EngineZmqAddresses
|
||||||
parallel_config: dict[str, Union[int, str, list[int]]]
|
parallel_config: dict[str, Union[int, str, list[int]]]
|
||||||
|
parallel_config_hash: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class CoreEngineProcManager:
|
class CoreEngineProcManager:
|
||||||
@ -867,7 +868,8 @@ def wait_for_engine_startup(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if status == "HELLO" and engine.state == CoreEngineState.NEW:
|
if status == "HELLO" and engine.state == CoreEngineState.NEW:
|
||||||
# Send init message with DP config info.
|
# Send init message with DP config info and config hash.
|
||||||
|
# The config hash ensures all DP workers have compatible configs.
|
||||||
init_message = msgspec.msgpack.encode(
|
init_message = msgspec.msgpack.encode(
|
||||||
EngineHandshakeMetadata(
|
EngineHandshakeMetadata(
|
||||||
addresses=addresses,
|
addresses=addresses,
|
||||||
@ -880,6 +882,9 @@ def wait_for_engine_startup(
|
|||||||
"data_parallel_size",
|
"data_parallel_size",
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
parallel_config_hash=parallel_config.compute_hash()
|
||||||
|
if parallel_config.data_parallel_size > 1
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
handshake_socket.send_multipart((eng_identity, init_message), copy=False)
|
handshake_socket.send_multipart((eng_identity, init_message), copy=False)
|
||||||
@ -900,6 +905,23 @@ def wait_for_engine_startup(
|
|||||||
if addresses.frontend_stats_publish_address is None:
|
if addresses.frontend_stats_publish_address is None:
|
||||||
addresses.frontend_stats_publish_address = msg.get("dp_stats_address")
|
addresses.frontend_stats_publish_address = msg.get("dp_stats_address")
|
||||||
|
|
||||||
|
# Validate config hash consistency across DP workers
|
||||||
|
if parallel_config.data_parallel_size > 1:
|
||||||
|
worker_config_hash = msg.get("parallel_config_hash")
|
||||||
|
expected_hash = parallel_config.compute_hash()
|
||||||
|
if worker_config_hash != expected_hash:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Configuration mismatch detected for engine "
|
||||||
|
f"{eng_index}. All DP workers must have identical "
|
||||||
|
f"configurations for parameters that affect collective "
|
||||||
|
f"communication (e.g., enable_eplb, "
|
||||||
|
f"eplb_config.log_balancedness). "
|
||||||
|
f"Worker hash: {worker_config_hash}, "
|
||||||
|
f"Expected hash: {expected_hash}. "
|
||||||
|
f"Please ensure all workers are started with the same "
|
||||||
|
f"command-line arguments."
|
||||||
|
)
|
||||||
|
|
||||||
start_pending[0 if local else 1] -= 1
|
start_pending[0 if local else 1] -= 1
|
||||||
engine.state = CoreEngineState.READY
|
engine.state = CoreEngineState.READY
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user