From 7a30fa8708ce7592a3a1adeab61425a28a3bbff3 Mon Sep 17 00:00:00 2001 From: Zazzle516 <77725591+Zazzle516@users.noreply.github.com> Date: Fri, 12 Sep 2025 07:18:09 +0800 Subject: [PATCH] [Doc] Clarify cudagraph capture size logic and default behavior in scheduler (#18698) Signed-off-by: Zazzle516 <2405677060@qq.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/__init__.py | 48 +++++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 24eaf2e360ab..8026d4c9e202 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -3579,30 +3579,40 @@ class VllmConfig: def _set_cudagraph_sizes(self): """ - cudagraph batchsize padding logic: + vLLM defines the default candidate list of batch sizes for CUDA graph + capture as: - `[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible - batch sizes that cudagraph will capture. - - Depending on the engine's configuration of `max_num_seqs`, the - candidate batch sizes to capture cudagraph will shrink to the subset - which just cover the range of `[1, max_num_seqs]`. In the common case, - `max_num_seqs` is 256, and the cudagraph batch sizes will be - `[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`. - - However, if users specify the cudagraph capture sizes through - compilation config, we will use the specified sizes instead. + ```python + max_graph_size = min(max_num_seqs * 2, 512) + # 1, 2, 4, then multiples of 8 up to max_graph_size + cuda_graph_sizes = [1, 2, 4, 8, 16, 24, 32, 40, ..., max_graph_size] In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` will be the final sizes to capture cudagraph (in descending order). - During runtime, if batchsize is larger than - `vllm_config.compilation_config.cudagraph_capture_sizes`, - no cudagraph will be used. - If the batch size is no larger than - `vllm_config.compilation_config.cudagraph_capture_sizes`, - we can quickly find the padded graph size for a given batch size by - looking up `vllm_config.compilation_config.bs_to_padded_graph_size`. + These sizes are used to capture and reuse CUDA graphs for + performance-critical paths (e.g., decoding). Capturing enables + significantly faster kernel dispatch by avoiding Python overhead. The + list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on + most GPUs), which controls the total allowed number of tokens in a + batch. Since each sequence may have a variable number of tokens, the + maximum usable batch size will depend on actual sequence lengths. + + Example: + With `max_num_batched_tokens = 8192`, and typical sequences + averaging ~32 tokens, most practical batch sizes fall below 256. + However, the system will still allow capture sizes up to 512 if + shape and memory permit. + + Note: + If users explicitly specify cudagraph capture sizes in the + compilation config, those will override this default logic. + At runtime: + + - If batch size <= one of the `cudagraph_capture_sizes`, the closest + padded CUDA graph will be used. + - If batch size > largest `cudagraph_capture_sizes`, cudagraph will + not be used. """ # calculate the default `batch_size_capture_list`