mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 23:05:32 +08:00
[feature] extend DBO to XBO (#30120)
Signed-off-by: jiangkuaixue123 <jiangxiaozhou111@163.com> Co-authored-by: root <root@hk01dgx028.cm.cluster>
This commit is contained in:
parent
c881db364e
commit
b9ff4f2a8d
@ -323,6 +323,7 @@ def test_prefill_split_across_ubatches(
|
|||||||
num_tokens,
|
num_tokens,
|
||||||
batch_spec.batch_size,
|
batch_spec.batch_size,
|
||||||
split_point=split_point,
|
split_point=split_point,
|
||||||
|
num_ubatches=2,
|
||||||
)
|
)
|
||||||
assert ubatch_slices is not None and len(ubatch_slices) == 2
|
assert ubatch_slices is not None and len(ubatch_slices) == 2
|
||||||
|
|
||||||
|
|||||||
@ -156,6 +156,8 @@ class ParallelConfig:
|
|||||||
|
|
||||||
enable_dbo: bool = False
|
enable_dbo: bool = False
|
||||||
"""Enable dual batch overlap for the model executor."""
|
"""Enable dual batch overlap for the model executor."""
|
||||||
|
ubatch_size: int = 0
|
||||||
|
"""Number of ubatch size."""
|
||||||
|
|
||||||
dbo_decode_token_threshold: int = 32
|
dbo_decode_token_threshold: int = 32
|
||||||
"""The threshold for dual batch overlap for batches only containing decodes.
|
"""The threshold for dual batch overlap for batches only containing decodes.
|
||||||
@ -325,6 +327,14 @@ class ParallelConfig:
|
|||||||
including data parallelism."""
|
including data parallelism."""
|
||||||
return self.world_size * self.data_parallel_size
|
return self.world_size * self.data_parallel_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_ubatching(self) -> bool:
|
||||||
|
return self.enable_dbo or self.ubatch_size > 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_ubatches(self) -> int:
|
||||||
|
return 2 if self.enable_dbo else self.ubatch_size
|
||||||
|
|
||||||
def get_next_dp_init_port(self) -> int:
|
def get_next_dp_init_port(self) -> int:
|
||||||
"""
|
"""
|
||||||
We might need to initialize process groups in multiple
|
We might need to initialize process groups in multiple
|
||||||
|
|||||||
@ -870,9 +870,12 @@ class VllmConfig:
|
|||||||
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.parallel_config.enable_dbo:
|
if self.parallel_config.use_ubatching:
|
||||||
a2a_backend = self.parallel_config.all2all_backend
|
a2a_backend = self.parallel_config.all2all_backend
|
||||||
assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], (
|
assert a2a_backend in [
|
||||||
|
"deepep_low_latency",
|
||||||
|
"deepep_high_throughput",
|
||||||
|
], (
|
||||||
"Microbatching currently only supports the deepep_low_latency and "
|
"Microbatching currently only supports the deepep_low_latency and "
|
||||||
f"deepep_high_throughput all2all backend. {a2a_backend} is not "
|
f"deepep_high_throughput all2all backend. {a2a_backend} is not "
|
||||||
"supported. To fix use --all2all-backend=deepep_low_latency or "
|
"supported. To fix use --all2all-backend=deepep_low_latency or "
|
||||||
|
|||||||
@ -408,6 +408,7 @@ class EngineArgs:
|
|||||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||||
all2all_backend: str | None = ParallelConfig.all2all_backend
|
all2all_backend: str | None = ParallelConfig.all2all_backend
|
||||||
enable_dbo: bool = ParallelConfig.enable_dbo
|
enable_dbo: bool = ParallelConfig.enable_dbo
|
||||||
|
ubatch_size: int = ParallelConfig.ubatch_size
|
||||||
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
|
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
|
||||||
dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
|
dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
|
||||||
disable_nccl_for_dp_synchronization: bool = (
|
disable_nccl_for_dp_synchronization: bool = (
|
||||||
@ -841,6 +842,10 @@ class EngineArgs:
|
|||||||
"--all2all-backend", **parallel_kwargs["all2all_backend"]
|
"--all2all-backend", **parallel_kwargs["all2all_backend"]
|
||||||
)
|
)
|
||||||
parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
|
parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
|
||||||
|
parallel_group.add_argument(
|
||||||
|
"--ubatch-size",
|
||||||
|
**parallel_kwargs["ubatch_size"],
|
||||||
|
)
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--dbo-decode-token-threshold",
|
"--dbo-decode-token-threshold",
|
||||||
**parallel_kwargs["dbo_decode_token_threshold"],
|
**parallel_kwargs["dbo_decode_token_threshold"],
|
||||||
@ -1557,6 +1562,7 @@ class EngineArgs:
|
|||||||
enable_expert_parallel=self.enable_expert_parallel,
|
enable_expert_parallel=self.enable_expert_parallel,
|
||||||
all2all_backend=self.all2all_backend,
|
all2all_backend=self.all2all_backend,
|
||||||
enable_dbo=self.enable_dbo,
|
enable_dbo=self.enable_dbo,
|
||||||
|
ubatch_size=self.ubatch_size,
|
||||||
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
|
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
|
||||||
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
|
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
|
||||||
disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
|
disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
|
||||||
|
|||||||
@ -201,10 +201,11 @@ def _make_metadata_with_slice(
|
|||||||
)
|
)
|
||||||
# NOTE: last token can be outside of the last request if we have CG padding.
|
# NOTE: last token can be outside of the last request if we have CG padding.
|
||||||
|
|
||||||
# If the "middle" request has tokens in both ubatches, we have to split it.
|
# If the request is split across ubatches, we have to adjust the metadata.
|
||||||
# If ubatch_slice is the first ubatch then we will be splitting the last
|
# splits_first_request: The first request in this slice is the continuation of
|
||||||
# request. If it's the second microbatch, then we will be splitting the
|
# a request that started in a previous slice.
|
||||||
# first request
|
# splits_last_request: The last request in this slice continues into the
|
||||||
|
# next slice.
|
||||||
splits_first_request = first_tok > start_locs[first_req]
|
splits_first_request = first_tok > start_locs[first_req]
|
||||||
splits_last_request = last_tok < start_locs[last_req + 1] - 1
|
splits_last_request = last_tok < start_locs[last_req + 1] - 1
|
||||||
|
|
||||||
@ -225,7 +226,10 @@ def _make_metadata_with_slice(
|
|||||||
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
||||||
|
|
||||||
if splits_last_request:
|
if splits_last_request:
|
||||||
tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop
|
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
|
||||||
|
# the tokens skipped because query_start_loc_cpu might have been modified
|
||||||
|
# if splits_first_request is True.
|
||||||
|
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
|
||||||
query_start_loc[-1] -= tokens_skipped
|
query_start_loc[-1] -= tokens_skipped
|
||||||
query_start_loc_cpu[-1] -= tokens_skipped
|
query_start_loc_cpu[-1] -= tokens_skipped
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from vllm.distributed.parallel_state import get_dp_group
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.worker.ubatch_utils import (
|
from vllm.v1.worker.ubatch_utils import (
|
||||||
check_ubatch_thresholds,
|
check_ubatch_thresholds,
|
||||||
is_second_ubatch_empty,
|
is_last_ubatch_empty,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -56,7 +56,7 @@ def _run_ar(
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def _post_process_ubatch(tensor: torch.Tensor) -> bool:
|
def _post_process_ubatch(tensor: torch.Tensor, num_ubatches: int) -> bool:
|
||||||
orig_num_tokens_tensor = tensor[0, :]
|
orig_num_tokens_tensor = tensor[0, :]
|
||||||
padded_num_tokens_tensor = tensor[1, :]
|
padded_num_tokens_tensor = tensor[1, :]
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ def _post_process_ubatch(tensor: torch.Tensor) -> bool:
|
|||||||
# there are no "empty" second ubatches
|
# there are no "empty" second ubatches
|
||||||
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
|
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
|
||||||
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
|
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
|
||||||
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
|
if is_last_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens, num_ubatches):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
|
"Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
|
||||||
)
|
)
|
||||||
@ -146,7 +146,7 @@ def _synchronize_dp_ranks(
|
|||||||
assert should_attempt_dp_padding == should_dp_pad
|
assert should_attempt_dp_padding == should_dp_pad
|
||||||
|
|
||||||
# Check conditions for microbatching
|
# Check conditions for microbatching
|
||||||
should_ubatch = _post_process_ubatch(tensor)
|
should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches)
|
||||||
|
|
||||||
if should_ubatch and not should_dp_pad:
|
if should_ubatch and not should_dp_pad:
|
||||||
logger.debug_once(
|
logger.debug_once(
|
||||||
|
|||||||
@ -2987,7 +2987,7 @@ class GPUModelRunner(
|
|||||||
|
|
||||||
cascade_attn_prefix_lens = None
|
cascade_attn_prefix_lens = None
|
||||||
# Disable cascade attention when using microbatching (DBO)
|
# Disable cascade attention when using microbatching (DBO)
|
||||||
if self.cascade_attn_enabled and not self.parallel_config.enable_dbo:
|
if self.cascade_attn_enabled and not self.parallel_config.use_ubatching:
|
||||||
# Pre-compute cascade attention prefix lengths
|
# Pre-compute cascade attention prefix lengths
|
||||||
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
|
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
|
||||||
num_scheduled_tokens_np,
|
num_scheduled_tokens_np,
|
||||||
@ -3028,6 +3028,13 @@ class GPUModelRunner(
|
|||||||
num_scheduled_tokens_np,
|
num_scheduled_tokens_np,
|
||||||
num_tokens_padded,
|
num_tokens_padded,
|
||||||
num_reqs_padded,
|
num_reqs_padded,
|
||||||
|
self.parallel_config.num_ubatches,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"ubatch_slices: %s, ubatch_slices_padded: %s",
|
||||||
|
ubatch_slices,
|
||||||
|
ubatch_slices_padded,
|
||||||
)
|
)
|
||||||
|
|
||||||
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
|
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
|
||||||
@ -3710,11 +3717,14 @@ class GPUModelRunner(
|
|||||||
# wrap the model with full cudagraph wrapper if needed.
|
# wrap the model with full cudagraph wrapper if needed.
|
||||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||||
assert cudagraph_mode is not None
|
assert cudagraph_mode is not None
|
||||||
if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo:
|
if (
|
||||||
|
cudagraph_mode.has_full_cudagraphs()
|
||||||
|
and not self.parallel_config.use_ubatching
|
||||||
|
):
|
||||||
self.model = CUDAGraphWrapper(
|
self.model = CUDAGraphWrapper(
|
||||||
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
||||||
)
|
)
|
||||||
elif self.parallel_config.enable_dbo:
|
elif self.parallel_config.use_ubatching:
|
||||||
if cudagraph_mode.has_full_cudagraphs():
|
if cudagraph_mode.has_full_cudagraphs():
|
||||||
self.model = UBatchWrapper(
|
self.model = UBatchWrapper(
|
||||||
self.model, self.vllm_config, CUDAGraphMode.FULL, self.device
|
self.model, self.vllm_config, CUDAGraphMode.FULL, self.device
|
||||||
@ -4095,7 +4105,16 @@ class GPUModelRunner(
|
|||||||
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
|
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
|
||||||
)
|
)
|
||||||
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
|
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
|
||||||
should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded
|
should_ubatch,
|
||||||
|
num_scheduled_tokens,
|
||||||
|
num_tokens_padded,
|
||||||
|
num_reqs_padded,
|
||||||
|
self.vllm_config.parallel_config.num_ubatches,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"ubatch_slices: %s, ubatch_slices_padded: %s",
|
||||||
|
ubatch_slices,
|
||||||
|
ubatch_slices_padded,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_metadata: PerLayerAttnMetadata | None = None
|
attn_metadata: PerLayerAttnMetadata | None = None
|
||||||
@ -4644,7 +4663,7 @@ class GPUModelRunner(
|
|||||||
# is above the threshold. Otherwise we just capture a non-ubatched
|
# is above the threshold. Otherwise we just capture a non-ubatched
|
||||||
# version of the graph
|
# version of the graph
|
||||||
allow_microbatching = (
|
allow_microbatching = (
|
||||||
self.parallel_config.enable_dbo
|
self.parallel_config.use_ubatching
|
||||||
and cudagraph_runtime_mode == CUDAGraphMode.FULL
|
and cudagraph_runtime_mode == CUDAGraphMode.FULL
|
||||||
and uniform_decode
|
and uniform_decode
|
||||||
and check_ubatch_thresholds(
|
and check_ubatch_thresholds(
|
||||||
@ -4779,8 +4798,8 @@ class GPUModelRunner(
|
|||||||
if kv_cache_group_id < len(kernel_block_sizes)
|
if kv_cache_group_id < len(kernel_block_sizes)
|
||||||
else None,
|
else None,
|
||||||
num_metadata_builders=1
|
num_metadata_builders=1
|
||||||
if not self.parallel_config.enable_dbo
|
if not self.parallel_config.use_ubatching
|
||||||
else 2,
|
else self.parallel_config.num_ubatches,
|
||||||
)
|
)
|
||||||
# Calculate reorder batch threshold (if needed)
|
# Calculate reorder batch threshold (if needed)
|
||||||
# Note (tdoublep): do this *after* constructing builders,
|
# Note (tdoublep): do this *after* constructing builders,
|
||||||
|
|||||||
@ -103,8 +103,10 @@ class UBatchWrapper:
|
|||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
self.comm_stream = torch.cuda.Stream(device=device)
|
self.comm_stream = torch.cuda.Stream(device=device)
|
||||||
# Two ubatch threads plus the main thread
|
# Ubatch threads plus the main thread
|
||||||
self.ready_barrier = threading.Barrier(3)
|
self.ready_barrier = threading.Barrier(
|
||||||
|
self.vllm_config.parallel_config.num_ubatches + 1
|
||||||
|
)
|
||||||
|
|
||||||
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
|
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
|
||||||
|
|
||||||
@ -309,7 +311,7 @@ class UBatchWrapper:
|
|||||||
create_forward_context(
|
create_forward_context(
|
||||||
attn_metadata[i] if attn_metadata is not None else None,
|
attn_metadata[i] if attn_metadata is not None else None,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
dp_metadata=dp_metadata,
|
dp_metadata=dp_metadata[i],
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
)
|
)
|
||||||
@ -417,18 +419,19 @@ class UBatchWrapper:
|
|||||||
|
|
||||||
# We shouldn't be here unless we are running with multiple DP ranks
|
# We shouldn't be here unless we are running with multiple DP ranks
|
||||||
assert dp_metadata is not None
|
assert dp_metadata is not None
|
||||||
num_tokens_per_ubatch = (
|
ubatch_dp_metadata = []
|
||||||
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
|
for ubatch_slice in ubatch_slices:
|
||||||
)
|
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
ubatch_num_tokens_across_dp = torch.tensor(
|
||||||
ubatch_num_tokens_across_dp = torch.tensor(
|
[ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32
|
||||||
[num_tokens_per_ubatch] * dp_size, device="cpu", dtype=torch.int32
|
)
|
||||||
)
|
ubatch_dp_metadata.append(
|
||||||
ubatch_dp_metadata = DPMetadata.make(
|
DPMetadata.make(
|
||||||
self.vllm_config.parallel_config,
|
self.vllm_config.parallel_config,
|
||||||
num_tokens_per_ubatch,
|
ubatch_slice.num_tokens,
|
||||||
ubatch_num_tokens_across_dp,
|
ubatch_num_tokens_across_dp,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
num_tokens not in self.cudagraphs
|
num_tokens not in self.cudagraphs
|
||||||
@ -464,7 +467,7 @@ class UBatchWrapper:
|
|||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
compute_stream=compute_stream,
|
compute_stream=compute_stream,
|
||||||
dp_metadata=dp_metadata,
|
dp_metadata=ubatch_dp_metadata,
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -27,14 +27,16 @@ class UBatchSlice:
|
|||||||
UBatchSlices: TypeAlias = list[UBatchSlice]
|
UBatchSlices: TypeAlias = list[UBatchSlice]
|
||||||
|
|
||||||
|
|
||||||
def is_second_ubatch_empty(orig_num_tokens: int, padded_num_tokens: int) -> bool:
|
def is_last_ubatch_empty(
|
||||||
return (padded_num_tokens // 2) >= orig_num_tokens
|
orig_num_tokens: int, padded_num_tokens: int, num_ubatches: int
|
||||||
|
) -> bool:
|
||||||
|
return (padded_num_tokens // num_ubatches) * (num_ubatches - 1) >= orig_num_tokens
|
||||||
|
|
||||||
|
|
||||||
def check_ubatch_thresholds(
|
def check_ubatch_thresholds(
|
||||||
config: ParallelConfig, num_tokens: int, uniform_decode: bool
|
config: ParallelConfig, num_tokens: int, uniform_decode: bool
|
||||||
) -> bool:
|
) -> bool:
|
||||||
if not config.enable_dbo:
|
if not config.use_ubatching:
|
||||||
return False
|
return False
|
||||||
if uniform_decode:
|
if uniform_decode:
|
||||||
return num_tokens >= config.dbo_decode_token_threshold
|
return num_tokens >= config.dbo_decode_token_threshold
|
||||||
@ -42,21 +44,17 @@ def check_ubatch_thresholds(
|
|||||||
return num_tokens >= config.dbo_prefill_token_threshold
|
return num_tokens >= config.dbo_prefill_token_threshold
|
||||||
|
|
||||||
|
|
||||||
# This just pads the second ubatch slice out to the total number of tokens
|
# This pads the last ubatch slice out to the total number of tokens
|
||||||
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
|
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
|
||||||
def _pad_out_ubatch_slices(
|
def _pad_out_ubatch_slices(
|
||||||
ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
|
ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
|
||||||
) -> UBatchSlices:
|
) -> UBatchSlices:
|
||||||
# TODO(lucas): handle empty second ubatch
|
last_slice = ubatch_slices[-1]
|
||||||
padded_second_request_slice = slice(
|
padded_last_request_slice = slice(last_slice.request_slice.start, num_reqs_padded)
|
||||||
ubatch_slices[1].request_slice.start, num_reqs_padded
|
padded_last_token_slice = slice(last_slice.token_slice.start, num_total_tokens)
|
||||||
)
|
|
||||||
padded_second_token_slice = slice(
|
return ubatch_slices[:-1] + [
|
||||||
ubatch_slices[1].token_slice.start, num_total_tokens
|
UBatchSlice(padded_last_request_slice, padded_last_token_slice)
|
||||||
)
|
|
||||||
return [
|
|
||||||
ubatch_slices[0],
|
|
||||||
UBatchSlice(padded_second_request_slice, padded_second_token_slice),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -65,40 +63,45 @@ def maybe_create_ubatch_slices(
|
|||||||
num_scheduled_tokens: np.ndarray,
|
num_scheduled_tokens: np.ndarray,
|
||||||
num_tokens_padded: int,
|
num_tokens_padded: int,
|
||||||
num_reqs_padded: int,
|
num_reqs_padded: int,
|
||||||
split_point: int | None = None,
|
num_ubatches: int,
|
||||||
|
split_point: list[int] | int | None = None,
|
||||||
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
|
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
|
||||||
if not should_ubatch:
|
if not should_ubatch:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
if split_point is None:
|
if split_point is None:
|
||||||
split_point = int(num_tokens_padded) // 2
|
split_point = int(num_tokens_padded) // num_ubatches
|
||||||
|
|
||||||
|
token_split_points = [split_point * i for i in range(1, num_ubatches)]
|
||||||
|
|
||||||
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
|
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
|
||||||
# in cu_num_tokens directly (i.e. query_start_loc)
|
# in cu_num_tokens directly (i.e. query_start_loc)
|
||||||
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
|
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
|
||||||
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
|
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
|
||||||
|
|
||||||
first_ubatch_token_slice = slice(0, split_point)
|
ubatch_slices = []
|
||||||
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
|
start_token = 0
|
||||||
|
|
||||||
# Determine request slices using exclusive stop semantics
|
# Add the end point to the split points to make iteration easier
|
||||||
# First ubatch includes requests whose tokens overlap [0, split_point)
|
all_points = token_split_points + [cu_num_tokens[-1]]
|
||||||
first_ubatch_req_stop = int(
|
|
||||||
np.searchsorted(cu_num_tokens, split_point, side="left")
|
|
||||||
)
|
|
||||||
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
|
|
||||||
|
|
||||||
# Second ubatch starts at the request that contains the split_point
|
for end_token in all_points:
|
||||||
# or the request starting exactly at split_point (if on boundary)
|
token_slice = slice(start_token, end_token)
|
||||||
second_ubatch_req_start = int(
|
|
||||||
np.searchsorted(cu_num_tokens, split_point, side="right") - 1
|
|
||||||
)
|
|
||||||
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
|
|
||||||
|
|
||||||
ubatch_slices = [
|
# Determine request slices using exclusive stop semantics
|
||||||
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
|
# Ubatch includes requests whose tokens overlap [start_token, end_token)
|
||||||
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
|
|
||||||
]
|
# Start at the request that contains the start_token
|
||||||
|
# or the request starting exactly at start_token (if on boundary)
|
||||||
|
req_start = int(np.searchsorted(cu_num_tokens, start_token, side="right") - 1)
|
||||||
|
|
||||||
|
# Stop at the request that starts at or after end_token
|
||||||
|
req_stop = int(np.searchsorted(cu_num_tokens, end_token, side="left"))
|
||||||
|
|
||||||
|
req_slice = slice(req_start, req_stop)
|
||||||
|
ubatch_slices.append(UBatchSlice(req_slice, token_slice))
|
||||||
|
|
||||||
|
start_token = end_token
|
||||||
|
|
||||||
ubatch_slices_padded = _pad_out_ubatch_slices(
|
ubatch_slices_padded = _pad_out_ubatch_slices(
|
||||||
ubatch_slices, num_tokens_padded, num_reqs_padded
|
ubatch_slices, num_tokens_padded, num_reqs_padded
|
||||||
|
|||||||
@ -7,10 +7,15 @@ import torch
|
|||||||
|
|
||||||
from vllm import forward_context
|
from vllm import forward_context
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.utils.torch_utils import current_stream
|
from vllm.utils.torch_utils import current_stream
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_THREAD_ID_TO_CONTEXT: dict = {}
|
_THREAD_ID_TO_CONTEXT: dict = {}
|
||||||
_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None]
|
# Here we hardcode the number of microbatches to 2 for default.
|
||||||
|
_NUM_UBATCHES: int = 2
|
||||||
|
_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = []
|
||||||
|
|
||||||
|
|
||||||
class UBatchContext:
|
class UBatchContext:
|
||||||
@ -48,6 +53,7 @@ class UBatchContext:
|
|||||||
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
||||||
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
|
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
|
||||||
_CURRENT_CONTEXTS[self.id] = self
|
_CURRENT_CONTEXTS[self.id] = self
|
||||||
|
# _NUM_UBATCHES is set in make_ubatch_contexts
|
||||||
self.ready_barrier.wait()
|
self.ready_barrier.wait()
|
||||||
|
|
||||||
self.cpu_wait_event.wait()
|
self.cpu_wait_event.wait()
|
||||||
@ -181,7 +187,7 @@ dbo_switch_to_compute_sync = _register_ubatch_function(
|
|||||||
def dbo_register_recv_hook(recv_hook):
|
def dbo_register_recv_hook(recv_hook):
|
||||||
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
||||||
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||||
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2]
|
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES]
|
||||||
next_ctx.recv_hook = recv_hook
|
next_ctx.recv_hook = recv_hook
|
||||||
|
|
||||||
|
|
||||||
@ -202,7 +208,14 @@ def make_ubatch_contexts(
|
|||||||
ready_barrier: threading.Barrier,
|
ready_barrier: threading.Barrier,
|
||||||
schedule: str = "default",
|
schedule: str = "default",
|
||||||
) -> list[UBatchContext]:
|
) -> list[UBatchContext]:
|
||||||
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
global _NUM_UBATCHES, _CURRENT_CONTEXTS
|
||||||
|
assert num_micro_batches > 1, "num_micro_batches must be greater than 1"
|
||||||
|
|
||||||
|
_NUM_UBATCHES = num_micro_batches
|
||||||
|
# Ensure the global context list is large enough
|
||||||
|
if len(_CURRENT_CONTEXTS) < num_micro_batches:
|
||||||
|
_CURRENT_CONTEXTS.extend([None] * (num_micro_batches - len(_CURRENT_CONTEXTS)))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Create a context manager for micro-batching synchronization.
|
Create a context manager for micro-batching synchronization.
|
||||||
"""
|
"""
|
||||||
@ -210,8 +223,6 @@ def make_ubatch_contexts(
|
|||||||
gpu_comm_done_events = [torch.Event() for _ in range(num_micro_batches)]
|
gpu_comm_done_events = [torch.Event() for _ in range(num_micro_batches)]
|
||||||
gpu_compute_done_events = [torch.Event() for _ in range(num_micro_batches)]
|
gpu_compute_done_events = [torch.Event() for _ in range(num_micro_batches)]
|
||||||
|
|
||||||
assert len(forward_contexts) == 2
|
|
||||||
|
|
||||||
ctxs = []
|
ctxs = []
|
||||||
for i in range(num_micro_batches):
|
for i in range(num_micro_batches):
|
||||||
ctx = UBatchContext(
|
ctx = UBatchContext(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user