[Perf] Change default CUDAGraphMode from PIECEWISE to FULL_AND_PIECEWISE (#25444)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-09-23 15:29:26 -04:00 committed by GitHub
parent 63400259d0
commit 24fab45d96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 7 deletions

View File

@ -509,8 +509,15 @@ class VllmConfig:
if self.compilation_config.cudagraph_mode is None: if self.compilation_config.cudagraph_mode is None:
if envs.VLLM_USE_V1 and self.compilation_config.level \ if envs.VLLM_USE_V1 and self.compilation_config.level \
== CompilationLevel.PIECEWISE: == CompilationLevel.PIECEWISE:
# default to full and piecewise for most models
self.compilation_config.cudagraph_mode = \ self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE CUDAGraphMode.FULL_AND_PIECEWISE
# pooling model does not support full cudagraphs
if self.model_config is not None and \
self.model_config.pooler_config is not None:
self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else: else:
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

View File

@ -228,15 +228,14 @@ class CompilationConfig:
The mode of the cudagraph: The mode of the cudagraph:
- NONE, no cudagraph capture. - NONE, no cudagraph capture.
- PIECEWISE. (v1 default) - PIECEWISE.
- FULL. - FULL.
- FULL_DECODE_ONLY. - FULL_DECODE_ONLY.
- FULL_AND_PIECEWISE. - FULL_AND_PIECEWISE. (v1 default)
PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph
incompatible ops (i.e. some attention ops) outside the cudagraph incompatible ops (i.e. some attention ops) outside the cudagraph
for general flexibility. for general flexibility.
This is the default mode.
FULL mode: Capture full cudagraph for all batches. Can be good for small FULL mode: Capture full cudagraph for all batches. Can be good for small
models or workloads with small prompts; not supported by many backends. models or workloads with small prompts; not supported by many backends.
@ -249,7 +248,7 @@ class CompilationConfig:
FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and
piecewise cudagraph for prefill and mixed prefill-decode batches. piecewise cudagraph for prefill and mixed prefill-decode batches.
This is like the most performant mode for most models. This is the most performant mode for most models and is the default.
Currently, the cudagraph mode is only used for the v1 engine. Currently, the cudagraph mode is only used for the v1 engine.
Note that the cudagraph logic is generally orthogonal to the Note that the cudagraph logic is generally orthogonal to the

View File

@ -2947,8 +2947,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# TODO(luka) better system for describing dummy batches # TODO(luka) better system for describing dummy batches
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
else: else:
# Make sure max_model_len is used at the graph capture time. seq_lens = max_query_len
seq_lens = self.max_model_len
self.seq_lens.np[:num_reqs] = seq_lens self.seq_lens.np[:num_reqs] = seq_lens
self.seq_lens.np[num_reqs:] = 0 self.seq_lens.np[num_reqs:] = 0
self.seq_lens.copy_to_gpu() self.seq_lens.copy_to_gpu()
@ -3541,6 +3540,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
CUDAGraphMode.FULL_DECODE_ONLY CUDAGraphMode.FULL_DECODE_ONLY
logger.warning(msg) logger.warning(msg)
# check that if we are doing decode full-cudagraphs it is supported
if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and min_cg_support == AttentionCGSupport.NEVER):
msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_builder_name} backend (support: "
f"{min_cg_support})")
if (self.compilation_config.level == CompilationLevel.PIECEWISE and
(self.compilation_config.splitting_ops_contain_attention()
or self.compilation_config.use_inductor_graph_partition)):
msg += "; setting cudagraph_mode=PIECEWISE because "\
"attention is compiled piecewise"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
msg += "; setting cudagraph_mode=NONE because "\
"attention is not compiled piecewise"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.NONE
logger.warning(msg)
# check that if we are doing spec-decode + decode full-cudagraphs it is # check that if we are doing spec-decode + decode full-cudagraphs it is
# supported # supported
if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL