[BugFix] Fix DP/EP hang (#25906)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Lucas Wilkinson 2025-09-30 00:18:59 -04:00 committed by simon-mo
parent 9471879bd4
commit 03df0fb5d2

View File

@ -2989,13 +2989,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# We currently only microbatch if the number of tokens is
# over a certain threshold.
if self.parallel_config.enable_dbo and allow_microbatching:
ubatch_slices, num_tokens_after_padding = ubatch_split(
ubatch_slices, ubatch_num_tokens_after_padding = ubatch_split(
num_scheduled_tokens,
total_num_scheduled_tokens,
total_num_scheduled_tokens,
uniform_decode=uniform_decode,
vllm_config=self.vllm_config,
)
# Currently when DBO is enabled `ubatch_split` returns
# the num_tokens_after_padding for a single ubatch, but we have 2
# TODO(sage,lucas): this is cruft that should be addressed in the
# padding refactor.
if ubatch_num_tokens_after_padding is not None:
num_tokens_after_padding = ubatch_num_tokens_after_padding * 2
# If we failed to microbatch, currently need to resynchronize
# TODO(lucas,sage): we should be able to avoid this second sync by
@ -3112,8 +3118,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# filter out the valid batch descriptor
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
BatchDescriptor(num_tokens=num_tokens,
uniform_decode=uniform_decode))
BatchDescriptor(num_tokens=num_tokens_after_padding,
uniform_decode=uniform_decode)) \
if not is_profile else (CUDAGraphMode.NONE, None)
if cudagraph_runtime_mode is not None:
# we allow forcing NONE when the dispatcher disagrees to support
# warm ups for cudagraph capture
@ -3125,7 +3132,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_runtime_mode = _cg_mode
if ubatch_slices is not None:
num_tokens = num_tokens // 2
# Adjust values to reflect a single ubatch.
# TODO(sage,lucas): this is cruft that should be addressed in
# the padding refactor.
num_tokens_after_padding = ubatch_slices[0].num_tokens
if num_tokens_across_dp is not None:
num_tokens_across_dp[:] = num_tokens_after_padding
with self.maybe_randomize_inputs(input_ids), set_forward_context(
attn_metadata,
self.vllm_config,