mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 16:57:19 +08:00
Fix bucketing
This commit is contained in:
parent
b15db234ba
commit
57690a9c09
@ -598,6 +598,13 @@ class Scheduler:
|
||||
|
||||
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
|
||||
while self._passed_delay(time.time()) and waiting_queue:
|
||||
# FIXME(woosuk): The TPU backend only supports up to 4 sequence
|
||||
# groups in a single batch.
|
||||
MAX_BATCH_SIZE = 1
|
||||
if len(seq_groups) == MAX_BATCH_SIZE:
|
||||
break
|
||||
assert len(seq_groups) < MAX_BATCH_SIZE
|
||||
|
||||
seq_group = waiting_queue[0]
|
||||
|
||||
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
||||
@ -666,10 +673,6 @@ class Scheduler:
|
||||
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
|
||||
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
|
||||
|
||||
# FIXME(woosuk): For TPUs, we want to schedule only one prompt
|
||||
# per scheduling step.
|
||||
break
|
||||
|
||||
# Queue requests that couldn't be scheduled.
|
||||
waiting_queue.extendleft(leftover_waiting_sequences)
|
||||
if len(seq_groups) > 0:
|
||||
|
||||
@ -12,11 +12,6 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import pad_to_max_length
|
||||
|
||||
# DELETE
|
||||
from jax_smi import initialise_tracking
|
||||
|
||||
initialise_tracking()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
@ -91,7 +86,7 @@ class TPUModelRunner:
|
||||
|
||||
# Decode
|
||||
start = time.time()
|
||||
for batch_size in [1, 2, 4] + [8 * i for i in range(1, 17)]:
|
||||
for batch_size in [1, 2, 4, 8] + [16 * i for i in range(1, 17)]:
|
||||
seq_len = 1
|
||||
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||
position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||
@ -325,8 +320,10 @@ def _get_padded_batch_size(batch_size: int) -> int:
|
||||
return batch_size
|
||||
elif batch_size <= 4:
|
||||
return 4
|
||||
elif batch_size <= 8:
|
||||
return 8
|
||||
else:
|
||||
return ((batch_size + 7) // 8) * 8
|
||||
return ((batch_size + 15) // 16) * 16
|
||||
|
||||
|
||||
import functools
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user