mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 06:17:52 +08:00
Fix bucketing
This commit is contained in:
parent
b15db234ba
commit
57690a9c09
@ -598,6 +598,13 @@ class Scheduler:
|
|||||||
|
|
||||||
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
|
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
|
||||||
while self._passed_delay(time.time()) and waiting_queue:
|
while self._passed_delay(time.time()) and waiting_queue:
|
||||||
|
# FIXME(woosuk): The TPU backend only supports up to 4 sequence
|
||||||
|
# groups in a single batch.
|
||||||
|
MAX_BATCH_SIZE = 1
|
||||||
|
if len(seq_groups) == MAX_BATCH_SIZE:
|
||||||
|
break
|
||||||
|
assert len(seq_groups) < MAX_BATCH_SIZE
|
||||||
|
|
||||||
seq_group = waiting_queue[0]
|
seq_group = waiting_queue[0]
|
||||||
|
|
||||||
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
||||||
@ -666,10 +673,6 @@ class Scheduler:
|
|||||||
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
|
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
|
||||||
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
|
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
|
||||||
|
|
||||||
# FIXME(woosuk): For TPUs, we want to schedule only one prompt
|
|
||||||
# per scheduling step.
|
|
||||||
break
|
|
||||||
|
|
||||||
# Queue requests that couldn't be scheduled.
|
# Queue requests that couldn't be scheduled.
|
||||||
waiting_queue.extendleft(leftover_waiting_sequences)
|
waiting_queue.extendleft(leftover_waiting_sequences)
|
||||||
if len(seq_groups) > 0:
|
if len(seq_groups) > 0:
|
||||||
|
|||||||
@ -12,11 +12,6 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.utils import pad_to_max_length
|
from vllm.utils import pad_to_max_length
|
||||||
|
|
||||||
# DELETE
|
|
||||||
from jax_smi import initialise_tracking
|
|
||||||
|
|
||||||
initialise_tracking()
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_PAD_SLOT_ID = -1
|
_PAD_SLOT_ID = -1
|
||||||
@ -91,7 +86,7 @@ class TPUModelRunner:
|
|||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
start = time.time()
|
start = time.time()
|
||||||
for batch_size in [1, 2, 4] + [8 * i for i in range(1, 17)]:
|
for batch_size in [1, 2, 4, 8] + [16 * i for i in range(1, 17)]:
|
||||||
seq_len = 1
|
seq_len = 1
|
||||||
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||||
position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||||
@ -325,8 +320,10 @@ def _get_padded_batch_size(batch_size: int) -> int:
|
|||||||
return batch_size
|
return batch_size
|
||||||
elif batch_size <= 4:
|
elif batch_size <= 4:
|
||||||
return 4
|
return 4
|
||||||
|
elif batch_size <= 8:
|
||||||
|
return 8
|
||||||
else:
|
else:
|
||||||
return ((batch_size + 7) // 8) * 8
|
return ((batch_size + 15) // 16) * 16
|
||||||
|
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user