[BugFix] Fix assert batch_descriptor.num_tokens == num_tokens_padded (#30173)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-12-09 10:36:12 -05:00 committed by GitHub
parent 5dcd593baf
commit 56037dfa2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 65 additions and 33 deletions

View File

@ -161,10 +161,10 @@ class TestCudagraphDispatcher:
assert rt_mode == CUDAGraphMode.NONE
assert key == BatchDescriptor(num_tokens=15)
# 4. Cascade attention should have a fall back mode
# 4. disable_full should have a fall back mode (e.g., cascade attention)
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
)
if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE

View File

@ -292,7 +292,7 @@ def set_forward_context(
if num_tokens_across_dp is None:
assert ubatch_slices is None
assert num_tokens is not None
_, num_tokens_across_dp = coordinate_batch_across_dp(
_, num_tokens_across_dp, _ = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=vllm_config.parallel_config,
allow_microbatching=False,

View File

@ -145,7 +145,7 @@ class CudagraphDispatcher:
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
use_cascade_attn: bool = False,
disable_full: bool = False,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
"""
Given conditions(e.g.,batch descriptor and if using cascade attention),
@ -165,7 +165,7 @@ class CudagraphDispatcher:
)
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
if not use_cascade_attn:
if not disable_full:
# check if key exists for full cudagraph
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_desc

View File

@ -1258,7 +1258,7 @@ class EagleProposer:
num_tokens_padded: int,
) -> tuple[int, torch.Tensor]:
# TODO(Flechman): support DBO ubatching
should_ubatch, num_toks_across_dp = coordinate_batch_across_dp(
should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens_unpadded,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,

View File

@ -40,16 +40,18 @@ def _run_ar(
should_dp_pad: bool,
orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int,
cudagraph_mode: int,
parallel_config: ParallelConfig,
) -> torch.Tensor:
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
device, group = _get_device_and_group(parallel_config)
tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32)
tensor = torch.zeros(5, dp_size, device=device, dtype=torch.int32)
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0
tensor[3][dp_rank] = 1 if should_dp_pad else 0
tensor[4][dp_rank] = cudagraph_mode
dist.all_reduce(tensor, group=group)
return tensor
@ -89,13 +91,23 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
return num_tokens_across_dp.cpu()
def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
"""
Synchronize cudagraph_mode across DP ranks by taking the minimum.
If any rank has NONE (0), all ranks use NONE.
This ensures all ranks send consistent values (all padded or all unpadded).
"""
return int(tensor[4, :].min().item())
def _synchronize_dp_ranks(
num_tokens_unpadded: int,
num_tokens_padded: int,
should_attempt_ubatching: bool,
should_attempt_dp_padding: bool,
cudagraph_mode: int,
parallel_config: ParallelConfig,
) -> tuple[bool, torch.Tensor | None]:
) -> tuple[bool, torch.Tensor | None, int]:
"""
1. Decides if each DP rank is going to microbatch. Either all ranks
run with microbatching or none of them do.
@ -104,10 +116,13 @@ def _synchronize_dp_ranks(
When running microbatched or if should_attempt_dp_padding is True, all
ranks will be padded out so that the run with the same number of tokens
3. Synchronizes cudagraph_mode across ranks by taking the minimum.
Returns: tuple[
should_ubatch: Are all DP ranks going to microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including any DP padding.
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
]
"""
@ -121,6 +136,7 @@ def _synchronize_dp_ranks(
should_dp_pad=should_attempt_dp_padding,
orig_num_tokens_per_ubatch=num_tokens_unpadded,
padded_num_tokens_per_ubatch=num_tokens_padded,
cudagraph_mode=cudagraph_mode,
parallel_config=parallel_config,
)
@ -148,7 +164,10 @@ def _synchronize_dp_ranks(
should_dp_pad,
)
return should_ubatch, num_tokens_after_padding
# Synchronize cudagraph_mode across ranks (take min)
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode
def coordinate_batch_across_dp(
@ -159,7 +178,8 @@ def coordinate_batch_across_dp(
num_tokens_padded: int | None = None,
uniform_decode: bool | None = None,
num_scheduled_tokens_per_request: np.ndarray | None = None,
) -> tuple[bool, torch.Tensor | None]:
cudagraph_mode: int = 0,
) -> tuple[bool, torch.Tensor | None, int]:
"""
Coordinates amongst all DP ranks to determine if and how the full batch
should be split into microbatches.
@ -175,6 +195,7 @@ def coordinate_batch_across_dp(
only contains single token decodes
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
number of tokens per request.
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL)
Returns: tuple[
ubatch_slices: if this is set then all DP ranks have agreed to
@ -183,12 +204,13 @@ def coordinate_batch_across_dp(
tokens per-microbatch for each DP rank including padding. Will be
padded up to the max value across all DP ranks when allow_dp_padding
is True.
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
]
"""
if parallel_config.data_parallel_size == 1:
# Early exit.
return False, None
return False, None, cudagraph_mode
# If the caller has explicitly enabled microbatching.
should_attempt_ubatching = False
@ -204,12 +226,15 @@ def coordinate_batch_across_dp(
if num_tokens_padded is None:
num_tokens_padded = num_tokens_unpadded
(should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks(
num_tokens_unpadded,
num_tokens_padded,
should_attempt_ubatching,
allow_dp_padding,
parallel_config,
(should_ubatch, num_tokens_after_padding, synced_cudagraph_mode) = (
_synchronize_dp_ranks(
num_tokens_unpadded,
num_tokens_padded,
should_attempt_ubatching,
allow_dp_padding,
cudagraph_mode,
parallel_config,
)
)
return (should_ubatch, num_tokens_after_padding)
return (should_ubatch, num_tokens_after_padding, synced_cudagraph_mode)

View File

@ -2788,17 +2788,19 @@ class GPUModelRunner(
)
dispatch_cudagraph = (
lambda num_tokens: self.cudagraph_dispatcher.dispatch(
lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch(
num_tokens=num_tokens,
has_lora=has_lora,
use_cascade_attn=use_cascade_attn,
uniform_decode=uniform_decode,
disable_full=disable_full,
)
if not force_eager
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
)
cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded)
cudagraph_mode, batch_descriptor = dispatch_cudagraph(
num_tokens_padded, use_cascade_attn
)
num_tokens_padded = batch_descriptor.num_tokens
# Extra coordination when running data-parallel since we need to coordinate
@ -2813,23 +2815,28 @@ class GPUModelRunner(
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
)
should_ubatch, num_tokens_across_dp = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.parallel_config,
allow_microbatching=allow_microbatching,
allow_dp_padding=allow_dp_padding,
num_tokens_padded=num_tokens_padded,
uniform_decode=uniform_decode,
num_scheduled_tokens_per_request=num_scheduled_tokens_np,
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.parallel_config,
allow_microbatching=allow_microbatching,
allow_dp_padding=allow_dp_padding,
num_tokens_padded=num_tokens_padded,
uniform_decode=uniform_decode,
num_scheduled_tokens_per_request=num_scheduled_tokens_np,
cudagraph_mode=cudagraph_mode.value,
)
)
# Extract DP padding if there is any
# Extract DP-synced values
if num_tokens_across_dp is not None:
dp_rank = self.parallel_config.data_parallel_rank
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
# Re-dispatch with DP padding
cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded)
# Re-dispatch with DP padding so we have the correct batch_descriptor
cudagraph_mode, batch_descriptor = dispatch_cudagraph(
num_tokens_padded,
disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value,
)
# Assert to make sure the agreed upon token count is correct otherwise
# num_tokens_across_dp will no-longer be valid
assert batch_descriptor.num_tokens == num_tokens_padded