diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 78cd8a179076a..0e162d2b921d5 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -598,6 +598,13 @@ class Scheduler: leftover_waiting_sequences: Deque[SequenceGroup] = deque() while self._passed_delay(time.time()) and waiting_queue: + # FIXME(woosuk): The TPU backend only supports up to 4 sequence + # groups in a single batch. + MAX_BATCH_SIZE = 1 + if len(seq_groups) == MAX_BATCH_SIZE: + break + assert len(seq_groups) < MAX_BATCH_SIZE + seq_group = waiting_queue[0] waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) @@ -666,10 +673,6 @@ class Scheduler: budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) budget.add_num_seqs(seq_group.request_id, num_new_seqs) - # FIXME(woosuk): For TPUs, we want to schedule only one prompt - # per scheduling step. - break - # Queue requests that couldn't be scheduled. waiting_queue.extendleft(leftover_waiting_sequences) if len(seq_groups) > 0: diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 2bd64005de5cb..73376070b4b3f 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -12,11 +12,6 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import pad_to_max_length -# DELETE -from jax_smi import initialise_tracking - -initialise_tracking() - logger = init_logger(__name__) _PAD_SLOT_ID = -1 @@ -91,7 +86,7 @@ class TPUModelRunner: # Decode start = time.time() - for batch_size in [1, 2, 4] + [8 * i for i in range(1, 17)]: + for batch_size in [1, 2, 4, 8] + [16 * i for i in range(1, 17)]: seq_len = 1 token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32) position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32) @@ -325,8 +320,10 @@ def _get_padded_batch_size(batch_size: int) -> int: return batch_size elif batch_size <= 4: return 4 + elif batch_size <= 8: + return 8 else: - return ((batch_size + 7) // 8) * 8 + return ((batch_size + 15) // 16) * 16 import functools