vllm/vllm/v1/worker/ubatch_utils.py
Lucas Wilkinson c8ab988b15
[BugFix] Fix DBO assert assert B_block_table == B_q (#29933)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
2025-12-04 14:48:54 -05:00

110 lines
3.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TypeAlias
import numpy as np
from vllm.config import ParallelConfig
@dataclass
class UBatchSlice:
request_slice: slice
token_slice: slice
def is_empty(self) -> bool:
return (
self.request_slice.start == self.request_slice.stop
or self.token_slice.start == self.token_slice.stop
)
@property
def num_tokens(self) -> int:
return self.token_slice.stop - self.token_slice.start
UBatchSlices: TypeAlias = list[UBatchSlice]
def is_second_ubatch_empty(orig_num_tokens: int, padded_num_tokens: int) -> bool:
return (padded_num_tokens // 2) >= orig_num_tokens
def check_ubatch_thresholds(
config: ParallelConfig, num_tokens: int, uniform_decode: bool
) -> bool:
if not config.enable_dbo:
return False
if uniform_decode:
return num_tokens >= config.dbo_decode_token_threshold
else:
return num_tokens >= config.dbo_prefill_token_threshold
# This just pads the second ubatch slice out to the total number of tokens
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
def _pad_out_ubatch_slices(
ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
) -> UBatchSlices:
# TODO(lucas): handle empty second ubatch
padded_second_request_slice = slice(
ubatch_slices[1].request_slice.start, num_reqs_padded
)
padded_second_token_slice = slice(
ubatch_slices[1].token_slice.start, num_total_tokens
)
return [
ubatch_slices[0],
UBatchSlice(padded_second_request_slice, padded_second_token_slice),
]
def maybe_create_ubatch_slices(
should_ubatch: bool,
num_scheduled_tokens: np.ndarray,
num_tokens_padded: int,
num_reqs_padded: int,
split_point: int | None = None,
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
if not should_ubatch:
return None, None
if split_point is None:
split_point = int(num_tokens_padded) // 2
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
# in cu_num_tokens directly (i.e. query_start_loc)
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
first_ubatch_token_slice = slice(0, split_point)
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
# Determine request slices using exclusive stop semantics
# First ubatch includes requests whose tokens overlap [0, split_point)
first_ubatch_req_stop = int(
np.searchsorted(cu_num_tokens, split_point, side="left")
)
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
# Second ubatch starts at the request that contains the split_point
# or the request starting exactly at split_point (if on boundary)
second_ubatch_req_start = int(
np.searchsorted(cu_num_tokens, split_point, side="right") - 1
)
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
ubatch_slices = [
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
]
ubatch_slices_padded = _pad_out_ubatch_slices(
ubatch_slices, num_tokens_padded, num_reqs_padded
)
assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded
return ubatch_slices, ubatch_slices_padded