mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 05:47:14 +08:00
[V1] Only print cudagraph tqdm on rank 0 with is_global_first_rank (#19516)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
27949354fa
commit
be250bbc67
@ -1315,6 +1315,37 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
|
||||
return [x == 1 for x in aggregated_data.tolist()]
|
||||
|
||||
|
||||
def is_global_first_rank() -> bool:
|
||||
"""
|
||||
Check if the current process is the first rank globally across all
|
||||
parallelism strategies (PP, TP, DP, EP, etc.).
|
||||
|
||||
Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
|
||||
or `get_pp_group().is_first_rank`, this function checks the global rank
|
||||
across all parallelism dimensions.
|
||||
|
||||
Returns:
|
||||
bool: True if this is the global first rank (rank 0), False otherwise.
|
||||
Returns True if distributed is not initialized (single process).
|
||||
"""
|
||||
try:
|
||||
# If world group is available, use it for the most accurate check
|
||||
global _WORLD
|
||||
if _WORLD is not None:
|
||||
return _WORLD.is_first_rank
|
||||
|
||||
# If torch distributed is not initialized, assume single process
|
||||
if not torch.distributed.is_initialized():
|
||||
return True
|
||||
|
||||
# Fallback to torch's global rank
|
||||
return torch.distributed.get_rank() == 0
|
||||
|
||||
except Exception:
|
||||
# If anything goes wrong, assume this is the first rank
|
||||
return True
|
||||
|
||||
|
||||
def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
|
||||
"""
|
||||
Returns the total number of nodes in the process group.
|
||||
|
||||
@ -26,7 +26,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tp_group, graph_capture,
|
||||
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
|
||||
prepare_communication_buffer_for_model)
|
||||
from vllm.forward_context import (DPMetadata, get_forward_context,
|
||||
set_forward_context)
|
||||
@ -2285,9 +2285,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
with graph_capture(device=self.device):
|
||||
full_cg = self.full_cuda_graph
|
||||
for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes),
|
||||
desc="Capturing CUDA graphs",
|
||||
total=len(self.cudagraph_batch_sizes)):
|
||||
# Only rank 0 should print progress bar during capture
|
||||
compilation_cases = reversed(self.cudagraph_batch_sizes)
|
||||
if is_global_first_rank():
|
||||
compilation_cases = tqdm(list(compilation_cases),
|
||||
desc="Capturing CUDA graph shapes")
|
||||
for num_tokens in compilation_cases:
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for _ in range(
|
||||
self.compilation_config.cudagraph_num_of_warmups):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user