[BugFix] Fix assert batch_descriptor.num_tokens == num_tokens_padded (#30173)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-12-09 10:36:12 -05:00 committed by GitHub
parent 5dcd593baf
commit 56037dfa2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 65 additions and 33 deletions

View File

@ -161,10 +161,10 @@ class TestCudagraphDispatcher:
assert rt_mode == CUDAGraphMode.NONE assert rt_mode == CUDAGraphMode.NONE
assert key == BatchDescriptor(num_tokens=15) assert key == BatchDescriptor(num_tokens=15)
# 4. Cascade attention should have a fall back mode # 4. disable_full should have a fall back mode (e.g., cascade attention)
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False) desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
rt_mode, key = dispatcher.dispatch( rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
) )
if "PIECEWISE" in cudagraph_mode_str: # string contains check if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE assert rt_mode == CUDAGraphMode.PIECEWISE

View File

@ -292,7 +292,7 @@ def set_forward_context(
if num_tokens_across_dp is None: if num_tokens_across_dp is None:
assert ubatch_slices is None assert ubatch_slices is None
assert num_tokens is not None assert num_tokens is not None
_, num_tokens_across_dp = coordinate_batch_across_dp( _, num_tokens_across_dp, _ = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_unpadded=num_tokens,
parallel_config=vllm_config.parallel_config, parallel_config=vllm_config.parallel_config,
allow_microbatching=False, allow_microbatching=False,

View File

@ -145,7 +145,7 @@ class CudagraphDispatcher:
num_tokens: int, num_tokens: int,
uniform_decode: bool, uniform_decode: bool,
has_lora: bool, has_lora: bool,
use_cascade_attn: bool = False, disable_full: bool = False,
) -> tuple[CUDAGraphMode, BatchDescriptor]: ) -> tuple[CUDAGraphMode, BatchDescriptor]:
""" """
Given conditions(e.g.,batch descriptor and if using cascade attention), Given conditions(e.g.,batch descriptor and if using cascade attention),
@ -165,7 +165,7 @@ class CudagraphDispatcher:
) )
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs() relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
if not use_cascade_attn: if not disable_full:
# check if key exists for full cudagraph # check if key exists for full cudagraph
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]: if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_desc return CUDAGraphMode.FULL, batch_desc

View File

