mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 00:45:23 +08:00
[BugFix] Fix DBO assert assert B_block_table == B_q (#29933)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
48a5fff66e
commit
c8ab988b15
@ -13,7 +13,7 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
split_attn_metadata,
|
split_attn_metadata,
|
||||||
split_decodes_and_prefills,
|
split_decodes_and_prefills,
|
||||||
)
|
)
|
||||||
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
|
from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -294,8 +294,14 @@ def test_prefill_split_across_ubatches(
|
|||||||
qsl_np = common.query_start_loc_cpu.numpy()
|
qsl_np = common.query_start_loc_cpu.numpy()
|
||||||
num_tokens = common.num_actual_tokens
|
num_tokens = common.num_actual_tokens
|
||||||
|
|
||||||
ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point)
|
ubatch_slices, _ = maybe_create_ubatch_slices(
|
||||||
assert len(ubatch_slices) == 2
|
True,
|
||||||
|
num_scheduled_tokens,
|
||||||
|
num_tokens,
|
||||||
|
batch_spec.batch_size,
|
||||||
|
split_point=split_point,
|
||||||
|
)
|
||||||
|
assert ubatch_slices is not None and len(ubatch_slices) == 2
|
||||||
|
|
||||||
first_meta = _make_metadata_with_slice(ubatch_slices[0], common)
|
first_meta = _make_metadata_with_slice(ubatch_slices[0], common)
|
||||||
second_meta = _make_metadata_with_slice(ubatch_slices[1], common)
|
second_meta = _make_metadata_with_slice(ubatch_slices[1], common)
|
||||||
|
|||||||
@ -1258,7 +1258,7 @@ class EagleProposer:
|
|||||||
num_tokens_padded: int,
|
num_tokens_padded: int,
|
||||||
) -> tuple[int, torch.Tensor]:
|
) -> tuple[int, torch.Tensor]:
|
||||||
# TODO(Flechman): support DBO ubatching
|
# TODO(Flechman): support DBO ubatching
|
||||||
ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp(
|
should_ubatch, num_toks_across_dp = coordinate_batch_across_dp(
|
||||||
num_tokens_unpadded=num_tokens_unpadded,
|
num_tokens_unpadded=num_tokens_unpadded,
|
||||||
parallel_config=self.vllm_config.parallel_config,
|
parallel_config=self.vllm_config.parallel_config,
|
||||||
allow_microbatching=False,
|
allow_microbatching=False,
|
||||||
@ -1267,7 +1267,7 @@ class EagleProposer:
|
|||||||
uniform_decode=None,
|
uniform_decode=None,
|
||||||
num_scheduled_tokens_per_request=None,
|
num_scheduled_tokens_per_request=None,
|
||||||
)
|
)
|
||||||
assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE"
|
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
|
||||||
|
|
||||||
num_tokens_dp_padded = num_tokens_padded
|
num_tokens_dp_padded = num_tokens_padded
|
||||||
if num_toks_across_dp is not None:
|
if num_toks_across_dp is not None:
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -9,10 +10,7 @@ from vllm.config import ParallelConfig
|
|||||||
from vllm.distributed.parallel_state import get_dp_group
|
from vllm.distributed.parallel_state import get_dp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.worker.ubatch_utils import (
|
from vllm.v1.worker.ubatch_utils import (
|
||||||
UBatchSlice,
|
|
||||||
UBatchSlices,
|
|
||||||
check_ubatch_thresholds,
|
check_ubatch_thresholds,
|
||||||
create_ubatch_slices,
|
|
||||||
is_second_ubatch_empty,
|
is_second_ubatch_empty,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -91,20 +89,6 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
|
|||||||
return num_tokens_across_dp.cpu()
|
return num_tokens_across_dp.cpu()
|
||||||
|
|
||||||
|
|
||||||
# This just pads the second ubatch slice out to the total number of tokens
|
|
||||||
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
|
|
||||||
def _pad_out_ubatch_slice(
|
|
||||||
ubatch_slices: UBatchSlices, num_total_tokens: int
|
|
||||||
) -> UBatchSlices:
|
|
||||||
padded_second_token_slice = slice(
|
|
||||||
ubatch_slices[1].token_slice.start, num_total_tokens
|
|
||||||
)
|
|
||||||
ubatch_slices[1] = UBatchSlice(
|
|
||||||
ubatch_slices[1].request_slice, padded_second_token_slice
|
|
||||||
)
|
|
||||||
return ubatch_slices
|
|
||||||
|
|
||||||
|
|
||||||
def _synchronize_dp_ranks(
|
def _synchronize_dp_ranks(
|
||||||
num_tokens_unpadded: int,
|
num_tokens_unpadded: int,
|
||||||
num_tokens_padded: int,
|
num_tokens_padded: int,
|
||||||
@ -175,7 +159,7 @@ def coordinate_batch_across_dp(
|
|||||||
num_tokens_padded: int | None = None,
|
num_tokens_padded: int | None = None,
|
||||||
uniform_decode: bool | None = None,
|
uniform_decode: bool | None = None,
|
||||||
num_scheduled_tokens_per_request: np.ndarray | None = None,
|
num_scheduled_tokens_per_request: np.ndarray | None = None,
|
||||||
) -> tuple[UBatchSlices | None, torch.Tensor | None]:
|
) -> tuple[bool, torch.Tensor | None]:
|
||||||
"""
|
"""
|
||||||
Coordinates amongst all DP ranks to determine if and how the full batch
|
Coordinates amongst all DP ranks to determine if and how the full batch
|
||||||
should be split into microbatches.
|
should be split into microbatches.
|
||||||
@ -204,7 +188,7 @@ def coordinate_batch_across_dp(
|
|||||||
"""
|
"""
|
||||||
if parallel_config.data_parallel_size == 1:
|
if parallel_config.data_parallel_size == 1:
|
||||||
# Early exit.
|
# Early exit.
|
||||||
return None, None
|
return False, None
|
||||||
|
|
||||||
# If the caller has explicitly enabled microbatching.
|
# If the caller has explicitly enabled microbatching.
|
||||||
should_attempt_ubatching = False
|
should_attempt_ubatching = False
|
||||||
@ -228,23 +212,4 @@ def coordinate_batch_across_dp(
|
|||||||
parallel_config,
|
parallel_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Don't microbatch unless every other DP worker is also microbatching
|
return (should_ubatch, num_tokens_after_padding)
|
||||||
if not should_ubatch:
|
|
||||||
return (None, num_tokens_after_padding)
|
|
||||||
|
|
||||||
# This doesn't actually pad the ubatch slices. It just initializes the
|
|
||||||
# split point to the padded value so that padding can be applied
|
|
||||||
# to the second ubatch in pad_out_ubatch_slice after attention
|
|
||||||
# metadata creation
|
|
||||||
assert num_tokens_after_padding is not None
|
|
||||||
num_tokens_padded = int(num_tokens_after_padding[0].item())
|
|
||||||
token_split_point = int(num_tokens_padded) // 2
|
|
||||||
|
|
||||||
assert num_scheduled_tokens_per_request is not None
|
|
||||||
ubatch_slices = create_ubatch_slices(
|
|
||||||
num_scheduled_tokens_per_request, token_split_point
|
|
||||||
)
|
|
||||||
ubatch_slices = _pad_out_ubatch_slice(ubatch_slices, num_tokens_padded)
|
|
||||||
assert sum(s.num_tokens for s in ubatch_slices) == num_tokens_padded
|
|
||||||
|
|
||||||
return (ubatch_slices, num_tokens_after_padding)
|
|
||||||
|
|||||||
@ -153,6 +153,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|||||||
from vllm.v1.worker.ubatch_utils import (
|
from vllm.v1.worker.ubatch_utils import (
|
||||||
UBatchSlices,
|
UBatchSlices,
|
||||||
check_ubatch_thresholds,
|
check_ubatch_thresholds,
|
||||||
|
maybe_create_ubatch_slices,
|
||||||
)
|
)
|
||||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||||
|
|
||||||
@ -2743,7 +2744,7 @@ class GPUModelRunner(
|
|||||||
) -> tuple[
|
) -> tuple[
|
||||||
CUDAGraphMode,
|
CUDAGraphMode,
|
||||||
BatchDescriptor,
|
BatchDescriptor,
|
||||||
UBatchSlices | None,
|
bool,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
CUDAGraphStat | None,
|
CUDAGraphStat | None,
|
||||||
]:
|
]:
|
||||||
@ -2779,7 +2780,7 @@ class GPUModelRunner(
|
|||||||
|
|
||||||
# Extra coordination when running data-parallel since we need to coordinate
|
# Extra coordination when running data-parallel since we need to coordinate
|
||||||
# across ranks
|
# across ranks
|
||||||
ubatch_slices, num_tokens_across_dp = None, None
|
should_ubatch, num_tokens_across_dp = False, None
|
||||||
if self.vllm_config.parallel_config.data_parallel_size > 1:
|
if self.vllm_config.parallel_config.data_parallel_size > 1:
|
||||||
# Disable DP padding when running eager to avoid excessive padding when
|
# Disable DP padding when running eager to avoid excessive padding when
|
||||||
# running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
|
# running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
|
||||||
@ -2789,8 +2790,8 @@ class GPUModelRunner(
|
|||||||
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||||
)
|
)
|
||||||
|
|
||||||
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
|
should_ubatch, num_tokens_across_dp = coordinate_batch_across_dp(
|
||||||
num_tokens_unpadded=num_tokens_padded,
|
num_tokens_unpadded=num_tokens,
|
||||||
parallel_config=self.parallel_config,
|
parallel_config=self.parallel_config,
|
||||||
allow_microbatching=allow_microbatching,
|
allow_microbatching=allow_microbatching,
|
||||||
allow_dp_padding=allow_dp_padding,
|
allow_dp_padding=allow_dp_padding,
|
||||||
@ -2822,7 +2823,7 @@ class GPUModelRunner(
|
|||||||
return (
|
return (
|
||||||
cudagraph_mode,
|
cudagraph_mode,
|
||||||
batch_descriptor,
|
batch_descriptor,
|
||||||
ubatch_slices,
|
should_ubatch,
|
||||||
num_tokens_across_dp,
|
num_tokens_across_dp,
|
||||||
cudagraph_stats,
|
cudagraph_stats,
|
||||||
)
|
)
|
||||||
@ -2921,7 +2922,7 @@ class GPUModelRunner(
|
|||||||
(
|
(
|
||||||
cudagraph_mode,
|
cudagraph_mode,
|
||||||
batch_desc,
|
batch_desc,
|
||||||
ubatch_slices,
|
should_ubatch,
|
||||||
num_tokens_across_dp,
|
num_tokens_across_dp,
|
||||||
cudagraph_stats,
|
cudagraph_stats,
|
||||||
) = self._determine_batch_execution_and_padding(
|
) = self._determine_batch_execution_and_padding(
|
||||||
@ -2934,10 +2935,10 @@ class GPUModelRunner(
|
|||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
|
"Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
|
||||||
"ubatch_slices: %s, num_tokens_across_dp: %s",
|
"should_ubatch: %s, num_tokens_across_dp: %s",
|
||||||
cudagraph_mode,
|
cudagraph_mode,
|
||||||
batch_desc,
|
batch_desc,
|
||||||
ubatch_slices,
|
should_ubatch,
|
||||||
num_tokens_across_dp,
|
num_tokens_across_dp,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2945,9 +2946,17 @@ class GPUModelRunner(
|
|||||||
num_reqs_padded = (
|
num_reqs_padded = (
|
||||||
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
|
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
|
||||||
)
|
)
|
||||||
|
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
|
||||||
|
should_ubatch,
|
||||||
|
num_scheduled_tokens_np,
|
||||||
|
num_tokens_padded,
|
||||||
|
num_reqs_padded,
|
||||||
|
)
|
||||||
|
|
||||||
|
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
|
||||||
|
|
||||||
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
|
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||||
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
|
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
|
||||||
|
|
||||||
(attn_metadata, spec_decode_common_attn_metadata) = (
|
(attn_metadata, spec_decode_common_attn_metadata) = (
|
||||||
self._build_attention_metadata(
|
self._build_attention_metadata(
|
||||||
@ -2956,7 +2965,7 @@ class GPUModelRunner(
|
|||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_reqs_padded=num_reqs_padded if pad_attn else None,
|
num_reqs_padded=num_reqs_padded if pad_attn else None,
|
||||||
max_query_len=max_num_scheduled_tokens,
|
max_query_len=max_num_scheduled_tokens,
|
||||||
ubatch_slices=ubatch_slices,
|
ubatch_slices=ubatch_slices_attn,
|
||||||
logits_indices=logits_indices,
|
logits_indices=logits_indices,
|
||||||
use_spec_decode=use_spec_decode,
|
use_spec_decode=use_spec_decode,
|
||||||
num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
|
num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
|
||||||
@ -2993,7 +3002,7 @@ class GPUModelRunner(
|
|||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
cudagraph_runtime_mode=cudagraph_mode,
|
cudagraph_runtime_mode=cudagraph_mode,
|
||||||
batch_descriptor=batch_desc,
|
batch_descriptor=batch_desc,
|
||||||
ubatch_slices=ubatch_slices,
|
ubatch_slices=ubatch_slices_padded,
|
||||||
),
|
),
|
||||||
record_function_or_nullcontext("gpu_model_runner: forward"),
|
record_function_or_nullcontext("gpu_model_runner: forward"),
|
||||||
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
|
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
|
||||||
@ -3945,7 +3954,7 @@ class GPUModelRunner(
|
|||||||
|
|
||||||
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
|
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
|
||||||
|
|
||||||
_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp, _ = (
|
_cudagraph_mode, batch_desc, should_ubatch, num_tokens_across_dp, _ = (
|
||||||
self._determine_batch_execution_and_padding(
|
self._determine_batch_execution_and_padding(
|
||||||
num_tokens=num_tokens_unpadded,
|
num_tokens=num_tokens_unpadded,
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
@ -3979,6 +3988,9 @@ class GPUModelRunner(
|
|||||||
num_reqs_padded = (
|
num_reqs_padded = (
|
||||||
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
|
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
|
||||||
)
|
)
|
||||||
|
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
|
||||||
|
should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded
|
||||||
|
)
|
||||||
|
|
||||||
attn_metadata: PerLayerAttnMetadata | None = None
|
attn_metadata: PerLayerAttnMetadata | None = None
|
||||||
|
|
||||||
@ -4000,11 +4012,12 @@ class GPUModelRunner(
|
|||||||
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
|
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
|
||||||
self.query_start_loc.copy_to_gpu()
|
self.query_start_loc.copy_to_gpu()
|
||||||
|
|
||||||
|
pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
|
||||||
attn_metadata, _ = self._build_attention_metadata(
|
attn_metadata, _ = self._build_attention_metadata(
|
||||||
num_tokens=num_tokens_unpadded,
|
num_tokens=num_tokens_unpadded,
|
||||||
num_reqs=num_reqs_padded,
|
num_reqs=num_reqs_padded,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
ubatch_slices=ubatch_slices,
|
ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices,
|
||||||
for_cudagraph_capture=is_graph_capturing,
|
for_cudagraph_capture=is_graph_capturing,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -4056,11 +4069,11 @@ class GPUModelRunner(
|
|||||||
num_tokens_padded, None, False
|
num_tokens_padded, None, False
|
||||||
)
|
)
|
||||||
|
|
||||||
if ubatch_slices is not None:
|
if ubatch_slices_padded is not None:
|
||||||
# Adjust values to reflect a single ubatch.
|
# Adjust values to reflect a single ubatch.
|
||||||
# TODO(sage,lucas): this is cruft that should be addressed in
|
# TODO(sage,lucas): this is cruft that should be addressed in
|
||||||
# the padding refactor.
|
# the padding refactor.
|
||||||
num_tokens_padded = ubatch_slices[0].num_tokens
|
num_tokens_padded = ubatch_slices_padded[0].num_tokens
|
||||||
if num_tokens_across_dp is not None:
|
if num_tokens_across_dp is not None:
|
||||||
num_tokens_across_dp[:] = num_tokens_padded
|
num_tokens_across_dp[:] = num_tokens_padded
|
||||||
|
|
||||||
@ -4073,7 +4086,7 @@ class GPUModelRunner(
|
|||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
batch_descriptor=batch_desc,
|
batch_descriptor=batch_desc,
|
||||||
ubatch_slices=ubatch_slices,
|
ubatch_slices=ubatch_slices_padded,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
|
|||||||
@ -42,9 +42,37 @@ def check_ubatch_thresholds(
|
|||||||
return num_tokens >= config.dbo_prefill_token_threshold
|
return num_tokens >= config.dbo_prefill_token_threshold
|
||||||
|
|
||||||
|
|
||||||
def create_ubatch_slices(
|
# This just pads the second ubatch slice out to the total number of tokens
|
||||||
num_scheduled_tokens: np.ndarray, split_point: int
|
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
|
||||||
|
def _pad_out_ubatch_slices(
|
||||||
|
ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
|
||||||
) -> UBatchSlices:
|
) -> UBatchSlices:
|
||||||
|
# TODO(lucas): handle empty second ubatch
|
||||||
|
padded_second_request_slice = slice(
|
||||||
|
ubatch_slices[1].request_slice.start, num_reqs_padded
|
||||||
|
)
|
||||||
|
padded_second_token_slice = slice(
|
||||||
|
ubatch_slices[1].token_slice.start, num_total_tokens
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
ubatch_slices[0],
|
||||||
|
UBatchSlice(padded_second_request_slice, padded_second_token_slice),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_create_ubatch_slices(
|
||||||
|
should_ubatch: bool,
|
||||||
|
num_scheduled_tokens: np.ndarray,
|
||||||
|
num_tokens_padded: int,
|
||||||
|
num_reqs_padded: int,
|
||||||
|
split_point: int | None = None,
|
||||||
|
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
|
||||||
|
if not should_ubatch:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if split_point is None:
|
||||||
|
split_point = int(num_tokens_padded) // 2
|
||||||
|
|
||||||
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
|
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
|
||||||
# in cu_num_tokens directly (i.e. query_start_loc)
|
# in cu_num_tokens directly (i.e. query_start_loc)
|
||||||
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
|
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
|
||||||
@ -67,7 +95,15 @@ def create_ubatch_slices(
|
|||||||
)
|
)
|
||||||
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
|
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
|
||||||
|
|
||||||
return [
|
ubatch_slices = [
|
||||||
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
|
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
|
||||||
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
|
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
ubatch_slices_padded = _pad_out_ubatch_slices(
|
||||||
|
ubatch_slices, num_tokens_padded, num_reqs_padded
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded
|
||||||
|
|
||||||
|
return ubatch_slices, ubatch_slices_padded
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user