mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 18:35:42 +08:00
[V1] Only print cudagraph tqdm on rank 0 with is_global_first_rank (#19516)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
27949354fa
commit
be250bbc67
@ -1315,6 +1315,37 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
|
|||||||
return [x == 1 for x in aggregated_data.tolist()]
|
return [x == 1 for x in aggregated_data.tolist()]
|
||||||
|
|
||||||
|
|
||||||
|
def is_global_first_rank() -> bool:
|
||||||
|
"""
|
||||||
|
Check if the current process is the first rank globally across all
|
||||||
|
parallelism strategies (PP, TP, DP, EP, etc.).
|
||||||
|
|
||||||
|
Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
|
||||||
|
or `get_pp_group().is_first_rank`, this function checks the global rank
|
||||||
|
across all parallelism dimensions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if this is the global first rank (rank 0), False otherwise.
|
||||||
|
Returns True if distributed is not initialized (single process).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# If world group is available, use it for the most accurate check
|
||||||
|
global _WORLD
|
||||||
|
if _WORLD is not None:
|
||||||
|
return _WORLD.is_first_rank
|
||||||
|
|
||||||
|
# If torch distributed is not initialized, assume single process
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Fallback to torch's global rank
|
||||||
|
return torch.distributed.get_rank() == 0
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# If anything goes wrong, assume this is the first rank
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
|
def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
|
||||||
"""
|
"""
|
||||||
Returns the total number of nodes in the process group.
|
Returns the total number of nodes in the process group.
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|||||||
has_kv_transfer_group)
|
has_kv_transfer_group)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_pp_group, get_tp_group, graph_capture,
|
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
|
||||||
prepare_communication_buffer_for_model)
|
prepare_communication_buffer_for_model)
|
||||||
from vllm.forward_context import (DPMetadata, get_forward_context,
|
from vllm.forward_context import (DPMetadata, get_forward_context,
|
||||||
set_forward_context)
|
set_forward_context)
|
||||||
@ -2285,9 +2285,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# can reuse the memory pool allocated for the large shapes.
|
# can reuse the memory pool allocated for the large shapes.
|
||||||
with graph_capture(device=self.device):
|
with graph_capture(device=self.device):
|
||||||
full_cg = self.full_cuda_graph
|
full_cg = self.full_cuda_graph
|
||||||
for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes),
|
# Only rank 0 should print progress bar during capture
|
||||||
desc="Capturing CUDA graphs",
|
compilation_cases = reversed(self.cudagraph_batch_sizes)
|
||||||
total=len(self.cudagraph_batch_sizes)):
|
if is_global_first_rank():
|
||||||
|
compilation_cases = tqdm(list(compilation_cases),
|
||||||
|
desc="Capturing CUDA graph shapes")
|
||||||
|
for num_tokens in compilation_cases:
|
||||||
# We skip EPLB here since we don't want to record dummy metrics
|
# We skip EPLB here since we don't want to record dummy metrics
|
||||||
for _ in range(
|
for _ in range(
|
||||||
self.compilation_config.cudagraph_num_of_warmups):
|
self.compilation_config.cudagraph_num_of_warmups):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user