Refactor: Move CUDA graph dispatch logic earlier (#27382)

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou 2025-11-23 05:10:31 +08:00 committed by GitHub
parent 7df331c66b
commit df78aeef08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3740,6 +3740,31 @@ class GPUModelRunner(
dp_rank = self.parallel_config.data_parallel_rank
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
# filter out the valid batch descriptor
_cg_mode, batch_descriptor = (
self.cudagraph_dispatcher.dispatch(
BatchDescriptor(
num_tokens=num_tokens_after_padding,
uniform_decode=uniform_decode,
has_lora=activate_lora and self.lora_config is not None,
)
)
if not is_profile
else (CUDAGraphMode.NONE, None)
)
if cudagraph_runtime_mode is not None:
# we allow forcing NONE when the dispatcher disagrees to support
# warm ups for cudagraph capture
assert (
cudagraph_runtime_mode == CUDAGraphMode.NONE
or cudagraph_runtime_mode == _cg_mode
), (
f"Cudagraph runtime mode mismatch at dummy_run. "
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
)
else:
cudagraph_runtime_mode = _cg_mode
attn_metadata: PerLayerAttnMetadata | None = None
# If force_attention is True, we always capture attention. Otherwise,
@ -3814,31 +3839,6 @@ class GPUModelRunner(
num_tokens_after_padding, None, False
)
# filter out the valid batch descriptor
_cg_mode, batch_descriptor = (
self.cudagraph_dispatcher.dispatch(
BatchDescriptor(
num_tokens=num_tokens_after_padding,
uniform_decode=uniform_decode,
has_lora=activate_lora and self.lora_config is not None,
)
)
if not is_profile
else (CUDAGraphMode.NONE, None)
)
if cudagraph_runtime_mode is not None:
# we allow forcing NONE when the dispatcher disagrees to support
# warm ups for cudagraph capture
assert (
cudagraph_runtime_mode == CUDAGraphMode.NONE
or cudagraph_runtime_mode == _cg_mode
), (
f"Cudagraph runtime mode mismatch at dummy_run. "
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
)
else:
cudagraph_runtime_mode = _cg_mode
if ubatch_slices is not None:
# Adjust values to reflect a single ubatch.
# TODO(sage,lucas): this is cruft that should be addressed in