Reduce the Cuda Graph memory footprint when running with DBO (#25779)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Sage Moore 2025-09-26 15:29:56 -07:00 committed by simon-mo
parent b761df963c
commit bb79c4da2f
2 changed files with 32 additions and 28 deletions

View File

@ -3468,8 +3468,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases:
# We currently only capture ubatched graphs when its a FULL
# cudagraph and for uniform decode batches.
capture_ubatched_graph = self.parallel_config.enable_dbo \
# cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched
# version of the graph
allow_microbatching = self.parallel_config.enable_dbo \
and cudagraph_runtime_mode == CUDAGraphMode.FULL \
and uniform_decode \
and check_ubatch_thresholds(
@ -3478,37 +3480,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
uniform_decode=uniform_decode,
)
# Currently we capture both microbatched and non-microbatched
# graphs when capture_ubatched_graph is True, this is because
# occasionally we will be forced out of microbatching due to other
# DP ranks not microbatching (usually caused by an empty second
# microbatch; once we resolve this, we can remove the
# non-microbatched graph capture).
allow_microbatching_options = [True, False] if \
capture_ubatched_graph else [False]
for allow_microbatching in allow_microbatching_options:
for _ in range(
self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention = (
cudagraph_runtime_mode == CUDAGraphMode.FULL)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False)
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention = (
cudagraph_runtime_mode == CUDAGraphMode.FULL)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False)
self.maybe_remove_all_loras(self.lora_config)
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:

View File

@ -330,6 +330,18 @@ class UBatchWrapper:
# If there's no ubatching, just run the runnable object
if ubatch_slices is None:
# This is to account for the case where ubatching was aborted.
# When we capture full graphs we only capture one graph per shape,
# meaning that if we have a ubatched cudagraph for the current
# num_tokens, we don't have a non-ubatched one. Without this
# check, the cudagraph wrapper will try to capture a cudagraph
# for this shape during a normal run.
if cudagraph_runtime_mode is CUDAGraphMode.FULL:
assert batch_descriptor is not None
if batch_descriptor.num_tokens in self.cudagraphs:
cudagraph_runtime_mode = CUDAGraphMode.NONE
if cudagraph_runtime_mode in (CUDAGraphMode.NONE,
CUDAGraphMode.PIECEWISE):
return self.runnable(*args, **kwargs)