diff --git a/tests/v1/core/test_priority_scheduler_random.py b/tests/v1/core/test_priority_scheduler_random.py new file mode 100644 index 000000000000..b4805be80272 --- /dev/null +++ b/tests/v1/core/test_priority_scheduler_random.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random +import uuid + +import pytest + +from vllm.config import VllmConfig +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalKwargsItem, + PlaceholderRange, +) +from vllm.sampling_params import SamplingParams +from vllm.utils.hashing import get_hash_fn_by_name +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput +from vllm.v1.request import Request + +from .test_scheduler import create_scheduler_with_priority +from .utils import EOS_TOKEN_ID + +pytestmark = pytest.mark.cpu_test + + +def _create_random_request( + max_tokens_range: tuple[int, int], + num_tokens_range: tuple[int, int], + arrival_time_range: tuple[float, float], + priority_range: tuple[int, int], + num_mm_item_range: tuple[int, int], + vllm_config: VllmConfig, +): + max_tokens = random.randint(*max_tokens_range) + num_tokens = random.randint(*num_tokens_range) + priority = random.randint(*priority_range) + arrival_time = random.uniform(*arrival_time_range) + num_mm_item = random.randint(*num_mm_item_range) + + mm_positions: list[PlaceholderRange] = [] + for mm_start in sorted( + random.sample(range(num_tokens), min(num_mm_item, num_tokens)) + ): + if mm_start + 10 > num_tokens: + continue + mm_positions.append(PlaceholderRange(offset=mm_start, length=10)) + + request_id = uuid.uuid4().hex + + sampling_params = SamplingParams( + ignore_eos=False, + max_tokens=max_tokens, + ) + mm_features = [] + for j, position in enumerate(mm_positions): + identifier = f"{request_id}_hash_{j}" + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image", + ) + mm_features.append(mm_feature) + + prompt_token_ids = random.choices(range(100), k=num_tokens) + + caching_hash_fn = get_hash_fn_by_name( + vllm_config.cache_config.prefix_caching_hash_algo + ) + init_none_hash(caching_hash_fn) + block_hasher = get_request_block_hasher( + vllm_config.cache_config.block_size, caching_hash_fn + ) + + request = Request( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + pooling_params=None, + mm_features=mm_features if mm_features else None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=arrival_time, + priority=priority, + block_hasher=block_hasher, + ) + return request + + +def _mock_execute_model( + scheduler_output: SchedulerOutput, num_output_tokens_range: tuple[int, int] +) -> ModelRunnerOutput: + request_ids: list[str] = [] + request_ids.extend(req.req_id for req in scheduler_output.scheduled_new_reqs) + request_ids.extend(scheduler_output.scheduled_cached_reqs.req_ids) + random.shuffle(request_ids) + + num_output_tokens = [ + random.randint(*num_output_tokens_range) for _ in range(len(request_ids)) + ] + sampled_token_ids = [ + [random.randint(0, 100) for _ in range(num_tokens)] + for num_tokens in num_output_tokens + ] + + return ModelRunnerOutput( + req_ids=request_ids, + req_id_to_index={req_id: i for i, req_id in enumerate(request_ids)}, + sampled_token_ids=sampled_token_ids, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + + +def _mock_draft_token_ids( + scheduler_output: SchedulerOutput, + num_output_tokens_range: tuple[int, int], + seen_request_prompt_length: dict[str, int], +) -> DraftTokenIds: + request_ids: list[str] = [] + sampled_token_ids: list[list[int]] = [] + for request in scheduler_output.scheduled_new_reqs: + assert request.req_id not in seen_request_prompt_length + seen_request_prompt_length[request.req_id] = len(request.prompt_token_ids or []) + if request.num_computed_tokens >= seen_request_prompt_length[request.req_id]: + num_tokens = random.randint(*num_output_tokens_range) + request_ids.append(request.req_id) + sampled_token_ids.append( + [random.randint(0, 100) for _ in range(num_tokens)] + ) + for req_id, num_computed_tokens in zip( + scheduler_output.scheduled_cached_reqs.req_ids, + scheduler_output.scheduled_cached_reqs.num_computed_tokens, + ): + if num_computed_tokens >= seen_request_prompt_length[req_id]: + num_tokens = random.randint(*num_output_tokens_range) + request_ids.append(req_id) + sampled_token_ids.append( + [random.randint(0, 100) for _ in range(num_tokens)] + ) + return DraftTokenIds(req_ids=request_ids, draft_token_ids=sampled_token_ids) + + +def _chech_valid_scheduler_output( + scheduler_output: SchedulerOutput, + seen_request_ids: set[str], + seen_mm_hashes: set[str], +): + for req in scheduler_output.scheduled_new_reqs: + assert req.req_id not in seen_request_ids + seen_request_ids.add(req.req_id) + for req_id in scheduler_output.scheduled_cached_reqs.req_ids: + assert req_id in seen_request_ids + + req_ids = set[str]() + req_ids.update(req.req_id for req in scheduler_output.scheduled_new_reqs) + req_ids.update(scheduler_output.scheduled_cached_reqs.req_ids) + + assert set(scheduler_output.num_scheduled_tokens.keys()) == req_ids + assert ( + sum(scheduler_output.num_scheduled_tokens.values()) + == scheduler_output.total_num_scheduled_tokens + ) + + assert set(scheduler_output.scheduled_spec_decode_tokens.keys()) <= req_ids + assert set(scheduler_output.scheduled_encoder_inputs.keys()) <= req_ids + + for req in scheduler_output.scheduled_new_reqs: + for mm_feature in req.mm_features: + seen_mm_hashes.add(mm_feature.identifier) + for mm_hash in scheduler_output.free_encoder_mm_hashes: + assert mm_hash in seen_mm_hashes + + assert scheduler_output.finished_req_ids <= seen_request_ids + + +@pytest.mark.parametrize("enable_prefix_caching", [True, False]) +@pytest.mark.parametrize("num_speculative_tokens", [None, 1, 5]) +@pytest.mark.parametrize( + ("max_input_tokens", "max_output_tokens", "max_num_seqs", "num_blocks"), + [ + # Standard profile + (5000, 500, 256, 10000), + # Generation heavy + high max_num_seqs + low num_blocks -> Many preemptions + (500, 5000, 1024, 1000), + ], + ids=["standard", "preemption"], +) +def test_priority_scheduling_blast( + enable_prefix_caching: bool, + num_speculative_tokens: int | None, + max_input_tokens: int, + max_output_tokens: int, + max_num_seqs: int, + num_blocks: int, +): + random.seed(42) + seen_request_prompt_length = dict[str, int]() + seen_request_ids = set[str]() + seen_mm_hashes = set[str]() + + scheduler = create_scheduler_with_priority( + model="Qwen/Qwen2.5-VL-3B-Instruct", + max_num_seqs=max_num_seqs, + enable_prefix_caching=enable_prefix_caching, + num_blocks=num_blocks, + num_speculative_tokens=num_speculative_tokens, + ) + + num_initial_requests = 10 + for _ in range(num_initial_requests): + req = _create_random_request( + max_tokens_range=(1, max_output_tokens), + num_tokens_range=(1, max_input_tokens), + arrival_time_range=(0, 1), + priority_range=(-3, 3), + num_mm_item_range=(0, 2), + vllm_config=scheduler.vllm_config, + ) + scheduler.add_request(req) + + for _ in range(20000): + if len(scheduler.waiting) == 0: + num_new_requests = random.randint(0, 2) + for _ in range(num_new_requests): + req = _create_random_request( + max_tokens_range=(1, max_output_tokens), + num_tokens_range=(1, max_input_tokens), + arrival_time_range=(0, 1), + priority_range=(-3, 3), + num_mm_item_range=(0, 2), + vllm_config=scheduler.vllm_config, + ) + scheduler.add_request(req) + scheduler_output = scheduler.schedule() + _chech_valid_scheduler_output( + scheduler_output, seen_request_ids, seen_mm_hashes + ) + model_output = _mock_execute_model( + scheduler_output, + num_output_tokens_range=(1, 1 + (num_speculative_tokens or 0)), + ) + scheduler.update_from_output(scheduler_output, model_output) + if num_speculative_tokens is not None: + scheduler.update_draft_token_ids( + _mock_draft_token_ids( + scheduler_output, + (0, num_speculative_tokens), + seen_request_prompt_length, + ) + ) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8455746cd56d..4fcc7955df19 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -300,6 +300,20 @@ class Scheduler(SchedulerInterface): ] req_to_new_blocks.pop(preempted_req.request_id) num_scheduled_tokens.pop(preempted_req.request_id) + scheduled_spec_decode_tokens.pop( + preempted_req.request_id, None + ) + preempted_encoder_inputs = scheduled_encoder_inputs.pop( + preempted_req.request_id, None + ) + if preempted_encoder_inputs: + # Restore encoder compute budget if the preempted + # request had encoder inputs scheduled in this step. + num_tokens_to_restore = sum( + preempted_req.get_num_encoder_tokens(i) + for i in preempted_encoder_inputs + ) + encoder_compute_budget += num_tokens_to_restore req_index -= 1 else: preempted_req = self.running.pop()