diff --git a/requirements/common.txt b/requirements/common.txt index f97fe35d28b30..526ed514ac03a 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -7,7 +7,7 @@ requests >= 2.26.0 tqdm blake3 py-cpuinfo -transformers >= 4.51.1 +transformers >= 4.53.2 huggingface-hub[hf_xet] >= 0.33.0 # Required for Xet downloads. tokenizers >= 0.21.1 # Required for fast incremental detokenization. protobuf # Required by LlamaTokenizer. diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 02d2c83ab1584..2d3657b334ba8 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -451,6 +451,7 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) + req.status = RequestStatus.RUNNING scheduler_output = SchedulerOutput( scheduled_new_reqs=[], @@ -504,6 +505,7 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) + req.status = RequestStatus.RUNNING scheduler_output = SchedulerOutput( scheduled_new_reqs=[], @@ -556,6 +558,7 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) + req.status = RequestStatus.RUNNING scheduler_output = SchedulerOutput( scheduled_new_reqs=[], @@ -703,6 +706,65 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], scheduler.update_from_output(scheduler_output1, model_runner_output) +def test_preempt_during_execution(): + # NOTE(woosuk): The actual number of available blocks is 10 instead of 11 + # because block 0 is reserved as the null block. + scheduler = create_scheduler(max_num_batched_tokens=100, + block_size=16, + num_blocks=11, + enable_prefix_caching=False) + requests = create_requests(num_requests=2, num_tokens=80) + + # Schedule the first request. + scheduler.add_request(requests[0]) + scheduler_output0 = scheduler.schedule() + assert len(scheduler_output0.num_scheduled_tokens) == 1 + assert len(scheduler_output0.scheduled_new_reqs[0].block_ids[0]) == 5 + + # Schedule the second request while the first request is still running. + # This scenario can occur in certain cases, when max_concurrent_batches > 1 + # (e.g., when pipeline parallelism is used). + scheduler.add_request(requests[1]) + scheduler_output1 = scheduler.schedule() + assert len(scheduler_output1.num_scheduled_tokens) == 1 + assert len(scheduler_output1.scheduled_new_reqs[0].block_ids[0]) == 5 + + # Get the output of the first request. + model_runner_output0 = ModelRunnerOutput( + req_ids=[requests[0].request_id], + req_id_to_index={requests[0].request_id: 0}, + sampled_token_ids=[[0]], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(scheduler_output0, model_runner_output0) + + # Schedule the first request again. This will cause the preemption + # of the second request because the KV cache is full. + _ = scheduler.schedule() + assert len(scheduler.running) == 1 + assert scheduler.running[0] == requests[0] + assert requests[1].status == RequestStatus.PREEMPTED + + model_runner_output1 = ModelRunnerOutput( + req_ids=[requests[1].request_id], + req_id_to_index={requests[1].request_id: 0}, + sampled_token_ids=[[42]], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(scheduler_output1, model_runner_output1) + + # The second request (that is preempted) should be updated with the + # sampled token id. + assert len(requests[1].output_token_ids) == 1 + assert requests[1].output_token_ids[0] == 42 + + # Note - these test cases mirror some of those in test_rejection_sampler.py @pytest.mark.parametrize( "spec_tokens,output_tokens,expected", diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index b2d90614c294c..f81bb9fc13a4c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -747,19 +747,21 @@ class Scheduler(SchedulerInterface): pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits - new_running: list[Request] = [] outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None - # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below - # loop can be a performance bottleneck. We should do our best to avoid - # expensive operations inside the loop. - for request in self.running: - req_id = request.request_id - num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) - if num_tokens_scheduled == 0: - # The request was not scheduled in this step. - new_running.append(request) + # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, + # the below loop can be a performance bottleneck. We should do our best + # to avoid expensive operations inside the loop. + stopped_running_reqs: set[Request] = set() + stopped_preempted_reqs: set[Request] = set() + for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): + assert num_tokens_scheduled > 0 + request = self.requests.get(req_id) + if request is None: + # The request is already finished. This can happen if the + # request is aborted while the model is executing it (e.g., + # in pipeline parallelism). continue req_index = model_runner_output.req_id_to_index[req_id] @@ -792,6 +794,7 @@ class Scheduler(SchedulerInterface): new_logprobs = None new_token_ids = generated_token_ids kv_transfer_params = None + status_before_stop = request.status # Append generated tokens and check for stop. Note that if # a request is still being prefilled, we expect the model runner @@ -803,17 +806,22 @@ class Scheduler(SchedulerInterface): # This must be called before we make the EngineCoreOutput. stopped = check_stop(request, self.max_model_len) if stopped: - kv_transfer_params = self._free_request(request) del new_token_ids[num_new:] # Trim new tokens if needed. break + # Stop checking for pooler models. pooler_output = None if pooler_outputs: pooler_output = pooler_outputs[req_index] stopped = check_stop(request, self.max_model_len, pooler_output) - if stopped: - kv_transfer_params = self._free_request(request) + + if stopped: + kv_transfer_params = self._free_request(request) + if status_before_stop == RequestStatus.RUNNING: + stopped_running_reqs.add(request) + else: + stopped_preempted_reqs.add(request) # Extract sample logprobs if needed. if request.sampling_params is not None \ @@ -868,9 +876,14 @@ class Scheduler(SchedulerInterface): # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors - if not stopped: - new_running.append(request) - self.running = new_running + # Remove the stopped requests from the running and waiting queues. + if stopped_running_reqs: + self.running = [ + req for req in self.running if req not in stopped_running_reqs + ] + if stopped_preempted_reqs: + # This is a rare case and unlikely to impact performance. + self.waiting.remove_requests(stopped_preempted_reqs) # KV Connector: update state for finished KV Transfers. self._update_from_kv_xfer_finished(model_runner_output)