@ -1258,7 +1258,7 @@ class EagleProposer:
num_tokens_padded: int, num_tokens_padded: int,
) -> tuple[int, torch.Tensor]: ) -> tuple[int, torch.Tensor]:
# TODO(Flechman): support DBO ubatching # TODO(Flechman): support DBO ubatching
should_ubatch, num_toks_across_dp = coordinate_batch_across_dp( should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens_unpadded, num_tokens_unpadded=num_tokens_unpadded,
parallel_config=self.vllm_config.parallel_config, parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False, allow_microbatching=False,

View File

@ -40,16 +40,18 @@ def _run_ar(
should_dp_pad: bool, should_dp_pad: bool,
orig_num_tokens_per_ubatch: int, orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int, padded_num_tokens_per_ubatch: int,
cudagraph_mode: int,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
) -> torch.Tensor: ) -> torch.Tensor:
dp_size = parallel_config.data_parallel_size dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank dp_rank = parallel_config.data_parallel_rank
device, group = _get_device_and_group(parallel_config) device, group = _get_device_and_group(parallel_config)
tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32) tensor = torch.zeros(5, dp_size, device=device, dtype=torch.int32)
tensor[0][dp_rank] = orig_num_tokens_per_ubatch tensor[0][dp_rank] = orig_num_tokens_per_ubatch
tensor[1][dp_rank] = padded_num_tokens_per_ubatch tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0 tensor[2][dp_rank] = 1 if should_ubatch else 0
tensor[3][dp_rank] = 1 if should_dp_pad else 0 tensor[3][dp_rank] = 1 if should_dp_pad else 0
tensor[4][dp_rank] = cudagraph_mode
dist.all_reduce(tensor, group=group) dist.all_reduce(tensor, group=group)
return tensor return tensor
@ -89,13 +91,23 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
return num_tokens_across_dp.cpu() return num_tokens_across_dp.cpu()
def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
"""
Synchronize cudagraph_mode across DP ranks by taking the minimum.
If any rank has NONE (0), all ranks use NONE.
This ensures all ranks send consistent values (all padded or all unpadded).
"""
return int(tensor[4, :].min().item())
def _synchronize_dp_ranks( def _synchronize_dp_ranks(
num_tokens_unpadded: int, num_tokens_unpadded: int,
num_tokens_padded: int, num_tokens_padded: int,
should_attempt_ubatching: bool, should_attempt_ubatching: bool,
should_attempt_dp_padding: bool, should_attempt_dp_padding: bool,
cudagraph_mode: int,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
) -> tuple[bool, torch.Tensor | None]: ) -> tuple[bool, torch.Tensor | None, int]:
""" """
1. Decides if each DP rank is going to microbatch. Either all ranks 1. Decides if each DP rank is going to microbatch. Either all ranks
run with microbatching or none of them do. run with microbatching or none of them do.
@ -104,10 +116,13 @@ def _synchronize_dp_ranks(
When running microbatched or if should_attempt_dp_padding is True, all When running microbatched or if should_attempt_dp_padding is True, all
ranks will be padded out so that the run with the same number of tokens ranks will be padded out so that the run with the same number of tokens
3. Synchronizes cudagraph_mode across ranks by taking the minimum.
Returns: tuple[ Returns: tuple[
should_ubatch: Are all DP ranks going to microbatch should_ubatch: Are all DP ranks going to microbatch
num_tokens_after_padding: A tensor containing the total number of num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including any DP padding. tokens per-microbatch for each DP rank including any DP padding.
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
] ]
""" """
@ -121,6 +136,7 @@ def _synchronize_dp_ranks(
should_dp_pad=should_attempt_dp_padding, should_dp_pad=should_attempt_dp_padding,
orig_num_tokens_per_ubatch=num_tokens_unpadded, orig_num_tokens_per_ubatch=num_tokens_unpadded,
padded_num_tokens_per_ubatch=num_tokens_padded, padded_num_tokens_per_ubatch=num_tokens_padded,
cudagraph_mode=cudagraph_mode,
parallel_config=parallel_config, parallel_config=parallel_config,
) )
@ -148,7 +164,10 @@ def _synchronize_dp_ranks(
should_dp_pad, should_dp_pad,
) )
return should_ubatch, num_tokens_after_padding # Synchronize cudagraph_mode across ranks (take min)
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode
def coordinate_batch_across_dp( def coordinate_batch_across_dp(
@ -159,7 +178,8 @@ def coordinate_batch_across_dp(
num_tokens_padded: int | None = None, num_tokens_padded: int | None = None,
uniform_decode: bool | None = None, uniform_decode: bool | None = None,
num_scheduled_tokens_per_request: np.ndarray | None = None, num_scheduled_tokens_per_request: np.ndarray | None = None,
) -> tuple[bool, torch.Tensor | None]: cudagraph_mode: int = 0,
) -> tuple[bool, torch.Tensor | None, int]:
""" """
Coordinates amongst all DP ranks to determine if and how the full batch Coordinates amongst all DP ranks to determine if and how the full batch
should be split into microbatches. should be split into microbatches.
@ -175,6 +195,7 @@ def coordinate_batch_across_dp(
only contains single token decodes only contains single token decodes
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
number of tokens per request. number of tokens per request.
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL)
Returns: tuple[ Returns: tuple[
ubatch_slices: if this is set then all DP ranks have agreed to ubatch_slices: if this is set then all DP ranks have agreed to
@ -183,12 +204,13 @@ def coordinate_batch_across_dp(
tokens per-microbatch for each DP rank including padding. Will be tokens per-microbatch for each DP rank including padding. Will be
padded up to the max value across all DP ranks when allow_dp_padding padded up to the max value across all DP ranks when allow_dp_padding
is True. is True.
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
] ]
""" """
if parallel_config.data_parallel_size == 1: if parallel_config.data_parallel_size == 1:
# Early exit. # Early exit.
return False, None return False, None, cudagraph_mode
# If the caller has explicitly enabled microbatching. # If the caller has explicitly enabled microbatching.
should_attempt_ubatching = False should_attempt_ubatching = False
@ -204,12 +226,15 @@ def coordinate_batch_across_dp(
if num_tokens_padded is None: if num_tokens_padded is None:
num_tokens_padded = num_tokens_unpadded num_tokens_padded = num_tokens_unpadded
(should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks( (should_ubatch, num_tokens_after_padding, synced_cudagraph_mode) = (
num_tokens_unpadded, _synchronize_dp_ranks(
num_tokens_padded, num_tokens_unpadded,
should_attempt_ubatching, num_tokens_padded,
allow_dp_padding, should_attempt_ubatching,
parallel_config, allow_dp_padding,
cudagraph_mode,
parallel_config,
)
) )
return (should_ubatch, num_tokens_after_padding) return (should_ubatch, num_tokens_after_padding, synced_cudagraph_mode)

View File

@ -2788,17 +2788,19 @@ class GPUModelRunner(
) )
dispatch_cudagraph = ( dispatch_cudagraph = (
lambda num_tokens: self.cudagraph_dispatcher.dispatch( lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch(
num_tokens=num_tokens, num_tokens=num_tokens,
has_lora=has_lora, has_lora=has_lora,
use_cascade_attn=use_cascade_attn,
uniform_decode=uniform_decode, uniform_decode=uniform_decode,
disable_full=disable_full,
) )
if not force_eager if not force_eager
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded)) else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
) )
cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded) cudagraph_mode, batch_descriptor = dispatch_cudagraph(
num_tokens_padded, use_cascade_attn
)
num_tokens_padded = batch_descriptor.num_tokens num_tokens_padded = batch_descriptor.num_tokens
# Extra coordination when running data-parallel since we need to coordinate # Extra coordination when running data-parallel since we need to coordinate
@ -2813,23 +2815,28 @@ class GPUModelRunner(
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
) )
should_ubatch, num_tokens_across_dp = coordinate_batch_across_dp( should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
num_tokens_unpadded=num_tokens, coordinate_batch_across_dp(
parallel_config=self.parallel_config, num_tokens_unpadded=num_tokens,
allow_microbatching=allow_microbatching, parallel_config=self.parallel_config,
allow_dp_padding=allow_dp_padding, allow_microbatching=allow_microbatching,
num_tokens_padded=num_tokens_padded, allow_dp_padding=allow_dp_padding,
uniform_decode=uniform_decode, num_tokens_padded=num_tokens_padded,
num_scheduled_tokens_per_request=num_scheduled_tokens_np, uniform_decode=uniform_decode,
num_scheduled_tokens_per_request=num_scheduled_tokens_np,
cudagraph_mode=cudagraph_mode.value,
)
) )
# Extract DP padding if there is any # Extract DP-synced values
if num_tokens_across_dp is not None: if num_tokens_across_dp is not None:
dp_rank = self.parallel_config.data_parallel_rank dp_rank = self.parallel_config.data_parallel_rank
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item()) num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
# Re-dispatch with DP padding so we have the correct batch_descriptor
# Re-dispatch with DP padding cudagraph_mode, batch_descriptor = dispatch_cudagraph(
cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded) num_tokens_padded,
disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value,
)
# Assert to make sure the agreed upon token count is correct otherwise # Assert to make sure the agreed upon token count is correct otherwise
# num_tokens_across_dp will no-longer be valid # num_tokens_across_dp will no-longer be valid
assert batch_descriptor.num_tokens == num_tokens_padded assert batch_descriptor.num_tokens == num_tokens_padded