mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:35:01 +08:00
[Core] Simplify the Dp padding/should ubatch coordination logic (#25768)
Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
c50901f3b9
commit
2111b4643c
@ -13,7 +13,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
split_attn_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.worker.ubatch_splitting import create_ubatch_slices
|
||||
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -152,6 +152,10 @@ class ParallelConfig:
|
||||
threshold, microbatching will be used. Otherwise, the request will be
|
||||
processed in a single batch."""
|
||||
|
||||
disable_nccl_for_dp_synchronization: bool = False
|
||||
"""Forces the dp synchronization logic in vllm/v1/worker/dp_utils.py
|
||||
to use Gloo instead of NCCL for its all reduce"""
|
||||
|
||||
ray_workers_use_nsight: bool = False
|
||||
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
||||
|
||||
|
||||
@ -365,6 +365,9 @@ class EngineArgs:
|
||||
enable_dbo: bool = ParallelConfig.enable_dbo
|
||||
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
|
||||
dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
|
||||
disable_nccl_for_dp_synchronization: bool = (
|
||||
ParallelConfig.disable_nccl_for_dp_synchronization
|
||||
)
|
||||
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
|
||||
enable_eplb: bool = ParallelConfig.enable_eplb
|
||||
expert_placement_strategy: ExpertPlacementStrategy = (
|
||||
@ -760,6 +763,10 @@ class EngineArgs:
|
||||
"--dbo-prefill-token-threshold",
|
||||
**parallel_kwargs["dbo_prefill_token_threshold"],
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--disable-nccl-for-dp-synchronization",
|
||||
**parallel_kwargs["disable_nccl_for_dp_synchronization"],
|
||||
)
|
||||
parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
|
||||
parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
|
||||
parallel_group.add_argument(
|
||||
@ -1437,6 +1444,7 @@ class EngineArgs:
|
||||
enable_dbo=self.enable_dbo,
|
||||
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
|
||||
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
|
||||
disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
|
||||
enable_eplb=self.enable_eplb,
|
||||
eplb_config=self.eplb_config,
|
||||
expert_placement_strategy=self.expert_placement_strategy,
|
||||
|
||||
@ -95,7 +95,6 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||
VLLM_SKIP_P2P_CHECK: bool = False
|
||||
VLLM_DISABLED_KERNELS: list[str] = []
|
||||
VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: bool = False
|
||||
VLLM_DISABLE_PYNCCL: bool = False
|
||||
VLLM_USE_V1: bool = True
|
||||
VLLM_ROCM_USE_AITER: bool = False
|
||||
@ -830,12 +829,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_DISABLED_KERNELS": lambda: []
|
||||
if "VLLM_DISABLED_KERNELS" not in os.environ
|
||||
else os.environ["VLLM_DISABLED_KERNELS"].split(","),
|
||||
# Swaps the all reduce backend that we use to coordinate the DP padding
|
||||
# information from NCCL to gloo.
|
||||
"VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION": lambda: (
|
||||
os.getenv("VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION", "False").lower()
|
||||
in ("true", "1")
|
||||
),
|
||||
# Disable pynccl (using torch.distributed instead)
|
||||
"VLLM_DISABLE_PYNCCL": lambda: (
|
||||
os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1")
|
||||
|
||||
@ -8,13 +8,11 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlices
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
@ -87,129 +85,22 @@ class DPMetadata:
|
||||
# NOTE: local_sizes should only be set by the chunked_sizes context manager
|
||||
local_sizes: Optional[list[int]] = None
|
||||
|
||||
@staticmethod
|
||||
def num_tokens_across_dp(
|
||||
num_tokens: int, dp_size: int, dp_rank: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Gather the num_tokens across all DP ranks and return results in a
|
||||
CPU tensor of size dp_size.
|
||||
"""
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
|
||||
device = current_platform.device_type
|
||||
group = get_dp_group().device_group
|
||||
|
||||
# Transfering this tensor from GPU to CPU will introduce a GPU sync
|
||||
# point that could adversely affect performance of vllm with asynch
|
||||
# scheduling. This environment variable exists to quickly disable
|
||||
# this optimization if we run into this case.
|
||||
if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION:
|
||||
logger.info_once(
|
||||
"Using CPU all reduce to syncronize DP padding between ranks."
|
||||
)
|
||||
device = "cpu"
|
||||
group = get_dp_group().cpu_group
|
||||
num_tokens_across_dp = [0] * dp_size
|
||||
num_tokens_across_dp[dp_rank] = num_tokens
|
||||
num_tokens_tensor = torch.tensor(
|
||||
num_tokens_across_dp, device=device, dtype=torch.int32
|
||||
)
|
||||
dist.all_reduce(num_tokens_tensor, group=group)
|
||||
return num_tokens_tensor.cpu()
|
||||
|
||||
# Get the cumulative tokens across sequence parallel ranks.
|
||||
# In this case the input to the MoEs will be distributed w.r.t both
|
||||
# DP and TP rank.
|
||||
# When sp_size==1, this is just the cummulative num tokens across DP.
|
||||
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
|
||||
num_tokens_across_sp_cpu = (
|
||||
self.num_tokens_across_dp_cpu - 1 + sp_size
|
||||
) // sp_size
|
||||
num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size)
|
||||
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
|
||||
|
||||
@staticmethod
|
||||
def should_ubatch_across_dp(
|
||||
should_ubatch: bool,
|
||||
orig_num_tokens_per_ubatch: int,
|
||||
padded_num_tokens_per_ubatch: int,
|
||||
dp_size: int,
|
||||
dp_rank: int,
|
||||
) -> tuple[bool, Optional[torch.Tensor]]:
|
||||
"""
|
||||
1. Decides if each DP rank is going to microbatch. Either all ranks
|
||||
run with microbatching or none of them do. If this function decides
|
||||
not to run with microbatching. It will "abort" meaning that no padding
|
||||
information will be returned to the caller. It will return (False, None)
|
||||
|
||||
2. Determines the total number of tokens that each rank will run.
|
||||
All ranks will be padded out so that the run with the same number
|
||||
of tokens
|
||||
|
||||
Returns: tuple[
|
||||
should_ubatch: Are all DP ranks going to microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
None if should_ubatch if False
|
||||
]
|
||||
"""
|
||||
|
||||
device = current_platform.device_type
|
||||
tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
|
||||
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
|
||||
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
|
||||
tensor[2][dp_rank] = 1 if should_ubatch else 0
|
||||
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
|
||||
dist.all_reduce(tensor, group=get_dp_group().device_group)
|
||||
|
||||
result: bool = bool(torch.all(tensor[2] == 1).item())
|
||||
if not result:
|
||||
return result, None
|
||||
|
||||
orig_num_tokens_tensor = tensor[0, :]
|
||||
padded_num_tokens_tensor = tensor[1, :]
|
||||
|
||||
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
|
||||
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
|
||||
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
|
||||
logger.debug(
|
||||
"Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
|
||||
)
|
||||
return False, None
|
||||
return result, padded_num_tokens_tensor.cpu()
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
parallel_config: ParallelConfig,
|
||||
attn_metadata: Any,
|
||||
num_tokens: int,
|
||||
num_tokens_across_dp_cpu: Optional[torch.Tensor] = None,
|
||||
num_tokens_across_dp_cpu: torch.Tensor,
|
||||
) -> "DPMetadata":
|
||||
assert num_tokens_across_dp_cpu is not None
|
||||
assert parallel_config.data_parallel_size > 1
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
if attn_metadata is not None and hasattr(attn_metadata, "num_prefill_tokens"):
|
||||
# for v0 attention backends
|
||||
batchsize = (
|
||||
attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
|
||||
)
|
||||
else:
|
||||
# for v1 attention backends or no attn_metadata
|
||||
batchsize = num_tokens
|
||||
batchsize = num_tokens
|
||||
|
||||
# If num_tokens_across_dp is None, it will be computed by all_reduce
|
||||
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
|
||||
assert (
|
||||
num_tokens_across_dp_cpu is None
|
||||
or num_tokens_across_dp_cpu[dp_rank] == batchsize
|
||||
), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
|
||||
if num_tokens_across_dp_cpu is None:
|
||||
num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
|
||||
batchsize, dp_size, dp_rank
|
||||
)
|
||||
assert num_tokens_across_dp_cpu[dp_rank] == batchsize, (
|
||||
f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
|
||||
)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
|
||||
return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
|
||||
|
||||
@ -376,11 +267,9 @@ def set_forward_context(
|
||||
if vllm_config.parallel_config.data_parallel_size > 1 and (
|
||||
attn_metadata is not None or num_tokens is not None
|
||||
):
|
||||
assert num_tokens_across_dp is not None
|
||||
dp_metadata = DPMetadata.make(
|
||||
vllm_config.parallel_config,
|
||||
attn_metadata,
|
||||
num_tokens or 0,
|
||||
num_tokens_across_dp,
|
||||
vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
|
||||
)
|
||||
|
||||
forward_context = create_forward_context(
|
||||
|
||||
177
vllm/v1/worker/dp_utils.py
Normal file
177
vllm/v1/worker/dp_utils.py
Normal file
@ -0,0 +1,177 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.worker.ubatch_utils import (
|
||||
UBatchSlices,
|
||||
check_ubatch_thresholds,
|
||||
create_ubatch_slices,
|
||||
is_second_ubatch_empty,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_device_and_group(parallel_config: ParallelConfig):
|
||||
device = current_platform.device_type
|
||||
group = get_dp_group().device_group
|
||||
|
||||
# Transfering this tensor from GPU to CPU will introduce a GPU sync
|
||||
# point that could adversely affect performance of vllm with asynch
|
||||
# scheduling. This environment variable exists to quickly disable
|
||||
# this optimization if we run into this case.
|
||||
if parallel_config.disable_nccl_for_dp_synchronization:
|
||||
logger.info_once("Using CPU all reduce to syncronize DP padding between ranks.")
|
||||
device = "cpu"
|
||||
group = get_dp_group().cpu_group
|
||||
return device, group
|
||||
|
||||
|
||||
def _run_ar(
|
||||
should_ubatch: bool,
|
||||
orig_num_tokens_per_ubatch: int,
|
||||
padded_num_tokens_per_ubatch: int,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> torch.Tensor:
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
device, group = _get_device_and_group(parallel_config)
|
||||
tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
|
||||
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
|
||||
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
|
||||
tensor[2][dp_rank] = 1 if should_ubatch else 0
|
||||
dist.all_reduce(tensor, group=group)
|
||||
return tensor
|
||||
|
||||
|
||||
def _post_process_ubatch(tensor: torch.Tensor) -> bool:
|
||||
orig_num_tokens_tensor = tensor[0, :]
|
||||
padded_num_tokens_tensor = tensor[1, :]
|
||||
|
||||
# First determine if we are going to be ubatching.
|
||||
should_ubatch: bool = bool(torch.all(tensor[2] == 1).item())
|
||||
if not should_ubatch:
|
||||
return False
|
||||
# If the DP ranks are planning to ubatch, make sure that
|
||||
# there are no "empty" second ubatches
|
||||
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
|
||||
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
|
||||
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
|
||||
logger.debug(
|
||||
"Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
|
||||
)
|
||||
should_ubatch = False
|
||||
return should_ubatch
|
||||
|
||||
|
||||
def _synchronize_dp_ranks(
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
should_attempt_ubatching: bool,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> tuple[bool, Optional[torch.Tensor]]:
|
||||
"""
|
||||
1. Decides if each DP rank is going to microbatch. Either all ranks
|
||||
run with microbatching or none of them do.
|
||||
|
||||
2. Determines the total number of tokens that each rank will run.
|
||||
All ranks will be padded out so that the run with the same number
|
||||
of tokens
|
||||
|
||||
Returns: tuple[
|
||||
should_ubatch: Are all DP ranks going to microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding.
|
||||
]
|
||||
|
||||
"""
|
||||
assert num_tokens_padded >= num_tokens_unpadded
|
||||
|
||||
# First we coordinate between the DP ranks via an All Reduce
|
||||
# to determine the total number of tokens that each rank
|
||||
# will run and if we are using ubatching or not.
|
||||
tensor = _run_ar(
|
||||
should_ubatch=should_attempt_ubatching,
|
||||
orig_num_tokens_per_ubatch=num_tokens_unpadded,
|
||||
padded_num_tokens_per_ubatch=num_tokens_padded,
|
||||
parallel_config=parallel_config,
|
||||
)
|
||||
|
||||
# Ensure that each rank is processing the same nuber of tokens
|
||||
num_tokens_across_dp = tensor[1, :]
|
||||
max_num_tokens = int(num_tokens_across_dp.max().item())
|
||||
num_tokens_after_padding = torch.tensor(
|
||||
[max_num_tokens] * len(num_tokens_across_dp), device="cpu", dtype=torch.int32
|
||||
)
|
||||
|
||||
should_ubatch = _post_process_ubatch(tensor)
|
||||
|
||||
return should_ubatch, num_tokens_after_padding
|
||||
|
||||
|
||||
def coordinate_batch_across_dp(
|
||||
num_scheduled_tokens_per_request: np.ndarray,
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
parallel_config: ParallelConfig,
|
||||
allow_microbatching: bool,
|
||||
uniform_decode: bool,
|
||||
) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Coordinates amongst all DP ranks to determine if and how the full batch
|
||||
should be split into microbatches.
|
||||
|
||||
Returns: tuple[
|
||||
ubatch_slices: if this is set then all DP ranks have agreed to
|
||||
microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding.
|
||||
]
|
||||
|
||||
"""
|
||||
if parallel_config.data_parallel_size == 1:
|
||||
# Early exit.
|
||||
return None, None
|
||||
|
||||
# Check preconditions for microbatching
|
||||
should_attempt_ubatching = check_ubatch_thresholds(
|
||||
parallel_config,
|
||||
num_tokens_unpadded,
|
||||
uniform_decode=uniform_decode,
|
||||
)
|
||||
|
||||
# If the caller has explicitly disabled microbatching.
|
||||
if not allow_microbatching:
|
||||
should_attempt_ubatching = False
|
||||
|
||||
(should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks(
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
should_attempt_ubatching,
|
||||
parallel_config,
|
||||
)
|
||||
|
||||
# Don't microbatch unless every other DP worker is also microbatching
|
||||
if not should_ubatch:
|
||||
return (None, num_tokens_after_padding)
|
||||
|
||||
# This doesn't actually pad the ubatch slices. It just initializes the
|
||||
# split point to the padded value so that padding can be applied
|
||||
# to the second ubatch in pad_out_ubatch_slice after attention
|
||||
# metadata creation
|
||||
assert num_tokens_after_padding is not None
|
||||
token_split_point = int(num_tokens_after_padding[0].item()) // 2
|
||||
|
||||
ubatch_slices = create_ubatch_slices(
|
||||
num_scheduled_tokens_per_request, token_split_point
|
||||
)
|
||||
|
||||
return (ubatch_slices, num_tokens_after_padding)
|
||||
@ -41,7 +41,7 @@ from vllm.distributed.parallel_state import (
|
||||
is_global_first_rank,
|
||||
prepare_communication_buffer_for_model,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, DPMetadata, set_forward_context
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
@ -131,12 +131,16 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
||||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.ubatch_splitting import check_ubatch_thresholds, ubatch_split
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices
|
||||
from vllm.v1.worker.ubatch_utils import (
|
||||
UBatchSlice,
|
||||
UBatchSlices,
|
||||
check_ubatch_thresholds,
|
||||
)
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
|
||||
from .utils import (
|
||||
@ -1161,18 +1165,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
|
||||
|
||||
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
|
||||
num_tokens_padded = num_tokens_unpadded + self.get_local_padding(
|
||||
num_tokens_unpadded
|
||||
)
|
||||
num_tokens_padded = self._get_num_input_tokens(num_tokens_unpadded)
|
||||
uniform_decode = (
|
||||
max_num_scheduled_tokens == self.uniform_decode_query_len
|
||||
) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
|
||||
ubatch_slices, num_tokens_after_padding = ubatch_split(
|
||||
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
|
||||
num_scheduled_tokens,
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
uniform_decode=uniform_decode,
|
||||
vllm_config=self.vllm_config,
|
||||
self.parallel_config,
|
||||
True,
|
||||
uniform_decode,
|
||||
)
|
||||
|
||||
self.seq_lens.np[:num_reqs] = (
|
||||
@ -1405,7 +1408,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
spec_decode_common_attn_metadata,
|
||||
max_num_scheduled_tokens,
|
||||
ubatch_slices,
|
||||
num_tokens_after_padding,
|
||||
num_tokens_across_dp,
|
||||
use_cascade_attn,
|
||||
)
|
||||
|
||||
@ -1986,65 +1989,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
log_stats=self.parallel_config.eplb_config.log_balancedness,
|
||||
)
|
||||
|
||||
def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Determines the total number of tokens that each rank will run.
|
||||
All ranks will be padded out so that they run with the same number
|
||||
of tokens
|
||||
|
||||
Returns: tuple[
|
||||
num_pad_tokens: The number of tokens that will be added to the batch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens for each DP rank including padding.
|
||||
]
|
||||
"""
|
||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
# For DP: Don't pad when setting enforce_eager.
|
||||
# This lets us set enforce_eager on the prefiller in a P/D setup and
|
||||
# still use CUDA graphs (enabled by this padding) on the decoder.
|
||||
#
|
||||
# TODO(tms) : There are many cases where padding is enabled for
|
||||
# prefills, causing unnecessary and excessive padding of activations.
|
||||
|
||||
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
|
||||
# Early exit.
|
||||
return 0, None
|
||||
|
||||
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
||||
num_tokens, dp_size, dp_rank
|
||||
)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
|
||||
num_tokens_after_padding = torch.tensor(
|
||||
[max_tokens_across_dp_cpu] * dp_size, device="cpu", dtype=torch.int32
|
||||
)
|
||||
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
||||
|
||||
def get_local_padding(self, num_tokens_unpadded: int) -> int:
|
||||
num_tokens_padded = num_tokens_unpadded
|
||||
|
||||
if (
|
||||
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]
|
||||
):
|
||||
# Use piecewise CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded)
|
||||
else:
|
||||
# Eager mode.
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
# enabled collective fusion for SP
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
if (
|
||||
self.vllm_config.compilation_config.pass_config.enable_sequence_parallelism
|
||||
and tp_size > 1
|
||||
):
|
||||
num_tokens_padded = round_up(num_tokens_unpadded, tp_size)
|
||||
|
||||
num_pad_tokens = num_tokens_padded - num_tokens_unpadded
|
||||
return num_pad_tokens
|
||||
|
||||
# This is where the second ubatch is adjusted to account for the padding.
|
||||
# Should be called after attention metadata creation. This just pads
|
||||
# the second ubatch slice out to the total number of tokens
|
||||
@ -2127,13 +2071,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def _preprocess(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
num_input_tokens: int, # Padded
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
ubatch_slices: Optional[UBatchSlices] = None,
|
||||
num_tokens_after_padding: Optional[torch.Tensor] = None,
|
||||
) -> tuple[
|
||||
int,
|
||||
int,
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
torch.Tensor,
|
||||
@ -2141,14 +2082,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dict[str, Any],
|
||||
]:
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if ubatch_slices:
|
||||
assert num_tokens_after_padding is not None
|
||||
num_input_tokens = int(num_tokens_after_padding[0].item() * 2)
|
||||
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
|
||||
elif ubatch_slices is None:
|
||||
num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens)
|
||||
num_pad, num_tokens_after_padding = self.get_dp_padding(num_input_tokens)
|
||||
num_input_tokens += num_pad
|
||||
|
||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||
# modal outputs after that to ensure the correct order
|
||||
@ -2235,8 +2168,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
return (
|
||||
num_scheduled_tokens,
|
||||
num_input_tokens,
|
||||
num_tokens_after_padding,
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
positions,
|
||||
@ -2506,24 +2437,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
spec_decode_common_attn_metadata,
|
||||
max_query_len,
|
||||
ubatch_slices,
|
||||
num_tokens_after_padding,
|
||||
num_tokens_across_dp,
|
||||
use_cascade_attn,
|
||||
) = self._prepare_inputs(scheduler_output)
|
||||
|
||||
if ubatch_slices:
|
||||
assert num_tokens_across_dp is not None
|
||||
num_input_tokens = int(num_tokens_across_dp[0].item())
|
||||
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
|
||||
elif num_tokens_across_dp is not None:
|
||||
num_input_tokens = int(num_tokens_across_dp[0].item())
|
||||
else:
|
||||
num_input_tokens = self._get_num_input_tokens(
|
||||
scheduler_output.total_num_scheduled_tokens
|
||||
)
|
||||
|
||||
(
|
||||
num_scheduled_tokens,
|
||||
num_input_tokens,
|
||||
num_tokens_across_dp,
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
model_kwargs,
|
||||
) = self._preprocess(
|
||||
scheduler_output,
|
||||
intermediate_tensors,
|
||||
ubatch_slices,
|
||||
num_tokens_after_padding,
|
||||
scheduler_output, num_input_tokens, intermediate_tensors
|
||||
)
|
||||
|
||||
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
||||
@ -2548,11 +2485,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
):
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
# This is currently to get around the assert in the DPMetadata
|
||||
# where it wants `num_tokens_across_dp` to align with `num_tokens`
|
||||
if ubatch_slices is not None:
|
||||
num_input_tokens = ubatch_slices[0].num_tokens
|
||||
|
||||
# Run the model.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with (
|
||||
@ -3329,36 +3261,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
|
||||
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
|
||||
|
||||
ubatch_slices = None
|
||||
num_tokens_after_padding = None
|
||||
|
||||
# We currently only microbatch if the number of tokens is
|
||||
# over a certain threshold.
|
||||
if self.parallel_config.enable_dbo and allow_microbatching:
|
||||
ubatch_slices, ubatch_num_tokens_after_padding = ubatch_split(
|
||||
num_scheduled_tokens,
|
||||
total_num_scheduled_tokens,
|
||||
total_num_scheduled_tokens,
|
||||
uniform_decode=uniform_decode,
|
||||
vllm_config=self.vllm_config,
|
||||
)
|
||||
# Currently when DBO is enabled `ubatch_split` returns
|
||||
# the num_tokens_after_padding for a single ubatch, but we have 2
|
||||
# TODO(sage,lucas): this is cruft that should be addressed in the
|
||||
# padding refactor.
|
||||
if ubatch_num_tokens_after_padding is not None:
|
||||
num_tokens_after_padding = ubatch_num_tokens_after_padding * 2
|
||||
|
||||
# If we failed to microbatch, currently need to resynchronize
|
||||
# TODO(lucas,sage): we should be able to avoid this second sync by
|
||||
# refactoring `get_dp_padding_ubatch` and `get_dp_padding` into
|
||||
# a single `coordinate_batch_across_dp` function.
|
||||
if num_tokens_after_padding is None:
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
||||
num_tokens_after_padding = num_tokens + num_pad
|
||||
else:
|
||||
num_tokens_across_dp = num_tokens_after_padding
|
||||
num_tokens_after_padding = int(num_tokens_after_padding[0].item())
|
||||
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
|
||||
num_scheduled_tokens,
|
||||
total_num_scheduled_tokens,
|
||||
total_num_scheduled_tokens,
|
||||
self.vllm_config.parallel_config,
|
||||
allow_microbatching,
|
||||
uniform_decode,
|
||||
)
|
||||
num_tokens_after_padding = num_tokens
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_after_padding = int(num_tokens_across_dp[0])
|
||||
|
||||
attn_metadata: Optional[PerLayerAttnMetadata] = None
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
|
||||
from vllm.forward_context import (
|
||||
DPMetadata,
|
||||
create_forward_context,
|
||||
get_forward_context,
|
||||
override_forward_context,
|
||||
@ -409,6 +410,18 @@ class UBatchWrapper:
|
||||
|
||||
# We shouldn't be here unless we are running with multiple DP ranks
|
||||
assert dp_metadata is not None
|
||||
num_tokens_per_ubatch = (
|
||||
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
|
||||
)
|
||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
ubatch_num_tokens_across_dp = torch.tensor(
|
||||
[num_tokens_per_ubatch] * dp_size, device="cpu", dtype=torch.int32
|
||||
)
|
||||
ubatch_dp_metadata = DPMetadata.make(
|
||||
self.vllm_config.parallel_config,
|
||||
num_tokens_per_ubatch,
|
||||
ubatch_num_tokens_across_dp,
|
||||
)
|
||||
|
||||
if (
|
||||
num_tokens not in self.cudagraphs
|
||||
@ -422,7 +435,7 @@ class UBatchWrapper:
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
compute_stream=compute_stream,
|
||||
dp_metadata=dp_metadata,
|
||||
dp_metadata=ubatch_dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
)
|
||||
|
||||
@ -1,207 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.forward_context import DPMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import round_up
|
||||
from vllm.v1.worker.ubatch_utils import (
|
||||
UBatchSlice,
|
||||
UBatchSlices,
|
||||
is_second_ubatch_empty,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def should_ubatch_with_num_tokens(
|
||||
should_ubatch: bool,
|
||||
orig_num_tokens_per_ubatch: int,
|
||||
padded_num_tokens_per_ubatch: int,
|
||||
vllm_config: VllmConfig,
|
||||
) -> tuple[bool, Optional[torch.Tensor]]:
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
return DPMetadata.should_ubatch_across_dp(
|
||||
should_ubatch,
|
||||
orig_num_tokens_per_ubatch,
|
||||
padded_num_tokens_per_ubatch,
|
||||
dp_size,
|
||||
dp_rank,
|
||||
)
|
||||
|
||||
|
||||
def check_ubatch_thresholds(
|
||||
config: ParallelConfig, num_tokens: int, uniform_decode: bool
|
||||
) -> bool:
|
||||
if not config.enable_dbo:
|
||||
return False
|
||||
if uniform_decode:
|
||||
return num_tokens >= config.dbo_decode_token_threshold
|
||||
else:
|
||||
return num_tokens >= config.dbo_prefill_token_threshold
|
||||
|
||||
|
||||
def get_dp_padding_ubatch(
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
should_attempt_ubatching: bool,
|
||||
vllm_config: VllmConfig,
|
||||
) -> tuple[bool, Optional[torch.Tensor]]:
|
||||
"""
|
||||
1. Decides if each DP rank is going to microbatch. Either all ranks
|
||||
run with microbatching or none of them do. If this function decides
|
||||
not to run with microbatching. It will "abort" meaning that no padding
|
||||
information will be returned to the caller. It will return (False, None)
|
||||
|
||||
2. Determines the total number of tokens that each rank will run.
|
||||
All ranks will be padded out so that the run with the same number
|
||||
of tokens
|
||||
|
||||
Returns: tuple[
|
||||
should_ubatch: Are all DP ranks going to microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
None if should_ubatch if False
|
||||
]
|
||||
|
||||
"""
|
||||
assert num_tokens_padded >= num_tokens_unpadded
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
if dp_size == 1:
|
||||
# Early exit.
|
||||
return False, None
|
||||
|
||||
# If this DP rank doesn't want to attempt microbatching
|
||||
if not should_attempt_ubatching:
|
||||
(should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens(
|
||||
False, 0, 0, vllm_config
|
||||
)
|
||||
assert should_ubatch is False
|
||||
assert num_tokens_across_dp is None
|
||||
return should_ubatch, num_tokens_across_dp
|
||||
|
||||
# Round up to the next multiple of two for even divisibility
|
||||
num_tokens_padded = round_up(num_tokens_padded, 2)
|
||||
num_tokens_per_ubatch = num_tokens_padded // 2
|
||||
should_ubatch = True
|
||||
|
||||
# Sanity Check that the existing padding isn't giving us an empty second
|
||||
# ubatch. Abort if so
|
||||
if is_second_ubatch_empty(num_tokens_unpadded, num_tokens_padded):
|
||||
logger.debug(
|
||||
"Empty second µbatch detected: unpadded tokens: %s, padded tokens: %s",
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
)
|
||||
should_ubatch = False
|
||||
|
||||
# Note that we compute the number of padded tokens per ubatch
|
||||
(should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens(
|
||||
should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch, vllm_config
|
||||
)
|
||||
if not should_ubatch:
|
||||
assert num_tokens_across_dp is None
|
||||
return should_ubatch, num_tokens_across_dp
|
||||
|
||||
assert num_tokens_across_dp is not None
|
||||
|
||||
max_tokens_across_dp_cpu = int(torch.max(num_tokens_across_dp).item())
|
||||
num_tokens_after_padding = torch.tensor(
|
||||
[max_tokens_across_dp_cpu] * dp_size, device="cpu", dtype=torch.int32
|
||||
)
|
||||
return should_ubatch, num_tokens_after_padding
|
||||
|
||||
|
||||
def create_ubatch_slices(
|
||||
num_scheduled_tokens: np.ndarray, split_point: int
|
||||
) -> UBatchSlices:
|
||||
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
|
||||
# in cu_num_tokens directly (i.e. query_start_loc)
|
||||
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
|
||||
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
|
||||
|
||||
first_ubatch_token_slice = slice(0, split_point)
|
||||
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
|
||||
|
||||
# Determine request slices using exclusive stop semantics
|
||||
# First ubatch includes requests whose tokens overlap [0, split_point)
|
||||
first_ubatch_req_stop = int(
|
||||
np.searchsorted(cu_num_tokens, split_point, side="left")
|
||||
)
|
||||
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
|
||||
|
||||
# Second ubatch starts at the request that contains the split_point
|
||||
# or the request starting exactly at split_point (if on boundary)
|
||||
second_ubatch_req_start = int(
|
||||
np.searchsorted(cu_num_tokens, split_point, side="right") - 1
|
||||
)
|
||||
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
|
||||
|
||||
return [
|
||||
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
|
||||
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
|
||||
]
|
||||
|
||||
|
||||
def ubatch_split(
|
||||
num_scheduled_tokens_per_request: np.ndarray,
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
uniform_decode: bool,
|
||||
vllm_config: VllmConfig,
|
||||
) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Coordinates amongst all DP ranks to determine if and how the full batch
|
||||
should be split into microbatches.
|
||||
|
||||
Returns: tuple[
|
||||
ubatch_slices: if this is set then all DP ranks have agreed to
|
||||
microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
None if ubatch_slices is None
|
||||
]
|
||||
|
||||
"""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
# Don't bother with the should_ubatch handshaking unless microbatching
|
||||
# is enabled
|
||||
if not parallel_config.enable_dbo:
|
||||
return (None, None)
|
||||
|
||||
# Check preconditions for microbatching
|
||||
should_attempt_ubatching = check_ubatch_thresholds(
|
||||
parallel_config,
|
||||
num_tokens_unpadded,
|
||||
uniform_decode=uniform_decode,
|
||||
)
|
||||
|
||||
# Don't microbatch unless every other DP worker is also microbatching
|
||||
should_ubatch, num_tokens_after_padding = get_dp_padding_ubatch(
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
should_attempt_ubatching,
|
||||
vllm_config,
|
||||
)
|
||||
|
||||
if not should_ubatch:
|
||||
return (None, None)
|
||||
|
||||
# This doesn't actually pad the ubatch slices. It just initializes the
|
||||
# split point to the padded value so that padding can be applied
|
||||
# to the second ubatch in pad_out_ubatch_slice after attention
|
||||
# metadata creation
|
||||
assert num_tokens_after_padding is not None
|
||||
token_split_point = int(num_tokens_after_padding[0].item())
|
||||
|
||||
ubatch_slices = create_ubatch_slices(
|
||||
num_scheduled_tokens_per_request, token_split_point
|
||||
)
|
||||
|
||||
return (ubatch_slices, num_tokens_after_padding)
|
||||
@ -2,8 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class UBatchSlice:
|
||||
@ -24,7 +27,47 @@ class UBatchSlice:
|
||||
UBatchSlices: TypeAlias = list[UBatchSlice]
|
||||
|
||||
|
||||
def is_second_ubatch_empty(
|
||||
orig_num_tokens_per_ubatch: int, padded_num_tokens_per_ubatch: int
|
||||
def is_second_ubatch_empty(orig_num_tokens: int, padded_num_tokens: int) -> bool:
|
||||
return (padded_num_tokens // 2) >= orig_num_tokens
|
||||
|
||||
|
||||
def check_ubatch_thresholds(
|
||||
config: ParallelConfig, num_tokens: int, uniform_decode: bool
|
||||
) -> bool:
|
||||
return padded_num_tokens_per_ubatch >= 2 * orig_num_tokens_per_ubatch
|
||||
if not config.enable_dbo:
|
||||
return False
|
||||
if uniform_decode:
|
||||
return num_tokens >= config.dbo_decode_token_threshold
|
||||
else:
|
||||
return num_tokens >= config.dbo_prefill_token_threshold
|
||||
|
||||
|
||||
def create_ubatch_slices(
|
||||
num_scheduled_tokens: np.ndarray, split_point: int
|
||||
) -> UBatchSlices:
|
||||
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
|
||||
# in cu_num_tokens directly (i.e. query_start_loc)
|
||||
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
|
||||
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
|
||||
|
||||
first_ubatch_token_slice = slice(0, split_point)
|
||||
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
|
||||
|
||||
# Determine request slices using exclusive stop semantics
|
||||
# First ubatch includes requests whose tokens overlap [0, split_point)
|
||||
first_ubatch_req_stop = int(
|
||||
np.searchsorted(cu_num_tokens, split_point, side="left")
|
||||
)
|
||||
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
|
||||
|
||||
# Second ubatch starts at the request that contains the split_point
|
||||
# or the request starting exactly at split_point (if on boundary)
|
||||
second_ubatch_req_start = int(
|
||||
np.searchsorted(cu_num_tokens, split_point, side="right") - 1
|
||||
)
|
||||
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
|
||||
|
||||
return [
|
||||
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
|
||||
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
|
||||
]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user