mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 15:35:01 +08:00
[BugFix] Fix assert batch_descriptor.num_tokens == num_tokens_padded (#30173)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
5dcd593baf
commit
56037dfa2f
@ -161,10 +161,10 @@ class TestCudagraphDispatcher:
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
assert key == BatchDescriptor(num_tokens=15)
|
||||
|
||||
# 4. Cascade attention should have a fall back mode
|
||||
# 4. disable_full should have a fall back mode (e.g., cascade attention)
|
||||
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
|
||||
rt_mode, key = dispatcher.dispatch(
|
||||
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
|
||||
num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
|
||||
)
|
||||
if "PIECEWISE" in cudagraph_mode_str: # string contains check
|
||||
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
@ -292,7 +292,7 @@ def set_forward_context(
|
||||
if num_tokens_across_dp is None:
|
||||
assert ubatch_slices is None
|
||||
assert num_tokens is not None
|
||||
_, num_tokens_across_dp = coordinate_batch_across_dp(
|
||||
_, num_tokens_across_dp, _ = coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens,
|
||||
parallel_config=vllm_config.parallel_config,
|
||||
allow_microbatching=False,
|
||||
|
||||
@ -145,7 +145,7 @@ class CudagraphDispatcher:
|
||||
num_tokens: int,
|
||||
uniform_decode: bool,
|
||||
has_lora: bool,
|
||||
use_cascade_attn: bool = False,
|
||||
disable_full: bool = False,
|
||||
) -> tuple[CUDAGraphMode, BatchDescriptor]:
|
||||
"""
|
||||
Given conditions(e.g.,batch descriptor and if using cascade attention),
|
||||
@ -165,7 +165,7 @@ class CudagraphDispatcher:
|
||||
)
|
||||
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
|
||||
|
||||
if not use_cascade_attn:
|
||||
if not disable_full:
|
||||
# check if key exists for full cudagraph
|
||||
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||
return CUDAGraphMode.FULL, batch_desc
|
||||
|
||||
@ -1258,7 +1258,7 @@ class EagleProposer:
|
||||
num_tokens_padded: int,
|
||||
) -> tuple[int, torch.Tensor]:
|
||||
# TODO(Flechman): support DBO ubatching
|
||||
should_ubatch, num_toks_across_dp = coordinate_batch_across_dp(
|
||||
should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens_unpadded,
|
||||
parallel_config=self.vllm_config.parallel_config,
|
||||
allow_microbatching=False,
|
||||
|
||||
@ -40,16 +40,18 @@ def _run_ar(
|
||||
should_dp_pad: bool,
|
||||
orig_num_tokens_per_ubatch: int,
|
||||
padded_num_tokens_per_ubatch: int,
|
||||
cudagraph_mode: int,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> torch.Tensor:
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
device, group = _get_device_and_group(parallel_config)
|
||||
tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32)
|
||||
tensor = torch.zeros(5, dp_size, device=device, dtype=torch.int32)
|
||||
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
|
||||
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
|
||||
tensor[2][dp_rank] = 1 if should_ubatch else 0
|
||||
tensor[3][dp_rank] = 1 if should_dp_pad else 0
|
||||
tensor[4][dp_rank] = cudagraph_mode
|
||||
dist.all_reduce(tensor, group=group)
|
||||
return tensor
|
||||
|
||||
@ -89,13 +91,23 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
|
||||
return num_tokens_across_dp.cpu()
|
||||
|
||||
|
||||
def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Synchronize cudagraph_mode across DP ranks by taking the minimum.
|
||||
If any rank has NONE (0), all ranks use NONE.
|
||||
This ensures all ranks send consistent values (all padded or all unpadded).
|
||||
"""
|
||||
return int(tensor[4, :].min().item())
|
||||
|
||||
|
||||
def _synchronize_dp_ranks(
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
should_attempt_ubatching: bool,
|
||||
should_attempt_dp_padding: bool,
|
||||
cudagraph_mode: int,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> tuple[bool, torch.Tensor | None]:
|
||||
) -> tuple[bool, torch.Tensor | None, int]:
|
||||
"""
|
||||
1. Decides if each DP rank is going to microbatch. Either all ranks
|
||||
run with microbatching or none of them do.
|
||||
@ -104,10 +116,13 @@ def _synchronize_dp_ranks(
|
||||
When running microbatched or if should_attempt_dp_padding is True, all
|
||||
ranks will be padded out so that the run with the same number of tokens
|
||||
|
||||
3. Synchronizes cudagraph_mode across ranks by taking the minimum.
|
||||
|
||||
Returns: tuple[
|
||||
should_ubatch: Are all DP ranks going to microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including any DP padding.
|
||||
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
|
||||
]
|
||||
|
||||
"""
|
||||
@ -121,6 +136,7 @@ def _synchronize_dp_ranks(
|
||||
should_dp_pad=should_attempt_dp_padding,
|
||||
orig_num_tokens_per_ubatch=num_tokens_unpadded,
|
||||
padded_num_tokens_per_ubatch=num_tokens_padded,
|
||||
cudagraph_mode=cudagraph_mode,
|
||||
parallel_config=parallel_config,
|
||||
)
|
||||
|
||||
@ -148,7 +164,10 @@ def _synchronize_dp_ranks(
|
||||
should_dp_pad,
|
||||
)
|
||||
|
||||
return should_ubatch, num_tokens_after_padding
|
||||
# Synchronize cudagraph_mode across ranks (take min)
|
||||
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
|
||||
|
||||
return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode
|
||||
|
||||
|
||||
def coordinate_batch_across_dp(
|
||||
@ -159,7 +178,8 @@ def coordinate_batch_across_dp(
|
||||
num_tokens_padded: int | None = None,
|
||||
uniform_decode: bool | None = None,
|
||||
num_scheduled_tokens_per_request: np.ndarray | None = None,
|
||||
) -> tuple[bool, torch.Tensor | None]:
|
||||
cudagraph_mode: int = 0,
|
||||
) -> tuple[bool, torch.Tensor | None, int]:
|
||||
"""
|
||||
Coordinates amongst all DP ranks to determine if and how the full batch
|
||||
should be split into microbatches.
|
||||
@ -175,6 +195,7 @@ def coordinate_batch_across_dp(
|
||||
only contains single token decodes
|
||||
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
|
||||
number of tokens per request.
|
||||
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL)
|
||||
|
||||
Returns: tuple[
|
||||
ubatch_slices: if this is set then all DP ranks have agreed to
|
||||
@ -183,12 +204,13 @@ def coordinate_batch_across_dp(
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
padded up to the max value across all DP ranks when allow_dp_padding
|
||||
is True.
|
||||
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
|
||||
]
|
||||
|
||||
"""
|
||||
if parallel_config.data_parallel_size == 1:
|
||||
# Early exit.
|
||||
return False, None
|
||||
return False, None, cudagraph_mode
|
||||
|
||||
# If the caller has explicitly enabled microbatching.
|
||||
should_attempt_ubatching = False
|
||||
@ -204,12 +226,15 @@ def coordinate_batch_across_dp(
|
||||
if num_tokens_padded is None:
|
||||
num_tokens_padded = num_tokens_unpadded
|
||||
|
||||
(should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks(
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
should_attempt_ubatching,
|
||||
allow_dp_padding,
|
||||
parallel_config,
|
||||
(should_ubatch, num_tokens_after_padding, synced_cudagraph_mode) = (
|
||||
_synchronize_dp_ranks(
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
should_attempt_ubatching,
|
||||
allow_dp_padding,
|
||||
cudagraph_mode,
|
||||
parallel_config,
|
||||
)
|
||||
)
|
||||
|
||||
return (should_ubatch, num_tokens_after_padding)
|
||||
return (should_ubatch, num_tokens_after_padding, synced_cudagraph_mode)
|
||||
|
||||
@ -2788,17 +2788,19 @@ class GPUModelRunner(
|
||||
)
|
||||
|
||||
dispatch_cudagraph = (
|
||||
lambda num_tokens: self.cudagraph_dispatcher.dispatch(
|
||||
lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens=num_tokens,
|
||||
has_lora=has_lora,
|
||||
use_cascade_attn=use_cascade_attn,
|
||||
uniform_decode=uniform_decode,
|
||||
disable_full=disable_full,
|
||||
)
|
||||
if not force_eager
|
||||
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
|
||||
)
|
||||
|
||||
cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded)
|
||||
cudagraph_mode, batch_descriptor = dispatch_cudagraph(
|
||||
num_tokens_padded, use_cascade_attn
|
||||
)
|
||||
num_tokens_padded = batch_descriptor.num_tokens
|
||||
|
||||
# Extra coordination when running data-parallel since we need to coordinate
|
||||
@ -2813,23 +2815,28 @@ class GPUModelRunner(
|
||||
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
)
|
||||
|
||||
should_ubatch, num_tokens_across_dp = coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens,
|
||||
parallel_config=self.parallel_config,
|
||||
allow_microbatching=allow_microbatching,
|
||||
allow_dp_padding=allow_dp_padding,
|
||||
num_tokens_padded=num_tokens_padded,
|
||||
uniform_decode=uniform_decode,
|
||||
num_scheduled_tokens_per_request=num_scheduled_tokens_np,
|
||||
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens,
|
||||
parallel_config=self.parallel_config,
|
||||
allow_microbatching=allow_microbatching,
|
||||
allow_dp_padding=allow_dp_padding,
|
||||
num_tokens_padded=num_tokens_padded,
|
||||
uniform_decode=uniform_decode,
|
||||
num_scheduled_tokens_per_request=num_scheduled_tokens_np,
|
||||
cudagraph_mode=cudagraph_mode.value,
|
||||
)
|
||||
)
|
||||
|
||||
# Extract DP padding if there is any
|
||||
# Extract DP-synced values
|
||||
if num_tokens_across_dp is not None:
|
||||
dp_rank = self.parallel_config.data_parallel_rank
|
||||
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
|
||||
|
||||
# Re-dispatch with DP padding
|
||||
cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded)
|
||||
# Re-dispatch with DP padding so we have the correct batch_descriptor
|
||||
cudagraph_mode, batch_descriptor = dispatch_cudagraph(
|
||||
num_tokens_padded,
|
||||
disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value,
|
||||
)
|
||||
# Assert to make sure the agreed upon token count is correct otherwise
|
||||
# num_tokens_across_dp will no-longer be valid
|
||||
assert batch_descriptor.num_tokens == num_tokens_padded
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user