mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:15:26 +08:00
[BugFix] Fix DP/EP hang (#25906)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
parent
9471879bd4
commit
03df0fb5d2
@ -2989,13 +2989,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# We currently only microbatch if the number of tokens is
|
||||
# over a certain threshold.
|
||||
if self.parallel_config.enable_dbo and allow_microbatching:
|
||||
ubatch_slices, num_tokens_after_padding = ubatch_split(
|
||||
ubatch_slices, ubatch_num_tokens_after_padding = ubatch_split(
|
||||
num_scheduled_tokens,
|
||||
total_num_scheduled_tokens,
|
||||
total_num_scheduled_tokens,
|
||||
uniform_decode=uniform_decode,
|
||||
vllm_config=self.vllm_config,
|
||||
)
|
||||
# Currently when DBO is enabled `ubatch_split` returns
|
||||
# the num_tokens_after_padding for a single ubatch, but we have 2
|
||||
# TODO(sage,lucas): this is cruft that should be addressed in the
|
||||
# padding refactor.
|
||||
if ubatch_num_tokens_after_padding is not None:
|
||||
num_tokens_after_padding = ubatch_num_tokens_after_padding * 2
|
||||
|
||||
# If we failed to microbatch, currently need to resynchronize
|
||||
# TODO(lucas,sage): we should be able to avoid this second sync by
|
||||
@ -3112,8 +3118,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# filter out the valid batch descriptor
|
||||
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
|
||||
BatchDescriptor(num_tokens=num_tokens,
|
||||
uniform_decode=uniform_decode))
|
||||
BatchDescriptor(num_tokens=num_tokens_after_padding,
|
||||
uniform_decode=uniform_decode)) \
|
||||
if not is_profile else (CUDAGraphMode.NONE, None)
|
||||
if cudagraph_runtime_mode is not None:
|
||||
# we allow forcing NONE when the dispatcher disagrees to support
|
||||
# warm ups for cudagraph capture
|
||||
@ -3125,7 +3132,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cudagraph_runtime_mode = _cg_mode
|
||||
|
||||
if ubatch_slices is not None:
|
||||
num_tokens = num_tokens // 2
|
||||
# Adjust values to reflect a single ubatch.
|
||||
# TODO(sage,lucas): this is cruft that should be addressed in
|
||||
# the padding refactor.
|
||||
num_tokens_after_padding = ubatch_slices[0].num_tokens
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[:] = num_tokens_after_padding
|
||||
|
||||
with self.maybe_randomize_inputs(input_ids), set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user