mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-06 08:21:49 +08:00
[AsyncScheduling] Don't schedule past request max_tokens (#27922)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
c9f66da8fd
commit
938a81692e
@ -34,15 +34,20 @@ def test_stop_by_max_tokens(max_tokens: int):
|
||||
requests = create_requests(num_requests=2, max_tokens=max_tokens)
|
||||
req0, req1 = requests
|
||||
|
||||
expected_total_num_scheduled_tokens = 0
|
||||
sched_outputs: deque[SchedulerOutput] = deque()
|
||||
scheduler.add_request(req0)
|
||||
sched_outputs.append(scheduler.schedule())
|
||||
expected_total_num_scheduled_tokens += req0.num_prompt_tokens + max_tokens - 1
|
||||
|
||||
scheduler.add_request(req1)
|
||||
sched_outputs.append(scheduler.schedule())
|
||||
expected_total_num_scheduled_tokens += req1.num_prompt_tokens + max_tokens - 1
|
||||
|
||||
total_num_scheduled_tokens = 0
|
||||
while sched_outputs:
|
||||
sched_output = sched_outputs.popleft()
|
||||
total_num_scheduled_tokens += sched_output.total_num_scheduled_tokens
|
||||
model_runner_output = _make_model_runner_output(sched_output)
|
||||
scheduler.update_from_output(sched_output, model_runner_output)
|
||||
|
||||
@ -53,6 +58,8 @@ def test_stop_by_max_tokens(max_tokens: int):
|
||||
assert scheduler.get_num_unfinished_requests() == 0
|
||||
assert req0.num_output_tokens == max_tokens
|
||||
assert req1.num_output_tokens == max_tokens
|
||||
# Ensure we aren't scheduling more tokens than necessary.
|
||||
assert total_num_scheduled_tokens == expected_total_num_scheduled_tokens
|
||||
|
||||
|
||||
def test_abort():
|
||||
|
||||
@ -155,7 +155,6 @@ def test_suffix_decoding_acceptance(
|
||||
)
|
||||
|
||||
# Run several times and check that the accepted tokens increase.
|
||||
spec_llm.chat(test_prompts, sampling_config)
|
||||
num_draft = []
|
||||
num_accept = []
|
||||
for i in range(10): # Run multiple times to warm up the cache.
|
||||
|
||||
@ -217,10 +217,14 @@ class Scheduler(SchedulerInterface):
|
||||
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
|
||||
# Make sure the input position does not exceed the max model len.
|
||||
# This is necessary when using spec decoding.
|
||||
# Make sure the input position does not exceed the max model len or
|
||||
# request's max_tokens.
|
||||
# This is necessary when using spec decoding and/or async scheduling.
|
||||
max_total_tokens = min(
|
||||
request.num_prompt_tokens + request.max_tokens, self.max_model_len
|
||||
)
|
||||
num_new_tokens = min(
|
||||
num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens
|
||||
num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens
|
||||
)
|
||||
|
||||
# Schedule encoder inputs.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user