From dbdea93f46860d5c5190557259ff06d48efcbd83 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 26 Sep 2025 15:29:56 -0700 Subject: [PATCH] Reduce the Cuda Graph memory footprint when running with DBO (#25779) Signed-off-by: Sage Moore Signed-off-by: yewentao256 --- vllm/v1/worker/gpu_model_runner.py | 48 ++++++++++++---------------- vllm/v1/worker/gpu_ubatch_wrapper.py | 12 +++++++ 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2fac708905d07..4fd4f9128c6eb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3477,8 +3477,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # We skip EPLB here since we don't want to record dummy metrics for num_tokens in compilation_cases: # We currently only capture ubatched graphs when its a FULL - # cudagraph and for uniform decode batches. - capture_ubatched_graph = self.parallel_config.enable_dbo \ + # cudagraph, a uniform decode batch, and the number of tokens + # is above the threshold. Otherwise we just capture a non-ubatched + # version of the graph + allow_microbatching = self.parallel_config.enable_dbo \ and cudagraph_runtime_mode == CUDAGraphMode.FULL \ and uniform_decode \ and check_ubatch_thresholds( @@ -3487,37 +3489,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): uniform_decode=uniform_decode, ) - # Currently we capture both microbatched and non-microbatched - # graphs when capture_ubatched_graph is True, this is because - # occasionally we will be forced out of microbatching due to other - # DP ranks not microbatching (usually caused by an empty second - # microbatch; once we resolve this, we can remove the - # non-microbatched graph capture). - allow_microbatching_options = [True, False] if \ - capture_ubatched_graph else [False] - for allow_microbatching in allow_microbatching_options: - for _ in range( - self.compilation_config.cudagraph_num_of_warmups): - # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. - # But be careful, warm up with `NONE`is orthogonal to - # if we want to warm up attention or not. This is - # different from the case where `FULL` implies capture - # attention while `PIECEWISE` implies no attention. - force_attention = ( - cudagraph_runtime_mode == CUDAGraphMode.FULL) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, - uniform_decode=uniform_decode, - allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False) + for _ in range(self.compilation_config.cudagraph_num_of_warmups): + # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. + # But be careful, warm up with `NONE`is orthogonal to + # if we want to warm up attention or not. This is + # different from the case where `FULL` implies capture + # attention while `PIECEWISE` implies no attention. + force_attention = ( + cudagraph_runtime_mode == CUDAGraphMode.FULL) self._dummy_run(num_tokens, - cudagraph_runtime_mode=cudagraph_runtime_mode, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, uniform_decode=uniform_decode, allow_microbatching=allow_microbatching, skip_eplb=True, remove_lora=False) + self._dummy_run(num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False) self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 5e4c1d32ab6cd..39be8c74102e4 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -330,6 +330,18 @@ class UBatchWrapper: # If there's no ubatching, just run the runnable object if ubatch_slices is None: + + # This is to account for the case where ubatching was aborted. + # When we capture full graphs we only capture one graph per shape, + # meaning that if we have a ubatched cudagraph for the current + # num_tokens, we don't have a non-ubatched one. Without this + # check, the cudagraph wrapper will try to capture a cudagraph + # for this shape during a normal run. + if cudagraph_runtime_mode is CUDAGraphMode.FULL: + assert batch_descriptor is not None + if batch_descriptor.num_tokens in self.cudagraphs: + cudagraph_runtime_mode = CUDAGraphMode.NONE + if cudagraph_runtime_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE): return self.runnable(*args, **kwargs)