[BugFix] Fix spec decoding max_tokens scheduling perf issue (#29542)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-11-28 04:52:23 -08:00 committed by GitHub
parent 953d9c820b
commit 8e7a891602
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 38 deletions

View File

@ -43,7 +43,7 @@ class TestLogprobsLists(TestCase):
cu_num_generated_tokens=None,
)
sliced = logprobsLists.slice(1, 3)
sliced = logprobsLists.slice_request(1, num_positions=2)
assert sliced.logprob_token_ids == [[2], [3]]
assert sliced.logprobs == [[0.2], [0.3]]
assert sliced.sampled_token_ranks == [2, 3]
@ -51,7 +51,7 @@ class TestLogprobsLists(TestCase):
def test_slice_from_start(self):
"""Test slicing from the start position"""
sliced = self.logprobsLists.slice(0, 2)
sliced = self.logprobsLists.slice_request(0, num_positions=5)
assert len(sliced.logprob_token_ids) == 5
assert sliced.logprob_token_ids == [
[1, 2],
@ -60,11 +60,11 @@ class TestLogprobsLists(TestCase):
[7, 8],
[9, 10],
]
assert sliced.cu_num_generated_tokens == [0, 2, 5]
assert sliced.cu_num_generated_tokens is None
def test_slice_from_middle(self):
"""Test slicing from the middle position"""
sliced = self.logprobsLists.slice(1, 3)
sliced = self.logprobsLists.slice_request(1, num_positions=7)
assert len(sliced.logprob_token_ids) == 7
assert sliced.logprob_token_ids == [
[5, 6],
@ -75,27 +75,25 @@ class TestLogprobsLists(TestCase):
[15, 16],
[17, 18],
]
assert sliced.cu_num_generated_tokens == [0, 3, 7]
assert sliced.cu_num_generated_tokens is None
def test_slice_single_request(self):
"""Test slicing a single request"""
sliced = self.logprobsLists.slice(1, 2)
sliced = self.logprobsLists.slice_request(1, num_positions=3)
assert len(sliced.logprob_token_ids) == 3
assert sliced.logprob_token_ids == [[5, 6], [7, 8], [9, 10]]
assert sliced.cu_num_generated_tokens == [0, 3]
assert sliced.cu_num_generated_tokens is None
def test_slice_last_request(self):
"""Test slicing the last request"""
sliced = self.logprobsLists.slice(2, 3)
sliced = self.logprobsLists.slice_request(2, num_positions=4)
assert len(sliced.logprob_token_ids) == 4
assert sliced.logprob_token_ids == [[11, 12], [13, 14], [15, 16], [17, 18]]
assert sliced.cu_num_generated_tokens == [0, 4]
assert sliced.cu_num_generated_tokens is None
def test_slice_all_requests(self):
"""Test slicing all requests (full slice)"""
sliced = self.logprobsLists.slice(0, 3)
sliced = self.logprobsLists.slice_request(0, num_positions=9)
assert len(sliced.logprob_token_ids) == 9 # All tokens
assert sliced.logprob_token_ids == self.logprobsLists.logprob_token_ids
assert (
sliced.cu_num_generated_tokens == self.logprobsLists.cu_num_generated_tokens
)
assert sliced.cu_num_generated_tokens is None

View File

@ -234,11 +234,15 @@ class Scheduler(SchedulerInterface):
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = min(num_new_tokens, token_budget)
# Make sure the input position does not exceed the max model len or
# request's max_tokens.
# This is necessary when using spec decoding and/or async scheduling.
num_spec_placeholders = max(0, request.num_output_placeholders - 1)
max_total_tokens = min(
request.num_prompt_tokens + request.max_tokens, self.max_model_len
# Avoid scheduling tokens that we're sure won't will be needed based on
# request.max_tokens. For this calculation we assume placeholder
# speculated output tokens are rejected.
request.num_prompt_tokens + request.max_tokens + num_spec_placeholders,
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
self.max_model_len,
)
num_new_tokens = min(
num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens
@ -1089,7 +1093,7 @@ class Scheduler(SchedulerInterface):
and request.sampling_params.logprobs is not None
and logprobs
):
new_logprobs = logprobs.slice(req_index, req_index + 1)
new_logprobs = logprobs.slice_request(req_index, len(new_token_ids))
if new_token_ids and self.structured_output_manager.should_advance(request):
struct_output_request = request.structured_output_request

View File

@ -29,27 +29,15 @@ class LogprobsLists(NamedTuple):
# different for each request.
cu_num_generated_tokens: list[int] | None = None
def slice(self, start_req_idx: int, end_req_idx: int):
if self.cu_num_generated_tokens:
start = self.cu_num_generated_tokens[start_req_idx]
end = self.cu_num_generated_tokens[end_req_idx]
# Recompute cumulative array starting from 0
cu_num_offset = self.cu_num_generated_tokens[start_req_idx]
sliced_cu_num_generated_tokens = [
cu_num - cu_num_offset
for cu_num in self.cu_num_generated_tokens[
start_req_idx : end_req_idx + 1
]
]
else:
start = start_req_idx
end = end_req_idx
sliced_cu_num_generated_tokens = None
def slice_request(self, req_idx: int, num_positions: int):
if self.cu_num_generated_tokens is not None:
req_idx = self.cu_num_generated_tokens[req_idx]
end_idx = req_idx + num_positions
return LogprobsLists(
self.logprob_token_ids[start:end],
self.logprobs[start:end],
self.sampled_token_ranks[start:end],
sliced_cu_num_generated_tokens,
self.logprob_token_ids[req_idx:end_idx],
self.logprobs[req_idx:end_idx],
self.sampled_token_ranks[req_idx:end_idx],
None,
)