mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 10:57:16 +08:00
[Bugfix] Make DP padding optional in coordinate_batch_across_dp (#26375)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
0e67102d93
commit
ae9d0e7da5
@ -12,6 +12,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlices
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -278,7 +279,19 @@ def set_forward_context(
|
||||
if vllm_config.parallel_config.data_parallel_size > 1 and (
|
||||
attn_metadata is not None or num_tokens is not None
|
||||
):
|
||||
assert num_tokens_across_dp is not None
|
||||
# If num_tokens_across_dp hasn't already been initialized, then
|
||||
# initialize it here. Both DP padding and Microbatching will be
|
||||
# disabled.
|
||||
if num_tokens_across_dp is None:
|
||||
assert ubatch_slices is None
|
||||
assert num_tokens is not None
|
||||
_, num_tokens_across_dp = coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens,
|
||||
parallel_config=vllm_config.parallel_config,
|
||||
allow_microbatching=False,
|
||||
allow_dp_padding=False,
|
||||
)
|
||||
assert num_tokens_across_dp is not None
|
||||
dp_metadata = DPMetadata.make(
|
||||
vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
|
||||
)
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.worker.ubatch_utils import (
|
||||
@ -37,6 +37,7 @@ def _get_device_and_group(parallel_config: ParallelConfig):
|
||||
|
||||
def _run_ar(
|
||||
should_ubatch: bool,
|
||||
should_dp_pad: bool,
|
||||
orig_num_tokens_per_ubatch: int,
|
||||
padded_num_tokens_per_ubatch: int,
|
||||
parallel_config: ParallelConfig,
|
||||
@ -44,10 +45,11 @@ def _run_ar(
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
device, group = _get_device_and_group(parallel_config)
|
||||
tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
|
||||
tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32)
|
||||
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
|
||||
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
|
||||
tensor[2][dp_rank] = 1 if should_ubatch else 0
|
||||
tensor[3][dp_rank] = 1 if should_dp_pad else 0
|
||||
dist.all_reduce(tensor, group=group)
|
||||
return tensor
|
||||
|
||||
@ -72,10 +74,26 @@ def _post_process_ubatch(tensor: torch.Tensor) -> bool:
|
||||
return should_ubatch
|
||||
|
||||
|
||||
def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch.Tensor:
|
||||
num_tokens_across_dp = tensor[1, :]
|
||||
if should_dp_pad:
|
||||
# If DP padding is enabled, ensure that each rank is processing the same number
|
||||
# of tokens
|
||||
max_num_tokens = int(num_tokens_across_dp.max().item())
|
||||
return torch.tensor(
|
||||
[max_num_tokens] * len(num_tokens_across_dp),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
else:
|
||||
return num_tokens_across_dp.cpu()
|
||||
|
||||
|
||||
def _synchronize_dp_ranks(
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
should_attempt_ubatching: bool,
|
||||
should_attempt_dp_padding: bool,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> tuple[bool, Optional[torch.Tensor]]:
|
||||
"""
|
||||
@ -83,57 +101,88 @@ def _synchronize_dp_ranks(
|
||||
run with microbatching or none of them do.
|
||||
|
||||
2. Determines the total number of tokens that each rank will run.
|
||||
All ranks will be padded out so that the run with the same number
|
||||
of tokens
|
||||
When running microbatched or if should_attempt_dp_padding is True, all
|
||||
ranks will be padded out so that the run with the same number of tokens
|
||||
|
||||
Returns: tuple[
|
||||
should_ubatch: Are all DP ranks going to microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding.
|
||||
tokens per-microbatch for each DP rank including any DP padding.
|
||||
]
|
||||
|
||||
"""
|
||||
assert num_tokens_padded >= num_tokens_unpadded
|
||||
|
||||
# First we coordinate between the DP ranks via an All Reduce
|
||||
# Coordinate between the DP ranks via an All Reduce
|
||||
# to determine the total number of tokens that each rank
|
||||
# will run and if we are using ubatching or not.
|
||||
tensor = _run_ar(
|
||||
should_ubatch=should_attempt_ubatching,
|
||||
should_dp_pad=should_attempt_dp_padding,
|
||||
orig_num_tokens_per_ubatch=num_tokens_unpadded,
|
||||
padded_num_tokens_per_ubatch=num_tokens_padded,
|
||||
parallel_config=parallel_config,
|
||||
)
|
||||
|
||||
# Ensure that each rank is processing the same nuber of tokens
|
||||
num_tokens_across_dp = tensor[1, :]
|
||||
max_num_tokens = int(num_tokens_across_dp.max().item())
|
||||
num_tokens_after_padding = torch.tensor(
|
||||
[max_num_tokens] * len(num_tokens_across_dp), device="cpu", dtype=torch.int32
|
||||
)
|
||||
should_dp_pad = bool(torch.all(tensor[3] == 1).item())
|
||||
|
||||
# DP ranks should all have the same value for should_attempt_dp_padding.
|
||||
assert should_attempt_dp_padding == should_dp_pad
|
||||
|
||||
# Check conditions for microbatching
|
||||
should_ubatch = _post_process_ubatch(tensor)
|
||||
|
||||
if should_ubatch and not should_dp_pad:
|
||||
if is_global_first_rank():
|
||||
logger.debug(
|
||||
"Microbatching has been triggered and requires DP padding. "
|
||||
"Enabling DP padding even though it has been explicitly "
|
||||
"disabled."
|
||||
)
|
||||
should_dp_pad = True
|
||||
|
||||
# Pad all DP ranks up to the maximum token count across ranks if
|
||||
# should_dp_pad is True
|
||||
num_tokens_after_padding = _post_process_dp_padding(
|
||||
tensor,
|
||||
should_dp_pad,
|
||||
)
|
||||
|
||||
return should_ubatch, num_tokens_after_padding
|
||||
|
||||
|
||||
def coordinate_batch_across_dp(
|
||||
num_scheduled_tokens_per_request: np.ndarray,
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
parallel_config: ParallelConfig,
|
||||
allow_microbatching: bool,
|
||||
uniform_decode: bool,
|
||||
allow_dp_padding: bool,
|
||||
parallel_config: ParallelConfig,
|
||||
num_tokens_padded: Optional[int] = None,
|
||||
uniform_decode: Optional[bool] = None,
|
||||
num_scheduled_tokens_per_request: Optional[np.ndarray] = None,
|
||||
) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Coordinates amongst all DP ranks to determine if and how the full batch
|
||||
should be split into microbatches.
|
||||
|
||||
Args:
|
||||
num_tokens_unpadded: Number of tokens without accounting for padding
|
||||
allow_microbatching: If microbatching should be attempted
|
||||
allow_dp_padding: If all DP ranks should be padded up to the same value
|
||||
parallel_config: The parallel config
|
||||
num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs,
|
||||
TP, etc)
|
||||
uniform_decode: Only used if allow_microbatching is True. True if the batch
|
||||
only contains single token decodes
|
||||
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
|
||||
number of tokens per request.
|
||||
|
||||
Returns: tuple[
|
||||
ubatch_slices: if this is set then all DP ranks have agreed to
|
||||
microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding.
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
padded up to the max value across all DP ranks when allow_dp_padding
|
||||
is True.
|
||||
]
|
||||
|
||||
"""
|
||||
@ -141,21 +190,25 @@ def coordinate_batch_across_dp(
|
||||
# Early exit.
|
||||
return None, None
|
||||
|
||||
# Check preconditions for microbatching
|
||||
should_attempt_ubatching = check_ubatch_thresholds(
|
||||
parallel_config,
|
||||
num_tokens_unpadded,
|
||||
uniform_decode=uniform_decode,
|
||||
)
|
||||
# If the caller has explicitly enabled microbatching.
|
||||
should_attempt_ubatching = False
|
||||
if allow_microbatching:
|
||||
# Check preconditions for microbatching
|
||||
assert uniform_decode is not None
|
||||
should_attempt_ubatching = check_ubatch_thresholds(
|
||||
parallel_config,
|
||||
num_tokens_unpadded,
|
||||
uniform_decode=uniform_decode,
|
||||
)
|
||||
|
||||
# If the caller has explicitly disabled microbatching.
|
||||
if not allow_microbatching:
|
||||
should_attempt_ubatching = False
|
||||
if num_tokens_padded is None:
|
||||
num_tokens_padded = num_tokens_unpadded
|
||||
|
||||
(should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks(
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
should_attempt_ubatching,
|
||||
allow_dp_padding,
|
||||
parallel_config,
|
||||
)
|
||||
|
||||
@ -170,6 +223,7 @@ def coordinate_batch_across_dp(
|
||||
assert num_tokens_after_padding is not None
|
||||
token_split_point = int(num_tokens_after_padding[0].item()) // 2
|
||||
|
||||
assert num_scheduled_tokens_per_request is not None
|
||||
ubatch_slices = create_ubatch_slices(
|
||||
num_scheduled_tokens_per_request, token_split_point
|
||||
)
|
||||
|
||||
@ -1178,13 +1178,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
uniform_decode = (
|
||||
max_num_scheduled_tokens == self.uniform_decode_query_len
|
||||
) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
|
||||
|
||||
# Disable DP padding when running eager to avoid excessive padding when
|
||||
# running prefills. This lets us set enforce_eager on the prefiller in
|
||||
# a P/D setup and still use CUDA graphs (enabled by this padding) on the
|
||||
# decoder.
|
||||
allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
|
||||
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
|
||||
num_scheduled_tokens,
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
self.parallel_config,
|
||||
True,
|
||||
uniform_decode,
|
||||
num_tokens_unpadded=num_tokens_unpadded,
|
||||
parallel_config=self.parallel_config,
|
||||
allow_microbatching=True,
|
||||
allow_dp_padding=allow_dp_padding,
|
||||
num_tokens_padded=num_tokens_padded,
|
||||
uniform_decode=uniform_decode,
|
||||
num_scheduled_tokens_per_request=num_scheduled_tokens,
|
||||
)
|
||||
|
||||
self.seq_lens.np[:num_reqs] = (
|
||||
@ -2436,12 +2444,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
use_cascade_attn,
|
||||
) = self._prepare_inputs(scheduler_output)
|
||||
|
||||
dp_rank = self.parallel_config.data_parallel_rank
|
||||
if ubatch_slices:
|
||||
assert num_tokens_across_dp is not None
|
||||
num_input_tokens = int(num_tokens_across_dp[0].item())
|
||||
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
|
||||
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
|
||||
elif num_tokens_across_dp is not None:
|
||||
num_input_tokens = int(num_tokens_across_dp[0].item())
|
||||
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
|
||||
else:
|
||||
num_input_tokens = self._get_num_input_tokens(
|
||||
scheduler_output.total_num_scheduled_tokens
|
||||
@ -3256,19 +3265,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
|
||||
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
|
||||
|
||||
# Disable DP padding when running eager
|
||||
allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
|
||||
# We currently only microbatch if the number of tokens is
|
||||
# over a certain threshold.
|
||||
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
|
||||
num_scheduled_tokens,
|
||||
total_num_scheduled_tokens,
|
||||
total_num_scheduled_tokens,
|
||||
self.vllm_config.parallel_config,
|
||||
allow_microbatching,
|
||||
uniform_decode,
|
||||
num_tokens_unpadded=total_num_scheduled_tokens,
|
||||
parallel_config=self.vllm_config.parallel_config,
|
||||
allow_microbatching=allow_microbatching,
|
||||
allow_dp_padding=allow_dp_padding,
|
||||
num_tokens_padded=total_num_scheduled_tokens,
|
||||
uniform_decode=uniform_decode,
|
||||
num_scheduled_tokens_per_request=num_scheduled_tokens,
|
||||
)
|
||||
num_tokens_after_padding = num_tokens
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_after_padding = int(num_tokens_across_dp[0])
|
||||
dp_rank = self.parallel_config.data_parallel_rank
|
||||
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
|
||||
|
||||
attn_metadata: Optional[PerLayerAttnMetadata] = None
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user