mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:55:50 +08:00
[BugFix] Fix spec decoding max_tokens scheduling perf issue (#29542)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
953d9c820b
commit
8e7a891602
@ -43,7 +43,7 @@ class TestLogprobsLists(TestCase):
|
||||
cu_num_generated_tokens=None,
|
||||
)
|
||||
|
||||
sliced = logprobsLists.slice(1, 3)
|
||||
sliced = logprobsLists.slice_request(1, num_positions=2)
|
||||
assert sliced.logprob_token_ids == [[2], [3]]
|
||||
assert sliced.logprobs == [[0.2], [0.3]]
|
||||
assert sliced.sampled_token_ranks == [2, 3]
|
||||
@ -51,7 +51,7 @@ class TestLogprobsLists(TestCase):
|
||||
|
||||
def test_slice_from_start(self):
|
||||
"""Test slicing from the start position"""
|
||||
sliced = self.logprobsLists.slice(0, 2)
|
||||
sliced = self.logprobsLists.slice_request(0, num_positions=5)
|
||||
assert len(sliced.logprob_token_ids) == 5
|
||||
assert sliced.logprob_token_ids == [
|
||||
[1, 2],
|
||||
@ -60,11 +60,11 @@ class TestLogprobsLists(TestCase):
|
||||
[7, 8],
|
||||
[9, 10],
|
||||
]
|
||||
assert sliced.cu_num_generated_tokens == [0, 2, 5]
|
||||
assert sliced.cu_num_generated_tokens is None
|
||||
|
||||
def test_slice_from_middle(self):
|
||||
"""Test slicing from the middle position"""
|
||||
sliced = self.logprobsLists.slice(1, 3)
|
||||
sliced = self.logprobsLists.slice_request(1, num_positions=7)
|
||||
assert len(sliced.logprob_token_ids) == 7
|
||||
assert sliced.logprob_token_ids == [
|
||||
[5, 6],
|
||||
@ -75,27 +75,25 @@ class TestLogprobsLists(TestCase):
|
||||
[15, 16],
|
||||
[17, 18],
|
||||
]
|
||||
assert sliced.cu_num_generated_tokens == [0, 3, 7]
|
||||
assert sliced.cu_num_generated_tokens is None
|
||||
|
||||
def test_slice_single_request(self):
|
||||
"""Test slicing a single request"""
|
||||
sliced = self.logprobsLists.slice(1, 2)
|
||||
sliced = self.logprobsLists.slice_request(1, num_positions=3)
|
||||
assert len(sliced.logprob_token_ids) == 3
|
||||
assert sliced.logprob_token_ids == [[5, 6], [7, 8], [9, 10]]
|
||||
assert sliced.cu_num_generated_tokens == [0, 3]
|
||||
assert sliced.cu_num_generated_tokens is None
|
||||
|
||||
def test_slice_last_request(self):
|
||||
"""Test slicing the last request"""
|
||||
sliced = self.logprobsLists.slice(2, 3)
|
||||
sliced = self.logprobsLists.slice_request(2, num_positions=4)
|
||||
assert len(sliced.logprob_token_ids) == 4
|
||||
assert sliced.logprob_token_ids == [[11, 12], [13, 14], [15, 16], [17, 18]]
|
||||
assert sliced.cu_num_generated_tokens == [0, 4]
|
||||
assert sliced.cu_num_generated_tokens is None
|
||||
|
||||
def test_slice_all_requests(self):
|
||||
"""Test slicing all requests (full slice)"""
|
||||
sliced = self.logprobsLists.slice(0, 3)
|
||||
sliced = self.logprobsLists.slice_request(0, num_positions=9)
|
||||
assert len(sliced.logprob_token_ids) == 9 # All tokens
|
||||
assert sliced.logprob_token_ids == self.logprobsLists.logprob_token_ids
|
||||
assert (
|
||||
sliced.cu_num_generated_tokens == self.logprobsLists.cu_num_generated_tokens
|
||||
)
|
||||
assert sliced.cu_num_generated_tokens is None
|
||||
|
||||
@ -234,11 +234,15 @@ class Scheduler(SchedulerInterface):
|
||||
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
|
||||
# Make sure the input position does not exceed the max model len or
|
||||
# request's max_tokens.
|
||||
# This is necessary when using spec decoding and/or async scheduling.
|
||||
num_spec_placeholders = max(0, request.num_output_placeholders - 1)
|
||||
max_total_tokens = min(
|
||||
request.num_prompt_tokens + request.max_tokens, self.max_model_len
|
||||
# Avoid scheduling tokens that we're sure won't will be needed based on
|
||||
# request.max_tokens. For this calculation we assume placeholder
|
||||
# speculated output tokens are rejected.
|
||||
request.num_prompt_tokens + request.max_tokens + num_spec_placeholders,
|
||||
# Make sure the input position does not exceed the max model len.
|
||||
# This is necessary when using spec decoding.
|
||||
self.max_model_len,
|
||||
)
|
||||
num_new_tokens = min(
|
||||
num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens
|
||||
@ -1089,7 +1093,7 @@ class Scheduler(SchedulerInterface):
|
||||
and request.sampling_params.logprobs is not None
|
||||
and logprobs
|
||||
):
|
||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||
new_logprobs = logprobs.slice_request(req_index, len(new_token_ids))
|
||||
|
||||
if new_token_ids and self.structured_output_manager.should_advance(request):
|
||||
struct_output_request = request.structured_output_request
|
||||
|
||||
@ -29,27 +29,15 @@ class LogprobsLists(NamedTuple):
|
||||
# different for each request.
|
||||
cu_num_generated_tokens: list[int] | None = None
|
||||
|
||||
def slice(self, start_req_idx: int, end_req_idx: int):
|
||||
if self.cu_num_generated_tokens:
|
||||
start = self.cu_num_generated_tokens[start_req_idx]
|
||||
end = self.cu_num_generated_tokens[end_req_idx]
|
||||
# Recompute cumulative array starting from 0
|
||||
cu_num_offset = self.cu_num_generated_tokens[start_req_idx]
|
||||
sliced_cu_num_generated_tokens = [
|
||||
cu_num - cu_num_offset
|
||||
for cu_num in self.cu_num_generated_tokens[
|
||||
start_req_idx : end_req_idx + 1
|
||||
]
|
||||
]
|
||||
else:
|
||||
start = start_req_idx
|
||||
end = end_req_idx
|
||||
sliced_cu_num_generated_tokens = None
|
||||
def slice_request(self, req_idx: int, num_positions: int):
|
||||
if self.cu_num_generated_tokens is not None:
|
||||
req_idx = self.cu_num_generated_tokens[req_idx]
|
||||
end_idx = req_idx + num_positions
|
||||
return LogprobsLists(
|
||||
self.logprob_token_ids[start:end],
|
||||
self.logprobs[start:end],
|
||||
self.sampled_token_ranks[start:end],
|
||||
sliced_cu_num_generated_tokens,
|
||||
self.logprob_token_ids[req_idx:end_idx],
|
||||
self.logprobs[req_idx:end_idx],
|
||||
self.sampled_token_ranks[req_idx:end_idx],
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user