mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 03:24:28 +08:00
[BugFix] Handle unscheduled requests properly when async scheduling (#27756)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
b5bae42f91
commit
2ce5c5d3d6
@ -212,10 +212,12 @@ def test_update_states_request_resumed(model_runner):
|
||||
# resume req
|
||||
cached_req_data = CachedRequestData(
|
||||
req_ids=[req_id],
|
||||
resumed_from_preemption=[False],
|
||||
resumed_req_ids={req_id},
|
||||
new_token_ids=[[]],
|
||||
all_token_ids={req_id: scheduler_output.scheduled_new_reqs[0].prompt_token_ids},
|
||||
new_block_ids=[([],)],
|
||||
num_computed_tokens=[0],
|
||||
num_output_tokens=[0],
|
||||
)
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
|
||||
@ -259,10 +259,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
|
||||
# resume req
|
||||
cached_req_data = CachedRequestData(
|
||||
req_ids=[req_id],
|
||||
resumed_from_preemption=[False],
|
||||
resumed_req_ids=set(),
|
||||
new_token_ids=[[]],
|
||||
resumed_req_token_ids=[None],
|
||||
new_block_ids=([[0]],),
|
||||
all_token_ids={},
|
||||
new_block_ids=[([0],)],
|
||||
num_computed_tokens=[0],
|
||||
num_output_tokens=[0],
|
||||
)
|
||||
|
||||
@ -494,5 +494,5 @@ def yield_req_data(
|
||||
yield from zip(
|
||||
cached_reqs.req_ids,
|
||||
cached_reqs.new_block_ids,
|
||||
cached_reqs.resumed_from_preemption,
|
||||
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
|
||||
)
|
||||
|
||||
@ -415,10 +415,10 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
num_computed_tokens = cached_reqs.num_computed_tokens[i]
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
|
||||
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
|
||||
|
||||
if self.is_producer:
|
||||
num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[req_id]
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_tokens = num_scheduled_tokens + num_computed_tokens
|
||||
assert req_id in self.chunked_prefill
|
||||
assert new_block_ids is not None
|
||||
|
||||
@ -336,7 +336,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
|
||||
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
|
||||
if not resumed_from_preemption or req_id not in self._requests_need_load:
|
||||
continue
|
||||
|
||||
|
||||
@ -2,8 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm._bc_linter import bc_linter_include
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -96,16 +99,16 @@ class NewRequestData:
|
||||
@dataclass
|
||||
class CachedRequestData:
|
||||
req_ids: list[str]
|
||||
# If resumed_from_preemption is False, new_block_ids will be appended to
|
||||
# the request's block IDs. If True, new_block_ids will be used as the
|
||||
# For request ids not in resumed_req_ids, new_block_ids will be appended to
|
||||
# the request's block IDs. For those in the set, new_block_ids will be used as the
|
||||
# request's block IDs instead of appending to the existing block IDs.
|
||||
resumed_from_preemption: list[bool]
|
||||
resumed_req_ids: set[str]
|
||||
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
|
||||
# When PP is not used, new_token_ids will be empty.
|
||||
new_token_ids: list[list[int]]
|
||||
# If resumed_from_preemption is True, propogate the token ids to the
|
||||
# connector, otherwise will be empty.
|
||||
resumed_req_token_ids: list[list[int] | None]
|
||||
# For requests not scheduled in the last step, propagate the token ids to the
|
||||
# connector. Won't contain requests that were scheduled in the prior step.
|
||||
all_token_ids: dict[str, list[int]]
|
||||
new_block_ids: list[tuple[list[int], ...] | None]
|
||||
num_computed_tokens: list[int]
|
||||
num_output_tokens: list[int]
|
||||
@ -114,13 +117,26 @@ class CachedRequestData:
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_ids)
|
||||
|
||||
@cached_property
|
||||
@deprecated("use resumed_req_ids field")
|
||||
def resumed_from_preemption(self) -> list[bool]:
|
||||
return [req_id in self.resumed_req_ids for req_id in self.req_ids]
|
||||
|
||||
@cached_property
|
||||
@deprecated("use all_token_ids field")
|
||||
def resumed_req_token_ids(self) -> list[list[int] | None]:
|
||||
return [
|
||||
self.all_token_ids[req_id] if req_id in self.resumed_req_ids else None
|
||||
for req_id in self.req_ids
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def make_empty(cls) -> "CachedRequestData":
|
||||
return cls(
|
||||
req_ids=[],
|
||||
resumed_from_preemption=[],
|
||||
resumed_req_ids=set(),
|
||||
new_token_ids=[],
|
||||
resumed_req_token_ids=[],
|
||||
all_token_ids={},
|
||||
new_block_ids=[],
|
||||
num_computed_tokens=[],
|
||||
num_output_tokens=[],
|
||||
|
||||
@ -71,6 +71,7 @@ class Scheduler(SchedulerInterface):
|
||||
self.finished_req_ids_dict: dict[int, set[str]] | None = (
|
||||
defaultdict(set) if include_finished_set else None
|
||||
)
|
||||
self.prev_step_scheduled_req_ids: set[str] = set()
|
||||
|
||||
# Scheduling constraints.
|
||||
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
|
||||
@ -444,14 +445,9 @@ class Scheduler(SchedulerInterface):
|
||||
# `request.num_prompt_tokens` to consider the resumed
|
||||
# requests, which have output tokens.
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
if (
|
||||
0
|
||||
< self.scheduler_config.long_prefill_token_threshold
|
||||
< num_new_tokens
|
||||
):
|
||||
num_new_tokens = (
|
||||
self.scheduler_config.long_prefill_token_threshold
|
||||
)
|
||||
threshold = self.scheduler_config.long_prefill_token_threshold
|
||||
if 0 < threshold < num_new_tokens:
|
||||
num_new_tokens = threshold
|
||||
|
||||
# chunked prefill has to be enabled explicitly to allow
|
||||
# pooling requests to be chunked
|
||||
@ -620,6 +616,11 @@ class Scheduler(SchedulerInterface):
|
||||
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
|
||||
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
|
||||
)
|
||||
|
||||
# Record the request ids that were scheduled in this step.
|
||||
self.prev_step_scheduled_req_ids.clear()
|
||||
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=cached_reqs_data,
|
||||
@ -691,14 +692,12 @@ class Scheduler(SchedulerInterface):
|
||||
req_ids: list[str] = []
|
||||
new_token_ids: list[list[int]] = []
|
||||
new_block_ids: list[tuple[list[int], ...] | None] = []
|
||||
resumed_req_token_ids: list[list[int] | None] = []
|
||||
all_token_ids: dict[str, list[int]] = {}
|
||||
num_computed_tokens: list[int] = []
|
||||
num_output_tokens: list[int] = []
|
||||
resumed_req_ids = set()
|
||||
|
||||
# Because resumed_reqs is usually empty, it is more efficient to do
|
||||
# in-place appending so that we don't need to allocate a new list.
|
||||
resumed_from_preemption = [False] * len(running_reqs)
|
||||
resumed_from_preemption += [True] * len(resumed_reqs)
|
||||
num_running_reqs = len(running_reqs)
|
||||
for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)):
|
||||
req_id = req.request_id
|
||||
req_ids.append(req_id)
|
||||
@ -715,12 +714,14 @@ class Scheduler(SchedulerInterface):
|
||||
req.num_computed_tokens : req.num_computed_tokens + num_tokens
|
||||
]
|
||||
new_token_ids.append(token_ids)
|
||||
resumed_token_ids = None
|
||||
if resumed_from_preemption[idx]:
|
||||
resumed_token_ids = req.all_token_ids[
|
||||
scheduled_in_prev_step = req_id in self.prev_step_scheduled_req_ids
|
||||
if idx >= num_running_reqs:
|
||||
assert not scheduled_in_prev_step
|
||||
resumed_req_ids.add(req_id)
|
||||
if not scheduled_in_prev_step:
|
||||
all_token_ids[req_id] = req.all_token_ids[
|
||||
: req.num_computed_tokens + num_tokens
|
||||
]
|
||||
resumed_req_token_ids.append(resumed_token_ids)
|
||||
new_block_ids.append(
|
||||
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
|
||||
)
|
||||
@ -731,9 +732,9 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
return CachedRequestData(
|
||||
req_ids=req_ids,
|
||||
resumed_from_preemption=resumed_from_preemption,
|
||||
resumed_req_ids=resumed_req_ids,
|
||||
new_token_ids=new_token_ids,
|
||||
resumed_req_token_ids=resumed_req_token_ids,
|
||||
all_token_ids=all_token_ids,
|
||||
new_block_ids=new_block_ids,
|
||||
num_computed_tokens=num_computed_tokens,
|
||||
num_output_tokens=num_output_tokens,
|
||||
|
||||
@ -706,7 +706,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = req_data.num_computed_tokens[i]
|
||||
new_block_ids = req_data.new_block_ids[i]
|
||||
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
||||
resumed_from_preemption = req_id in req_data.resumed_req_ids
|
||||
num_output_tokens = req_data.num_output_tokens[i]
|
||||
|
||||
# Update the cached states.
|
||||
@ -754,16 +754,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Replace the existing block IDs with the new ones.
|
||||
req_state.block_ids = new_block_ids
|
||||
|
||||
if self.use_async_scheduling and num_output_tokens > 0:
|
||||
# We must recover the output token ids for resumed requests in the
|
||||
# async scheduling case, so that correct input_ids are obtained.
|
||||
resumed_token_ids = req_data.resumed_req_token_ids[i]
|
||||
assert resumed_token_ids is not None
|
||||
req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]
|
||||
if req_index is None:
|
||||
# The request is not in the persistent batch.
|
||||
# The request was either preempted and resumed later, or was not
|
||||
# scheduled in the previous step and needs to be added again.
|
||||
|
||||
if self.use_async_scheduling and num_output_tokens > 0:
|
||||
# We must recover the output token ids for resumed requests in the
|
||||
# async scheduling case, so that correct input_ids are obtained.
|
||||
resumed_token_ids = req_data.all_token_ids[req_id]
|
||||
req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]
|
||||
|
||||
reqs_to_add.append(req_state)
|
||||
continue
|
||||
|
||||
|
||||
@ -483,7 +483,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = req_data.num_computed_tokens[i]
|
||||
new_block_ids = req_data.new_block_ids[i]
|
||||
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
||||
resumed_from_preemption = req_id in req_data.resumed_req_ids
|
||||
|
||||
# Update the cached states.
|
||||
req_state.num_computed_tokens = num_computed_tokens
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user