mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 04:36:31 +08:00
[Misc] Add tqdm progress bar during graph capture (#11349)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
7801f56ed7
commit
b880ffb87e
@ -13,6 +13,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
@ -21,7 +22,8 @@ from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.distributed import get_kv_transfer_group, get_pp_group
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
|
||||
graph_capture)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
@ -1413,8 +1415,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
logger.info("Capturing cudagraphs for decoding. This may lead to "
|
||||
"unexpected consequences if the model is not static. To "
|
||||
"run the model in eager mode, set 'enforce_eager=True' or "
|
||||
"use '--enforce-eager' in the CLI.")
|
||||
logger.info("If out-of-memory error occurs during cudagraph capture,"
|
||||
"use '--enforce-eager' in the CLI. "
|
||||
"If out-of-memory error occurs during cudagraph capture,"
|
||||
" consider decreasing `gpu_memory_utilization` or "
|
||||
"switching to eager mode. You can also reduce the "
|
||||
"`max_num_seqs` as needed to decrease memory usage.")
|
||||
@ -1451,8 +1453,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
# memory usage of CUDA graph.
|
||||
for virtual_engine in range(
|
||||
self.parallel_config.pipeline_parallel_size):
|
||||
for batch_size in \
|
||||
self.vllm_config.compilation_config.capture_sizes:
|
||||
# Only rank 0 should print progress bar during capture
|
||||
capture_sizes = (
|
||||
tqdm(
|
||||
self.vllm_config.compilation_config.capture_sizes,
|
||||
desc="Capturing CUDA graph shapes",
|
||||
) if get_tensor_model_parallel_rank() == 0 else
|
||||
self.vllm_config.compilation_config.capture_sizes)
|
||||
for batch_size in capture_sizes:
|
||||
attn_metadata = (
|
||||
self.attn_state.graph_capture_get_metadata_for_batch(
|
||||
batch_size,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user