mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-29 08:17:04 +08:00
[Misc] Respect no_use_tqdm_on_load flag while capturing CUDA graph (#20834)
Signed-off-by: Linkun <github@lkchen.net>
This commit is contained in:
parent
147afb448b
commit
f56d2996ca
@ -2270,8 +2270,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Only rank 0 should print progress bar during capture
|
# Only rank 0 should print progress bar during capture
|
||||||
compilation_cases = reversed(self.cudagraph_batch_sizes)
|
compilation_cases = reversed(self.cudagraph_batch_sizes)
|
||||||
if is_global_first_rank():
|
if is_global_first_rank():
|
||||||
compilation_cases = tqdm(list(compilation_cases),
|
compilation_cases = tqdm(
|
||||||
desc="Capturing CUDA graph shapes")
|
list(compilation_cases),
|
||||||
|
disable=not self.load_config.use_tqdm_on_load,
|
||||||
|
desc="Capturing CUDA graph shapes")
|
||||||
for num_tokens in compilation_cases:
|
for num_tokens in compilation_cases:
|
||||||
# We skip EPLB here since we don't want to record dummy metrics
|
# We skip EPLB here since we don't want to record dummy metrics
|
||||||
for _ in range(
|
for _ in range(
|
||||||
|
|||||||
@ -1587,6 +1587,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
if get_tensor_model_parallel_rank() == 0:
|
if get_tensor_model_parallel_rank() == 0:
|
||||||
compilation_cases = tqdm(
|
compilation_cases = tqdm(
|
||||||
list(compilation_cases),
|
list(compilation_cases),
|
||||||
|
disable=not self.load_config.use_tqdm_on_load,
|
||||||
desc="Capturing CUDA graph shapes")
|
desc="Capturing CUDA graph shapes")
|
||||||
for batch_size, use_inputs_embeds in compilation_cases:
|
for batch_size, use_inputs_embeds in compilation_cases:
|
||||||
attn_metadata = (
|
attn_metadata = (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user