From b880ffb87e0bcde5e3693203b480df49e46d67bc Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 19 Dec 2024 23:35:18 -0500 Subject: [PATCH] [Misc] Add tqdm progress bar during graph capture (#11349) Signed-off-by: mgoin --- vllm/worker/model_runner.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6ff98a8f1bab2..2b545d1b28bd2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -13,6 +13,7 @@ import numpy as np import torch import torch.distributed import torch.nn as nn +from tqdm import tqdm import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend @@ -21,7 +22,8 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_kv_transfer_group, get_pp_group -from vllm.distributed.parallel_state import graph_capture +from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, + graph_capture) from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger @@ -1413,8 +1415,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): logger.info("Capturing cudagraphs for decoding. This may lead to " "unexpected consequences if the model is not static. To " "run the model in eager mode, set 'enforce_eager=True' or " - "use '--enforce-eager' in the CLI.") - logger.info("If out-of-memory error occurs during cudagraph capture," + "use '--enforce-eager' in the CLI. " + "If out-of-memory error occurs during cudagraph capture," " consider decreasing `gpu_memory_utilization` or " "switching to eager mode. You can also reduce the " "`max_num_seqs` as needed to decrease memory usage.") @@ -1451,8 +1453,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): # memory usage of CUDA graph. for virtual_engine in range( self.parallel_config.pipeline_parallel_size): - for batch_size in \ - self.vllm_config.compilation_config.capture_sizes: + # Only rank 0 should print progress bar during capture + capture_sizes = ( + tqdm( + self.vllm_config.compilation_config.capture_sizes, + desc="Capturing CUDA graph shapes", + ) if get_tensor_model_parallel_rank() == 0 else + self.vllm_config.compilation_config.capture_sizes) + for batch_size in capture_sizes: attn_metadata = ( self.attn_state.graph_capture_get_metadata_for_batch( batch_size,