Fix bucketing

This commit is contained in:
Woosuk Kwon 2024-04-26 07:05:27 +00:00
parent b15db234ba
commit 57690a9c09
2 changed files with 11 additions and 11 deletions

View File

@ -598,6 +598,13 @@ class Scheduler:
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue:
# FIXME(woosuk): The TPU backend only supports up to 4 sequence
# groups in a single batch.
MAX_BATCH_SIZE = 1
if len(seq_groups) == MAX_BATCH_SIZE:
break
assert len(seq_groups) < MAX_BATCH_SIZE
seq_group = waiting_queue[0]
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
@ -666,10 +673,6 @@ class Scheduler:
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
# FIXME(woosuk): For TPUs, we want to schedule only one prompt
# per scheduling step.
break
# Queue requests that couldn't be scheduled.
waiting_queue.extendleft(leftover_waiting_sequences)
if len(seq_groups) > 0:

View File

@ -12,11 +12,6 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import pad_to_max_length
# DELETE
from jax_smi import initialise_tracking
initialise_tracking()
logger = init_logger(__name__)
_PAD_SLOT_ID = -1
@ -91,7 +86,7 @@ class TPUModelRunner:
# Decode
start = time.time()
for batch_size in [1, 2, 4] + [8 * i for i in range(1, 17)]:
for batch_size in [1, 2, 4, 8] + [16 * i for i in range(1, 17)]:
seq_len = 1
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
@ -325,8 +320,10 @@ def _get_padded_batch_size(batch_size: int) -> int:
return batch_size
elif batch_size <= 4:
return 4
elif batch_size <= 8:
return 8
else:
return ((batch_size + 7) // 8) * 8
return ((batch_size + 15) // 16) * 16
import functools