# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import deque import pytest from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import RequestStatus from .utils import create_requests, create_scheduler def _make_model_runner_output( scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput: req_ids = list(scheduler_output.num_scheduled_tokens.keys()) return ModelRunnerOutput( req_ids=req_ids, req_id_to_index={ req_id: i for i, req_id in enumerate(req_ids) }, sampled_token_ids=[[i] for i in range(len(req_ids))], spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) @pytest.mark.parametrize("max_tokens", [1, 2, 3, 5]) def test_stop_by_max_tokens(max_tokens: int): scheduler = create_scheduler(async_scheduling=True) requests = create_requests(num_requests=2, max_tokens=max_tokens) req0, req1 = requests sched_outputs: deque[SchedulerOutput] = deque() scheduler.add_request(req0) sched_outputs.append(scheduler.schedule()) scheduler.add_request(req1) sched_outputs.append(scheduler.schedule()) while sched_outputs: sched_output = sched_outputs.popleft() model_runner_output = _make_model_runner_output(sched_output) scheduler.update_from_output(sched_output, model_runner_output) sched_output = scheduler.schedule() if sched_output.num_scheduled_tokens: sched_outputs.append(sched_output) assert scheduler.get_num_unfinished_requests() == 0 assert req0.num_output_tokens == max_tokens assert req1.num_output_tokens == max_tokens def test_abort(): scheduler = create_scheduler(async_scheduling=True) requests = create_requests(num_requests=10, max_tokens=20) for req in requests: scheduler.add_request(req) sched_outputs: deque[SchedulerOutput] = deque() sched_outputs.append(scheduler.schedule()) sched_outputs.append(scheduler.schedule()) abort_order = [0, 8, 3, 1, 6, 4, 2, 5, 7, 9] abort_order_copy = abort_order.copy() def abort_request(): if not abort_order: return req = requests[abort_order.pop(0)] scheduler.finish_requests(req.request_id, RequestStatus.FINISHED_ABORTED) while sched_outputs: # Abort a scheduled request. abort_request() sched_output = sched_outputs.popleft() model_runner_output = _make_model_runner_output(sched_output) scheduler.update_from_output(sched_output, model_runner_output) sched_output = scheduler.schedule() if sched_output.num_scheduled_tokens: sched_outputs.append(sched_output) for i, req in enumerate(requests): assert req.status == RequestStatus.FINISHED_ABORTED assert req.num_output_tokens == abort_order_copy.index(i) def test_preempt(): scheduler = create_scheduler(async_scheduling=True) requests = create_requests(num_requests=10, max_tokens=20) for req in requests: scheduler.add_request(req) sched_outputs: deque[SchedulerOutput] = deque() sched_outputs.append(scheduler.schedule()) sched_outputs.append(scheduler.schedule()) abort_order = [0, 8, 3, 1, 6, 4, 2, 5, 7, 9] abort_order_copy = abort_order.copy() def abort_request(): if not abort_order: return req = requests[abort_order.pop(0)] scheduler.finish_requests(req.request_id, RequestStatus.FINISHED_ABORTED) while sched_outputs: # Abort a scheduled request. abort_request() sched_output = sched_outputs.popleft() model_runner_output = _make_model_runner_output(sched_output) scheduler.update_from_output(sched_output, model_runner_output) sched_output = scheduler.schedule() if sched_output.num_scheduled_tokens: sched_outputs.append(sched_output) for i, req in enumerate(requests): assert req.status == RequestStatus.FINISHED_ABORTED assert req.num_output_tokens == abort_order_copy.index(i) def test_prefix_caching_for_prefill_dedup(): CHUNK_SIZE = 1000 BLOCK_SIZE = 16 num_prompt_tokens = 100 scheduler = create_scheduler(async_scheduling=True, max_num_batched_tokens=CHUNK_SIZE, enable_prefix_caching=True, block_size=BLOCK_SIZE) requests = create_requests(num_requests=5, num_tokens=num_prompt_tokens, max_tokens=3, same_prompt=True) requests_copy = requests.copy() # Two requests with the same prompt. req0 = requests.pop(0) req1 = requests.pop(0) scheduler.add_request(req0) scheduler.add_request(req1) sched_outputs: deque[SchedulerOutput] = deque() sched_output = scheduler.schedule() sched_outputs.append(sched_output) # Make sure prefix caching de-duplicates the prompts in the same step, # so all the blocks except the last are shared between the two requests. assert len(sched_output.num_scheduled_tokens) == 2 num_blocks = num_prompt_tokens // BLOCK_SIZE assert req0.num_cached_tokens == 0 assert req1.num_cached_tokens >= num_blocks * BLOCK_SIZE sched_outputs.append(scheduler.schedule()) while sched_outputs: if requests: scheduler.add_request(requests.pop(0)) sched_output = sched_outputs.popleft() model_runner_output = _make_model_runner_output(sched_output) scheduler.update_from_output(sched_output, model_runner_output) sched_output = scheduler.schedule() if sched_output.num_scheduled_tokens: sched_outputs.append(sched_output) # Other requests scheduled after the two requests should also get # prefix cache hit. assert scheduler.get_num_unfinished_requests() == 0 for req in requests_copy[1:]: assert req.num_cached_tokens >= num_blocks * BLOCK_SIZE def test_prefix_caching_for_multi_turn(): CHUNK_SIZE = 1000 BLOCK_SIZE = 16 num_prompt_tokens = 100 num_output_tokens = 200 scheduler = create_scheduler(async_scheduling=True, max_num_batched_tokens=CHUNK_SIZE, enable_prefix_caching=True, block_size=BLOCK_SIZE) requests = create_requests(num_requests=5, num_tokens=num_prompt_tokens, max_tokens=num_output_tokens) for req in requests: scheduler.add_request(req) sched_outputs: deque[SchedulerOutput] = deque() sched_outputs.append(scheduler.schedule()) sched_outputs.append(scheduler.schedule()) # Process the requests. while sched_outputs: sched_output = sched_outputs.popleft() model_runner_output = _make_model_runner_output(sched_output) scheduler.update_from_output(sched_output, model_runner_output) sched_output = scheduler.schedule() if sched_output.num_scheduled_tokens: sched_outputs.append(sched_output) assert scheduler.get_num_unfinished_requests() == 0 # Create next-turn requests whose prompts are the full output of the # previous turn. next_turn_requests = create_requests( num_requests=5, num_tokens=num_prompt_tokens + num_output_tokens, max_tokens=num_output_tokens, ) for i, req in enumerate(next_turn_requests): req.prompt_token_ids = (requests[i].prompt_token_ids + list(requests[i].output_token_ids)) # Schedule the next-turn requests. for req in next_turn_requests: scheduler.add_request(req) sched_outputs.append(scheduler.schedule()) # Make sure the next-turn requests get prefix cache hit by the previous # requests. for req in next_turn_requests: assert (req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE * BLOCK_SIZE)