[V1] Only print cudagraph tqdm on rank 0 with is_global_first_rank (#19516)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-07-01 15:02:09 +09:00 committed by GitHub
parent 27949354fa
commit be250bbc67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 38 additions and 4 deletions

View File

@ -1315,6 +1315,37 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
return [x == 1 for x in aggregated_data.tolist()]
def is_global_first_rank() -> bool:
"""
Check if the current process is the first rank globally across all
parallelism strategies (PP, TP, DP, EP, etc.).
Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
or `get_pp_group().is_first_rank`, this function checks the global rank
across all parallelism dimensions.
Returns:
bool: True if this is the global first rank (rank 0), False otherwise.
Returns True if distributed is not initialized (single process).
"""
try:
# If world group is available, use it for the most accurate check
global _WORLD
if _WORLD is not None:
return _WORLD.is_first_rank
# If torch distributed is not initialized, assume single process
if not torch.distributed.is_initialized():
return True
# Fallback to torch's global rank
return torch.distributed.get_rank() == 0
except Exception:
# If anything goes wrong, assume this is the first rank
return True
def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
"""
Returns the total number of nodes in the process group.

View File

@ -26,7 +26,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture,
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
prepare_communication_buffer_for_model)
from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context)
@ -2285,9 +2285,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
full_cg = self.full_cuda_graph
for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes),
desc="Capturing CUDA graphs",
total=len(self.cudagraph_batch_sizes)):
# Only rank 0 should print progress bar during capture
compilation_cases = reversed(self.cudagraph_batch_sizes)
if is_global_first_rank():
compilation_cases = tqdm(list(compilation_cases),
desc="Capturing CUDA graph shapes")
for num_tokens in compilation_cases:
# We skip EPLB here since we don't want to record dummy metrics
for _ in range(
self.compilation_config.cudagraph_num_of_warmups):